"""
Z240-MP4, Z240-CP13をコーディネータとして動作させる

実行方法:
    python run.py COM1
"""

import os
import serial
import threading
import time
import struct
import sys
from queue import Queue, Empty
from typing import Any

# デバイス管理用クラスの読み込み
import device_class as dc

# シリアルポートの設定
# コマンドライン引数でポート名(例:COM1,/dev/ttyUSB0)を指定する
PORT = sys.argv[1]
BAUDRATE = 115200
TIMEOUT = 0.01
READ_SIZE = 255


def write_data_to_module(ser: serial.Serial, data: bytes) -> None:
    ser.write(data)
    # print("WRITE:", ' '.join('%02X' % b for b in data))


def read_data_from_module(ser: serial.Serial) -> bytes:
    ret = ser.read(size=READ_SIZE)
    # print("READ: ", ' '.join('%02X' % b for b in ret))
    return ret


def calc_xor8(data: bytes) -> int:
    # チェックサムを計算
    xor = 0
    for d in data:
        xor = xor ^ d

    return xor


def display_all_devices(device_list: list[dc.Device]) -> None:
    # 端末の出力をクリアする
    if os.name == 'nt':
        os.system('cls')
    else:
        os.system('clear')

    # デバイス一覧を端末に出力する
    print("Device List:")
    if len(device_list) == 0:
        print(" no device.")
        print("---------------------------------------------------------------------")
        return
    for idx, dev in enumerate(device_list, start=1):
        print("[%d]" % idx)
        print(" ShortAddr: 0x%04X  MAC: 0x%016X" % (dev.short_addr, dev.mac_addr))
        print(" EndPoints:")
        for ep in dev.endpoint_list:
            print("  Port: %d  ProfileID: 0x%04X  DeviceID: 0x%04X" %
                  (ep.port, ep.profile_id, ep.device_id))
            if ep.profile_id == 0x0104:
                if ep.device_id == 0x0100:
                    print("  Type: On/Off Light")
                elif ep.device_id == 0x0302:
                    print("  Type: Temperature Sensor")
            print(f"   InClusters : ", end='')
            for cl in ep.in_cluster_list:
                print("0x%04X" % cl.id, end=' ')
            print(f"\n   OutClusters: ", end='')
            for cl in ep.out_cluster_list:
                print("0x%04X" % cl.id, end=' ')
            print("\n  Data:")
            for cl in ep.in_cluster_list:
                if len(cl.attribute_list) == 0:
                    continue
                val = cl.attribute_list[0].value
                time_str = cl.attribute_list[0].get_latest_time()
                if cl.id == 0x0006:
                    if val is not None:
                        print("   ONOFF : %s (%s)" % ('ON' if val else 'OFF', time_str))
                    else:
                        print("   ONOFF : no data")
                elif cl.id == 0x0402:
                    if val is not None:
                        print("   Temperature: %.1f °C (%s)" % ((val / 100), time_str))
                    else:
                        print("   Temperature: no data")
                elif cl.id == 0x0405:
                    if val is not None:
                        print("   Humidity   : %.1f %% (%s)" % ((val / 100), time_str))
                    else:
                        print("   Humidity   : no data")
        print("---------------------------------------------------------------------")


def parse_buffer_data(buffer: bytearray) -> list[bytearray]:
    # バッファ内のデータを1フレームずつ切り出してlistにして返す
    message_frames = []
    i = 0
    while i < len(buffer):
        if buffer[i] != 0x55:
            i += 1
            continue
        if i + 1 >= len(buffer):
            break
        length = buffer[i + 1]
        end = i + 2 + length
        if end <= len(buffer):
            frame = buffer[i:end]
            message_frames.append(frame)
            i = end
        else:
            break

    buffer[:] = buffer[i:]
    return message_frames


def send_network_command(ser: serial.Serial, cmd_code: int, short_addr: int,
                         cmd_param: bytes = bytes()) -> None:
    # ネットワーク管理コマンドの送信関数

    # コマンドデータを作成
    cmd_data = bytearray([0x01, cmd_code])
    cmd_data += struct.pack('<H', short_addr)
    cmd_data += cmd_param
    checksum = calc_xor8(cmd_data)
    cmd_data += bytes([checksum])
    frame_data = bytes([0x55, len(cmd_data)]) + cmd_data

   # コマンドを実行
    write_data_to_module(ser, frame_data)


