﻿# Copyright 2025 XMOS LIMITED.
# This Software is subject to the terms of the XMOS Public Licence: Version 1.
from enum import IntEnum
from pathlib import Path
from copy import deepcopy
import sys
import numpy as np
from pydantic import ValidationError
import threading

from PySide6.QtWidgets import (
    QApplication,
    QMainWindow,
    QFileDialog,
    QMessageBox,
    QHBoxLayout,
    QPushButton,
)
from PySide6.QtCore import Signal, QObject, Slot, QTimer, QLocale
from PySide6.QtGui import QIcon

from audio_dsp.design.parse_json import make_pipeline, DspJson
from audio_dsp.design.pipeline import  generate_dsp_main
from tuning_utility import device
import os
import ctypes
import queue
import matplotlib.pyplot as plt

from .translations import tr_str

if os.name =="nt":
    ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(u'xmos.demo.app')

class DspWindow(QMainWindow):
    """Main window for the DSP controller application.

    This sets up the DSP state machine, and initializes a separate thread
    for the device control through the host application. It also
    provides buttons for loading and saving DSP parameters, generating
    code.

    Parameters
    ----------
    params : DspJson
        The DSP parameters loaded from a JSON file.
    code_gen_dir : Path
        The directory where the generated DSP code will be saved.
    title : str, optional
        The title of the main window, by default "DSP Controller"

    """

    def __init__(self, params: DspJson, code_gen_dir: Path, title=tr_str("XMOS DSP Controller")):
        super().__init__()

        # Initialize the DSP state
        if not params:
            params = self.load_dialog()
            if params is None:
                # User clicked cancel with no initial file, just exit
                QApplication.quit()
                sys.exit()

        self.code_gen_dir = code_gen_dir

        # set title and logo
        self.setWindowTitle(title)
        logo = QIcon()
        logo_path = Path(__file__).parent / "assets" / "XMOS_logo__X_transparent.png"
        logo.addFile(str(logo_path))
        self.setWindowIcon(logo)

        # Initialize the DSP state
        if not params:
            self.on_load_button()
        # this kicks off the device thread
        self.state = DspState(params)

        # Get the current language code (e.g., 'zh_CN', 'es', 'en_US')
        lang = QLocale().name()

        if lang.startswith("zh"):
            plt.rcParams['font.sans-serif'] = ['PingFang SC', 'SimHei', 'Noto Sans CJK SC', 'Microsoft YaHei', 'Arial Unicode MS']
            plt.rcParams['axes.unicode_minus'] = False  # To display minus sign correctly

    def _send_state(self):
        self.q.put(device.StateUpdated(self.state.state))
        self.q.put(
            device.SendTuning(
                lambda s: self.device_emitter.send_tuning_complete.emit(s)
            )
        )

    def closeEvent(self, event):
        # Send exit command to the device thread, then do regular close
        self.state.kill()
        super().closeEvent(event)

    def make_save_load_gen_buttons(self):
        # Buttons at the bottom
        button_layout = QHBoxLayout()
        self.load_button = QPushButton(tr_str("Load"))
        self.load_button.clicked.connect(self.on_load_button)

        self.save_button = QPushButton(tr_str("Save"))
        self.save_button.clicked.connect(self.on_save_button)

        self.gen_button = QPushButton(tr_str("Generate Code"))
        self.gen_button.clicked.connect(self.on_generate_button)

        button_layout.addWidget(self.load_button)
        button_layout.addWidget(self.save_button)
        button_layout.addWidget(self.gen_button)

        return button_layout

    def on_save_button(self):
        filename, _ = QFileDialog.getSaveFileName(self, tr_str("Save xDSP"), ".", tr_str("xDSP Files (*.json)"), options=QFileDialog.DontUseNativeDialog)
        if not filename:
            return
        filepath = Path(filename)
        filepath.write_text(self.state.state.model_dump_xdsp(indent=2))

        msgBox = QMessageBox(self)
        msgBox.setWindowTitle(tr_str("Save xDSP"))
        msgBox.setText(tr_str("The parameters has been saved."))
        msgBox.setInformativeText(tr_str("Would you like to generate the code?"))
        msgBox.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
        msgBox.setDefaultButton(QMessageBox.Yes)
        ret = msgBox.exec()

        if ret == QMessageBox.Yes:
            self.on_generate_button()

    def load_dialog(self):
        filename, _ = QFileDialog.getOpenFileName(self, tr_str("Open xDSP"), ".", tr_str("xDSP Files (*.json)"), options=QFileDialog.DontUseNativeDialog)
        if not filename:
            return
        filepath = Path(filename)

        try:        
            tuning = DspJson.model_validate_json(filepath.read_text())

        except ValidationError as e:
            print(e)
            msgBox = QMessageBox.critical(self, tr_str("Pipeline error"), tr_str("The loaded .json pipeline is not "
            "compatible with this the JSON schema defined in the current version of lib_audio_dsp. "
            "Please try loading a different file."))
            return self.load_dialog()

        return tuning

    def on_load_button(self):

        # Load the tuning from a file and make a pipeline
        tuning = self.load_dialog()
        if tuning is None:
            # User clicked cancel with no initial file, just exit
            return
        p_current = make_pipeline(self.state.state)
        p_new = make_pipeline(tuning)

        if checksums_equal(p_current.pipeline_stage["checksum"], p_new.pipeline_stage["checksum"]):
            self.state.update_load(tuning)
        else:
            # If the checksum does not match, we can still load the pipeline,
            # but we will show a warning message box.
            self.load_invalid_checksum_messagebox()

            # kill the old device thread
            self.state.kill()

            # start a new device thread with the new pipeline
            self.state = DspState(tuning)

        current = self.tabs.currentIndex()
        self.tabs.clear()
        self.make_tabs()
        self.tabs.setCurrentIndex(current)

    def on_generate_button(self):
        code_gen_dir = QFileDialog.getExistingDirectory(self, "Gen folder", str(self.code_gen_dir), options=QFileDialog.DontUseNativeDialog)
        p = make_pipeline(self.state.state)
        generate_dsp_main(p, out_dir=code_gen_dir)

    def load_invalid_checksum_messagebox(self):
        msgBox = QMessageBox.warning(self, tr_str("Checksum mismatch"), tr_str("The checksum of the loaded pipeline is "
        "different to the current pipeline. Continue to reload the GUI? \n\n"
        "If this JSON file is incompatible with the current GUI, errors may occur.\n\n"
        "Remember to reflash your device with the correct generated DSP code."))


