import configparser
import subprocess
import platform
from pathlib import Path
import time
import os
import shutil
import re

import mido

from hardware_test_tools.XcoreApp import XcoreApp


class UaDut(XcoreApp):
    def __init__(
        self,
        adapter_id,
        fw_path,
        pid,
        in_chans,
        out_chans,
        vid=0x20B1,
        winbuiltin=False,
        timeout=60,
        xflash=False,
        writeall=False,
        target=None
    ):
        assert fw_path.exists(), f"Firmware not present at {fw_path}"

        self.clock_src = "Internal"
        self.pid = pid
        self.vid = vid
        self.chan_i = in_chans
        self.chan_o = out_chans

        self.winbuiltin = winbuiltin
        if platform.system() == "Windows" and not self.winbuiltin:
            ini_path = (
                Path(os.environ["PROGRAMFILES"])
                / "XMOS"
                / "USB Audio Device Driver"
                / "x64"
                / "custom.ini"
            )
            assert (
                ini_path.exists()
            ), f"tusbaudio SDK custom.ini not found in expected location: {ini_path}"

            with open(ini_path, "r") as f:
                config = configparser.ConfigParser()
                config.read_file(f)
            try:
                self.driver_guid = config.get("DriverInterface", "InterfaceGUID")
            except (configparser.NoSectionError, configparser.NoOptionError):
                assert 0, f"Could not find InterfaceGUID in {ini_path}"
        else:
            self.driver_guid = None

        volcontrol_dir = Path(__file__).parents[1] / "build" / "volcontrol"
        if platform.system() == "Windows":
            volcontrol_path = volcontrol_dir / "volcontrol.exe"
        else:
            volcontrol_path = volcontrol_dir / "volcontrol"

        if volcontrol_path.exists():
            self.volcontrol_cmd = [volcontrol_path]
            if platform.system() == "Windows" and self.driver_guid:
                self.volcontrol_cmd.append(f"-g{self.driver_guid}")
        else:
            self.volcontrol_cmd = None

        super().__init__(fw_path, adapter_id, timeout=timeout, xflash=xflash, writeall=writeall, target=target)

    def wait_for_enumeration(self):
        for _ in range(10):
            time.sleep(1)

            if platform.system() == "Windows":
                ret = subprocess.run(["pnputil", "/enum-devices", "/connected"], text=True, timeout=30, capture_output=True)
                assert ret.returncode == 0, f"Querying devices failed:\n{ret.stdout}\n{ret.stderr}"
                enum_re = fr"Instance ID:\s+USB\\VID_{self.vid:04X}&PID_{self.pid:04X}&MI_00\\.*\nDevice Description:\s+(.*)"
                match = re.search(enum_re, ret.stdout)
                if match:
                    self.usb_name = match.group(1)
                    return
            elif platform.system() == "Darwin":
                ret = subprocess.run(["system_profiler", "SPUSBDataType"], text=True, timeout=30, capture_output=True)
                assert ret.returncode == 0, f"Querying devices failed:\n{ret.stdout}\n{ret.stderr}"
                usb_enum_re = fr"^\s+(.*):\n\n\s+Product ID: 0x{self.pid:04x}\n\s+Vendor ID: 0x{self.vid:04x}\n"
                match = re.search(usb_enum_re, ret.stdout, flags=re.MULTILINE)
                if not match:
                    continue
                usb_name = match.group(1)
                ret = subprocess.run(["system_profiler", "SPAudioDataType"], text=True, timeout=30, capture_output=True)
                aud_enum_re = fr"^\s+{re.escape(usb_name)}:\n\n"
                if re.search(aud_enum_re, ret.stdout, flags=re.MULTILINE):
                    self.usb_name = usb_name
                    return
            else:
                assert 0, f"Unsupported platform: {platform.system()}"

        fail_str = f"Device ({self.vid:04X}:{self.pid:04X}) failed to enumerate\n"

        # Device doesn't appear to have started, so dump the state of the xcore
        ret = subprocess.run(
            ["xrun", "--adapter-id", self.adapter_id, "--dump-state", self.xe_path],
            text=True,
            timeout=10,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
        )
        fail_str += "Register and thread state of the xcore device\n"
        fail_str += "\n".join(ret.stdout.splitlines())
        assert 0, fail_str

    def __enter__(self):
        super().__enter__()

        if platform.system() == "Windows":
            # Delay to allow the Windows Audio service to set up the new device
            time.sleep(40)

        self.wait_for_enumeration()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.volcontrol_cmd and self.clock_src != "Internal":
            # Reset the clock source to default before stopping the application
            self.set_clock_src("Internal")

        super().__exit__(exc_type, exc_val, exc_tb)

        if platform.system() == "Windows":
            # Delay to allow the Windows Audio service to settle after the device is removed
            time.sleep(15)

        # If usbdeview program is present on Windows, uninstall the device to force
        # re-enumeration next time and avoid caching of device features by the OS
        if platform.system() == "Windows":
            usbdeview_path = shutil.which("usbdeview")
            if usbdeview_path:
                subprocess.run(
                    [
                        usbdeview_path,
                        "/RunAsAdmin",
                        "/remove_by_pid",
                        f"0x20b1;{hex(self.pid)}",
                    ],
                    timeout=10,
                )

    def _set_stream_format(self, direction, samp_freq, num_chans, bit_depth, fail_on_err):
        cmd = self.volcontrol_cmd + [
            "--set-format",
            direction,
            f"{samp_freq}",
            f"{num_chans}",
            f"{bit_depth}",
        ]
        ret = subprocess.run(cmd, timeout=30, capture_output=True, text=True)
        if fail_on_err:
            assert ret.returncode == 0, f"failed to setup stream format: {direction}, {samp_freq} fs, {num_chans} channels, {bit_depth} bit\n{ret.stdout}\n{ret.stderr}"

        return ret.returncode

    def _set_full_stream_format(self, samp_freq, in_chans, in_bit_depth, out_chans, out_bit_depth, fail_on_err):
        cmd = self.volcontrol_cmd + ["--set-full-format", f"{samp_freq}", f"{in_chans}", f"{in_bit_depth}", f"{out_chans}", f"{out_bit_depth}"]
        ret = subprocess.run(cmd, timeout=30, capture_output=True, text=True)
        if fail_on_err:
            assert ret.returncode == 0, f"failed to setup full stream format: {samp_freq} fs, {in_chans} in channels, {in_bit_depth} in bit-depth, {out_chans} out channels, {out_bit_depth} out bit-depth.\n{ret.stdout}\n{ret.stderr}"
        return ret.returncode

    def _get_current_stream_format(self):
        cmd = self.volcontrol_cmd + ["--show-current-format"]
        ret = subprocess.run(cmd, timeout=30, capture_output=True, text=True)
        assert ret.returncode == 0, f"failed to get current stream format.\n{ret.stdout}\n{ret.stderr}"

        for line in ret.stdout.splitlines():
            m = re.search(r"^Sampling rate:\s*([0-9]+)", line)
            if m:
                sample_rate = int(m.group(1))
            m = re.search(r"^Input number of channels:\s*([0-9]+)", line)
            if m:
                in_chans = int(m.group(1))
            m = re.search(r"^Input bit depth:\s*([0-9]+)", line)
            if m:
                in_bit_depth = int(m.group(1))
            m = re.search(r"^Output number of channels:\s*([0-9]+)", line)
            if m:
                out_chans = int(m.group(1))
            m = re.search(r"^Output bit depth:\s*([0-9]+)", line)
            if m:
                out_bit_depth = int(m.group(1))

        return sample_rate, in_chans, in_bit_depth, out_chans, out_bit_depth

    def set_stream_format(self, direction, samp_freq, num_chans, bit_depth, fail_on_err=True):
        assert self.volcontrol_cmd

        if self.winbuiltin:
            # Cannot change the stream format
            return

        if samp_freq > 96000: # USB spec limits interfaces to max 10 channels when samp_freq > 96000
            assert num_chans <= 10, f"Cannot set more than 10 channels ({num_chans}) for samp_freq > 96000 ({samp_freq})"
            # Get the current input and output format
            _, in_chans, in_bit_depth, out_chans, out_bit_depth = self._get_current_stream_format()
            # If the number of channels in the other direction is more than 10, limit it to 10
            if direction == "input":
                if out_chans > 10:
                    return self._set_full_stream_format(samp_freq, num_chans, bit_depth, 10, out_bit_depth, fail_on_err)
            elif direction == "output":
                if in_chans > 10:
                    return self._set_full_stream_format(samp_freq, 10, in_bit_depth, num_chans, bit_depth, fail_on_err)

        # If we've got here, just set the stream format that the user has asked for
        return self._set_stream_format(direction, samp_freq, num_chans, bit_depth, fail_on_err)

    def set_clock_src(self, clock_src):
        assert clock_src in [
            "Internal",
            "SPDIF",
            "ADAT",
        ], f"Invalid clock source: {clock_src}"
        assert self.volcontrol_cmd

        cmd = self.volcontrol_cmd + ["--clock", clock_src]
        ret = subprocess.run(cmd, capture_output=True, text=True, timeout=30)

        assert (
            ret.returncode == 0
        ), f"setting clock source failed, cmd: {cmd}\nstdout:\n{ret.stdout}\nstderr:\n{ret.stderr}"

        self.clock_src = clock_src

    def volume_reset(self):
        assert (
            self.chan_i == self.chan_o
        ), "Volume reset is only supported by volcontrol for applications with the same number of channels for input and output"
        assert self.volcontrol_cmd

        num_chans = self.chan_i
        cmd = self.volcontrol_cmd + ["--resetall", f"{num_chans + 1}"]
        ret = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
        assert (
            ret.returncode == 0
        ), f"resetting volume failed, cmd: {cmd}\nstdout:\n{ret.stdout}\nstderr:\n{ret.stderr}"

    def volume_set(self, direction, channel, value):
        assert direction in ["input", "output"]
        assert self.volcontrol_cmd

        num_chans = self.chan_i if direction == "input" else self.chan_o
        # volcontrol takes a value one larger than the actual number of channels
        num_chans += 1

        cmd = self.volcontrol_cmd + ["--set", direction, f"{channel + 1}", f"{value}"]
        ret = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
        assert (
            ret.returncode == 0
        ), f"setting channel volume failed, cmd: {cmd}\nstdout:\n{ret.stdout}\nstderr:\n{ret.stderr}"

    def volume_set_master(self, direction, value):
        assert direction in ["input", "output"]
        assert self.volcontrol_cmd

        num_chans = self.chan_i if direction == "input" else self.chan_o
        # volcontrol takes a value one larger than the actual number of channels
        num_chans += 1

        cmd = self.volcontrol_cmd + ["--set", direction, "0", f"{value}"]
        ret = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
        assert (
            ret.returncode == 0
        ), f"setting master volume failed, cmd: {cmd}\nstdout:\n{ret.stdout}\nstderr:\n{ret.stderr}"

    def wait_for_midi_ports(
        self, timeout_s=60, restart_attempts=15
    ):  # Win 11 agent has been seen to take up to 40s
        for restart_count in range(restart_attempts):
            for i in range(timeout_s):
                midi_in = [md for md in mido.get_input_names() if "XMOS" in md]
                midi_out = [md for md in mido.get_output_names() if "XMOS" in md]
                if len(midi_in) > 0 and len(midi_out) > 0:
                    self.midi_in = midi_in[0]
                    self.midi_out = midi_out[0]
                    print(f"Hooray! XMOS MIDI ports found: {self.midi_in}, {self.midi_out}")
                    return True
                time.sleep(1)
                print(f"MIDI ports not found... retrying {i+1} of {timeout_s}, attempt {restart_count}")
            # Stop and restart the device
            self.__exit__(None, None, None)
            self.__enter__()

        print(f"No XMOS MIDI ports found: {mido.get_input_names()}, {mido.get_output_names()}")
        return False