def send_zcl_command(ser: serial.Serial, cmd_code: int, dst_addr: int, dst_port: int,
                     frame_seq: int, cluster_id: int, manufacturer_code: int,
                     ext_data: bytes = bytes()) -> None:
    # ZCLコマンドの送信関数

    # このサンプルコードではパラメータの一部を固定
    send_mode = 0x00
    direction = 0x00
    response_mode = 0x00

    # コマンドデータを作成
    cmd_data = bytearray([0x02, cmd_code])
    cmd_data += bytes(([send_mode]))
    cmd_data += struct.pack('<H', dst_addr)
    cmd_data += bytes([dst_port])
    cmd_data += bytes([frame_seq])
    cmd_data += bytes([direction])
    cmd_data += struct.pack('<H', cluster_id)
    cmd_data += struct.pack('<H', manufacturer_code)
    cmd_data += bytes([response_mode])
    cmd_data += ext_data
    checksum = calc_xor8(cmd_data)
    cmd_data += bytes([checksum])
    frame_data = bytes([0x55, len(cmd_data)]) + cmd_data

   # コマンドを実行
    write_data_to_module(ser, frame_data)


def update_attribute_value(device: dc.Device, port: int, cluster_id: int,
                           attr_id: int, attr_value: Any) -> None:
    if device is None:
        return

    tmp_endpoint = device.get_endpoint_by_port_number(port)
    if tmp_endpoint is None:
        return

    tmp_cluster = tmp_endpoint.get_cluster_by_id(cluster_id)
    if tmp_cluster is None:
        return

    tmp_attribute = tmp_cluster.get_attribute_by_id(attr_id)
    if tmp_attribute is None:
        return

    # 属性値を更新
    tmp_attribute.set_value(attr_value)


def analyze_message_notify(cmd_code: int, msg_data: bytearray, manager: dc.DeviceManager) -> bool:
    # チェックサムを除くデータ部分のみの長さを計算
    msg_len = len(msg_data) - 1

    # デバイスのネットワーク参加の通知メッセージを受信（コマンドコード：0x04）
    if cmd_code == 0x04:
        if msg_len != 11:
            print("Message format error!")
            return False
        mac_addr, short_addr, device_type = \
            struct.unpack('<QHB', msg_data[:msg_len])
        # エンドデバイスであれば管理リストに追加または更新
        if device_type == 0x02 or device_type == 0x03:
            if manager.has_device(mac_addr):
                manager.update_device(mac_addr, short_addr)
            else:
                device = dc.Device(mac_addr, short_addr)
                if manager.add_device(device) is True:
                    time.sleep(3)
            # デバイスがサポートするポート番号の照会を実行する（コマンドコード：0x05）
            send_network_command(ser, 0x05, short_addr)
            return True

    # デバイスのネットワーク離脱のメッセージを受信（コマンドコード：0x06）
    elif cmd_code == 0x06:
        if msg_len != 8:
            print("Message format error!")
            return False
        mac_addr = struct.unpack('<Q', msg_data[:msg_len])[0]
        if manager.remove_device_by_mac_addr(mac_addr) is True:
            time.sleep(3)
        return True

    return False