class ChecksumError(Exception):
    """Raised when the checksum does not match."""
    pass


class DspState(QObject):
    """Holds the current state of the DSP pipeline and manages device control.

    This sets up the DSP pipeline state, starts the device control thread, and initializes
    logging and device communication mechanisms.

    Parameters
    ----------
    params : DspJson
        The DSP parameters loaded from a JSON file, used to initialize the pipeline state.
    """

    STATE_CHANGED = Signal(DspJson)

    def __init__(self, params: DspJson):

        super().__init__()
        p = make_pipeline(params)

        # setup the device control thread and logging
        self.q = queue.Queue()
        self.device_logger = LogEmitter()
        self.log_outputter = Logger()
        self.device_logger.log_signal.connect(self.log_outputter.log)
        self.device_thread = threading.Thread(
            target=device.device_thread, args=(self.q, p, params, self.device_logger)
        )
        self.device_thread.start()
        self.device_emitter = DeviceEmitter()

        # Set up the initial state
        self._state = params
        self.node_dict = {node.placement.name: i for i, node in enumerate(params.graph.nodes)}

        # Connect the device state machine to the state change signal
        # and send the initial state to the device thread.
        device_state = UpdateDeviceStateMachine(self._send_state)
        self.device_emitter.send_tuning_complete.connect(
            lambda _: device_state.tuning_complete()
        )
        self.STATE_CHANGED.connect(lambda _: device_state.state_updated())
        device_state.state_updated()

    def _send_state(self):
        """Send the current parameters to the device thread."""
        self.q.put(device.StateUpdated(self.state))
        self.q.put(
            device.SendTuning(
                lambda s: self.device_emitter.send_tuning_complete.emit(s)
            )
        )

    def kill(self):
        """Terminate the device thread and clean up resources.

        This must be called before the application exits to ensure
        that the device thread is properly terminated.
        """
        # Send exit command to the device thread
        if hasattr(self, "q"):
            self.q.put(device.Exit())
        if hasattr(self, "device_thread"):
            self.device_thread.join()  # Wait for thread to finish
    
    @property
    def state(self) -> DspJson:
        """Get the current state of the DSP pipeline as a DspJson object."""
        return self._state

    def get_node(self, node_name):
        """Get a node by its name from the state."""
        return self.state.graph.nodes[self.node_dict[node_name]]

    def get_parameters(self, node_name):
        """Get the parameters of a node by its name."""
        return self.get_node(node_name).parameters

    def update(self, update_dict):
        """Update the state with a dictionary of node parameters.
        
        Parameters
        ----------
        update_dict : dict
            A dictionary where keys are node indexes and values are StageParameter objects.
        """
        state_changed = False
        for key, value in update_dict.items():
            current_parameters = self._state.graph.nodes[key].parameters
            if current_parameters != value:
                self._state.graph.nodes[key].parameters = value
                state_changed = True

        if state_changed:
            self.STATE_CHANGED.emit(self._state)

    def update_load(self, tuning: DspJson):
        """Update the state with a new DspJson object.

        Parameters
        ----------
        tuning : DspJson
            The new pydantic model to update the state with.
        """ 
        state = deepcopy(tuning)
        if state != self._state:
            self._state = state
            self.STATE_CHANGED.emit(self.state)

    def deep_update(self, index, field, value):
        """Update a specific field of a node's parameters.
        This is used to update a single field of a node's parameters, such as
        the gain of a stage, without affecting the rest of the parameters.
        
        Parameters
        ----------
        index : int
            The index of the node in the graph.
        field : str
            The field of the parameters to update.
        value : Any
            The new value for the field.
        """
        data = self.state.graph.nodes[index].parameters
        data = data.model_copy(update={field: value})
        self.update({index: data})