def analyze_message_zdo(cmd_code: int, msg_data: bytearray, manager: dc.DeviceManager) -> bool:
    # チェックサムを除くデータ部分のみの長さを計算
    msg_len = len(msg_data) - 1

    # デバイスのMACアドレスの照会結果（コマンドコード：0x01）
    if cmd_code == 0x01:
        if msg_len != 14:
            print("Message format error!")
            return False
        short_addr, cmd_seq, cmd_result, mac_addr, reserved = \
            struct.unpack('<HBBQH', msg_data[:14])

        # コーディネータは対象外
        if short_addr == 0x0000:
            return False

        device = dc.Device(mac_addr, short_addr)
        if manager.add_device(device) is True:
            time.sleep(3)
            send_network_command(ser, 0x05, short_addr)
        return True

    # デバイスがサポートするクラスタの照会結果（コマンドコード：0x04）
    elif cmd_code == 0x04:
        if msg_len < 12:
            print("Message format error!")
            return False

        short_addr, cmd_seq, cmd_result,  port_number, profile_id, device_id, device_ver = \
            struct.unpack('<HBBBHHB', msg_data[:10])
        if cmd_result != 0x00:
            print("Command failure!")
            return False

        in_cluster_count = msg_data[10]
        in_cluster_list = []
        idx = 11
        for _ in range(in_cluster_count):
            in_cluster_id = struct.unpack('<H', msg_data[idx:idx + 2])[0]
            in_cluster_list.append(in_cluster_id)
            idx += 2

        out_cluster_count = msg_data[idx]
        out_cluster_list = []
        for _ in range(out_cluster_count):
            out_cluster_id = struct.unpack('<H', msg_data[idx:idx + 2])[0]
            out_cluster_list.append(out_cluster_id)
            idx += 2

        # メッセージ内のショートアドレスに対応するデバイスを管理リストから検索、エンドポイント情報を追加
        device = manager.get_device_by_short_addr(short_addr)
        if device is not None:
            tmp_in_list = []
            for cluster_id in in_cluster_list:
                # このサンプルでは対象の属性IDを0x0000のみとする
                cluster = dc.Cluster(cluster_id)
                attrbute = dc.Attribute(id=0x0000)
                cluster.add_attribute(attrbute)
                tmp_in_list.append(cluster)

            tmp_out_list = []
            for cluster_id in out_cluster_list:
                cluster = dc.Cluster(cluster_id)
                # このサンプルでは出力クラスタの属性は対象としない
                tmp_out_list.append(cluster)

            endpoint = dc.EndPoint(port_number, profile_id, device_id, tmp_in_list, tmp_out_list)
            device.add_endpoint(endpoint)
            return True

    # デバイスがサポートするポート番号の照会結果（コマンドコード：0x05）
    elif cmd_code == 0x05:
        if msg_len < 5:
            print("Message format error!")
            return False

        short_addr,  cmd_seq, cmd_result, port_count = \
            struct.unpack('<HBBB', msg_data[:5])

        if cmd_result != 0x00:
            print("Command failure!")
            return False
        # このサンプルコードでは1つ目のポート番号のみ管理対象とする
        if port_count > 0:
            port_number = msg_data[5]
            # ポート番号に対応するクラスタのリストを照会するコマンドを送信
            send_network_command(ser, 0x04, short_addr, bytes([port_number]))

    return False


def analyze_message_zcl(cmd_code: int, msg_data: bytearray, manager: dc.DeviceManager) -> bool:
    # チェックサムを除くデータ部分のみの長さを計算
    msg_len = len(msg_data) - 1

    # 共通コマンドヘッダのパラメータ処理
    if msg_len < 11:
        print("Message format error!")
        return False

    mode, src_addr, src_port, frame_seq, direction, cluster_id, manufacturer_code, rssi = \
        struct.unpack('<BHBBBHHB',  msg_data[:11])

    # 属性値の照会結果の受信(コマンドコード：0x00)
    if cmd_code == 0x00:

        tmp_device = manager.get_device_by_short_addr(src_addr)
        # 管理デバイスリストに無い場合
        if tmp_device is None:
            return False

        idx = 11
        attr_count = msg_data[idx]
        idx += 1
        attr_id = struct.unpack('<H', msg_data[idx:idx+2])[0]
        # このサンプルコードではAttribute ID = 0x0000のみを対象にする
        if attr_count != 1 or attr_id != 0x0000:
            return False

        idx += 2
        zcl_status = msg_data[idx]
        if zcl_status != 0x00:
            return False
        idx += 1
        attr_data_type = msg_data[idx]
        idx += 1
        # 属性値のデータ型 bool
        if attr_data_type == 0x10:
            attr_value = bool(msg_data[idx])
        # 属性値のデータ型 uint16
        elif attr_data_type == 0x21:
            attr_value = struct.unpack('<H', msg_data[idx:idx+2])[0]
        # 属性値のデータ型 int16
        elif attr_data_type == 0x29:
            attr_value = struct.unpack('<h', msg_data[idx:idx+2])[0]
        else:
            attr_value = None

        if attr_value is not None:
            update_attribute_value(tmp_device, src_port, cluster_id, attr_id, attr_value)
            return True

    # 属性のアクティブレポートの受信(コマンドコード：0x0A)
    elif cmd_code == 0x0A:
        tmp_device = manager.get_device_by_short_addr(src_addr)
        # 管理デバイスリストに無い場合
        if tmp_device is None:
            # デバイスのMACアドレスを照会するコマンドを送信(コマンドコード:0x01)
            send_network_command(ser, 0x01, src_addr)
        # このサンプルではOnOff,温度,湿度のクラスタのみを対象
        if cluster_id == 0x0006 or cluster_id == 0x0402 or cluster_id == 0x0405:
            idx = 11
            attr_count = msg_data[idx]
            idx += 1
            for _ in range(attr_count):
                attr_id = struct.unpack('<H', msg_data[idx:idx + 2])[0]
                # このサンプルコードではAttribute ID = 0x0000のみを対象にする
                if attr_id != 0x0000:
                    break

                idx += 2
                attr_data_type = msg_data[idx]
                idx += 1
                data_size = 0
                # 属性値のデータ型 bool
                if attr_data_type == 0x10:
                    attr_value = bool(msg_data[idx])
                    data_size = 1
                # 属性値のデータ型 uint16
                elif attr_data_type == 0x21:
                    attr_value = struct.unpack('<H', msg_data[idx:idx + 2])[0]
                    data_size = 2
                # 属性値のデータ型 int16
                elif attr_data_type == 0x29:
                    attr_value = struct.unpack('<h', msg_data[idx:idx + 2])[0]
                    data_size = 2
                else:
                    attr_value = None
                idx += data_size
                if attr_value is not None:
                    update_attribute_value(tmp_device, src_port, cluster_id, attr_id, attr_value)
                    return True

    # ZCL制御コマンドのデフォルト応答フレームの受信(コマンドコード：0x0B)
    elif cmd_code == 0x0B:
        # 結果のメッセージをキューに入れる
        rx_queue.put(msg_data)
    # ZCL制御コマンドの応答メッセージの受信(コマンドコード：0x0F)
    elif cmd_code == 0x0F:
        # 結果のメッセージをキューに入れる
        rx_queue.put(msg_data)

    return False


def send_zcl_onoff(ser: serial.Serial, manager: dc.DeviceManager, q: Queue, number: int) -> None:
    target_addr = None
    target_port = None
    cluster_id = 0
    manufacturer_code = 0

    # OnOffクラスタを持つデバイスを管理リストから検索
    for dev in manager.device_list:
        for ep in dev.endpoint_list:
            cl = ep.get_cluster_by_id(0x0006)
            if cl is not None:
                target_addr = dev.short_addr
                target_port = ep.port
                cluster_id = cl.id
                manufacturer_code = cl.manufacturer_code
                break

    if target_addr is None:
        return

    # OnOffクラスタを持つデバイスにON(0x01)またはOFF(0x00)のZCLコマンドを送信する
    onoff = number % 2
    send_zcl_command(ser, 0x0F, target_addr, target_port, number, cluster_id,
                     manufacturer_code, bytes([onoff]))

    print("[MESSAGE] Sent ZCL control command. (addr=0x%04X, command=%s)" %
          (target_addr, ('ON' if onoff else 'OFF')))

    while True:
        try:
            # キューからZCL応答フレームを取り出す
            response = q.get(timeout=3)
            # 送信したコマンドのシーケンス番号と同一かを確認
            resp_frame_seq = response[4]
            if resp_frame_seq == number:
                # このサンプルコードでは応答フレームを処理しない
                break
        except Empty:
            # 3秒以内に応答が無ければタイムアウト
            print("[MESSAGE] Timeout waiting for ZCL command response.")
            break