class LogEmitter(QObject):
    log_signal = Signal(str)

    def log(self, s):
        self.log_signal.emit(s)


class Logger(QObject):
    @Slot()
    def log(self, str):
        print(str)


class DeviceEmitter(QObject):
    send_tuning_complete = Signal(bool)
    plot_updated = Signal(np.ndarray, np.ndarray)


class UpdateDeviceStateMachine(QObject):
    DEBOUNCE_MS = 0

    class _State(IntEnum):
        IDLE = 0
        TUNING = 1
        UPDATED_WHILE_TUNING = 2
        DEBOUNCE = 3

    @Slot()
    def _timeout(self):
        if self._state == self._State.IDLE:
            self._state = self._state.TUNING
            self._do_tuning_cb()

    def __init__(self, do_tuning_cb):
        self._state = self._State.IDLE
        self._do_tuning_cb = do_tuning_cb
        self._timer = QTimer()
        self._timer.setSingleShot(True)
        self._timer.timeout.connect(self._timeout)

    def tuning_complete(self):
        if self._state == self._State.TUNING:
            self._state = self._State.IDLE
        elif self._state == self._State.UPDATED_WHILE_TUNING:
            self._state = self._State.TUNING
            self._do_tuning_cb()
        else:
            raise ValueError(f"Invalid state transition: {self._state}")

    def state_updated(self):
        if self._state == self._State.TUNING:
            self._state = self._State.UPDATED_WHILE_TUNING
        elif self._state == self._State.IDLE:
            # State updated, assume more updates to come so
            # stay in idle until the debounce timer elapses.
            self._timer.stop()
            self._timer.start(self.DEBOUNCE_MS)
        else:
            # stay in this state
            pass

def validate_checksum(actual, desired):
    equal = np.array_equal(
    np.array(actual),
    np.array(desired),
    )

    if equal is False:
        raise ChecksumError(
            tr_str(
                "Pipeline mismatch; the pipeline defined in the GUI does not match "
                "the pipeline defined in the saved JSON file.\n"
                f"\n\tExpected checksum: {desired}\n\tGot {actual}"
            )
        )

def checksums_equal(actual, desired):
    try:
        validate_checksum(actual, desired)
        return True
    except ChecksumError:
        return False