def command_thread(ser: serial.Serial, manager: dc.DeviceManager, q: Queue) -> None:
    # フレームシーケンス番号
    frame_seq = 0

    # On/Offのコマンドの送信間隔
    interval_onoff = 30

    # 属性値読み込みのコマンドの送信間隔
    interval_property = 20

    current_time = time.time()
    next_time_onoff = current_time + interval_onoff
    next_time_property = current_time + interval_property

    while True:
        current_time = time.time()

        if current_time >= next_time_property:
            next_time_property += interval_property
            for dev in manager.device_list:
                for ep in dev.endpoint_list:
                    for cl in ep.in_cluster_list:
                        # このサンプルではOnOff,温度,湿度のクラスタのみを対象
                        if cl.id == 0x0006 or cl.id == 0x0402 or cl.id == 0x0405:
                            # 属性数と属性IDを拡張データにして送信
                            ext_data = bytearray([0x01])
                            ext_data += struct.pack('<H', 0x0000)
                            send_zcl_command(ser, 0x00, dev.short_addr, ep.port, 0xFF,
                                             cl.id, cl.manufacturer_code, ext_data)
                            time.sleep(0.5)

        if current_time >= next_time_onoff:
            next_time_onoff += interval_onoff
            frame_seq = (frame_seq + 1) % 255
            send_zcl_onoff(ser, manager, q, frame_seq)
            time.sleep(1)

        time.sleep(1)


# シリアルポートを開く
ser = serial.Serial(port=PORT, baudrate=BAUDRATE, timeout=TIMEOUT)

# ネットワークを開始し、デバイスの参加を許可する(コマンドコード:0x02)
print("permitting to join network...")
write_data_to_module(ser, bytes([0x55, 0x03, 0x00, 0x02, 0x02]))
read_data_from_module(ser)

device_manager = dc.DeviceManager()
rx_queue = Queue()


# スレッドを起動する
# ZCLコマンドを定期的に送信
# 管理デバイスの属性値を定期的に確認
sub_thread = threading.Thread(target=command_thread,
                              args=(ser, device_manager, rx_queue), daemon=True)
sub_thread.start()


# デバイス一覧の表示間隔
interval_print_devices = 10

# デバイス参加許可の実行間隔
interval_permit_join = 300

current_time = time.time()
next_time_print_devices = current_time
next_time_permit_join = current_time + interval_permit_join

# デバイス状態の更新確認用フラグ
changed = True

# 受信データの一時格納用バッファ
recv_buffer = bytearray()

while True:
    # シリアルポートの受信データの確認

    rx_data = ser.read(READ_SIZE)

    if len(rx_data) > 0:
        recv_buffer.extend(rx_data)
        frames = parse_buffer_data(recv_buffer)
        for frame in frames:
            cmd_type = frame[2]
            cmd_code = frame[3]
            # システム通知メッセージ（コマンドタイプ:0x80）
            if cmd_type == 0x80:
                changed = analyze_message_notify(cmd_code, frame[4:], device_manager)
            # ネットワーク管理コマンドの応答メッセージ（コマンドタイプ:0x81）
            elif cmd_type == 0x81:
                changed = analyze_message_zdo(cmd_code, frame[4:], device_manager)
            # ZCL制御コマンドの応答メッセージ（コマンドタイプ:0x82）
            elif cmd_type == 0x82:
                changed = analyze_message_zcl(cmd_code, frame[4:], device_manager)

    current_time = time.time()
    # 更新予定時刻に達するかデバイス状態が変更されたらデバイス一覧を表示する
    if current_time >= next_time_print_devices or changed is True:
        display_all_devices(device_manager.device_list)
        next_time_print_devices += interval_print_devices
        changed = False

    # ネットワーク許可を一定時間ごとに実行する
    if current_time >= next_time_permit_join:
        print("[MESSAGE] Permitting to join network...")
        write_data_to_module(ser, bytes([0x55, 0x03, 0x00, 0x02, 0x02]))
        read_data_from_module(ser)
        next_time_permit_join += interval_permit_join

    time.sleep(0.01)
