#!/usr/bin/env python3
import logging
log = logging.getLogger(__name__)

from contextlib import contextmanager
import io
import os
import pyftdi
from pyftdi.i2c import I2cController
import requests
import subprocess
import sys
import time
import random
import re
import argparse
import signal
import threading
import socket


from xtagctl.exceptions import *
from xtagctl.utils import *

from typing import Union

from pathlib import Path

LOCK_TIMEOUT = 60
XRUN_TIMEOUT = 20
XRUN_RETRIES = 3
XMOS_DOMAIN = 'xmos.local'

GLOBAL_DEVICE_MAP_URL = (
    "https://github0.xmos.com/raw/xmos-int/xtagctl_config/master/device_map"
)
GLOBAL_UHUBCTL_MAP_URL = (
    "https://github0.xmos.com/raw/xmos-int/xtagctl_config/master/uhubctl_map"
)

FT232H_PIN_MAP = {
    "D4": 4,
    "D5": 5,
    "D6": 6,
    "D7": 7,
    "C0": 8,
    "C1": 9,
    "C2": 10,
    "C3": 11,
    "C4": 12,
    "C5": 13,
    "C6": 14,
    "C7": 15,
}

# Idea taken from stack overflow: https://codereview.stackexchange.com/a/54350
# This allows xtagctl to be robust to SIGTERMs (i.e. Jenkins aborts)
def signal_handler(signum, frame):
    raise XtagctlException(f"Signal caught: {signum}")

class GracefulExit:
    def __enter__(self):
        # set up signals here
        # store old signal handlers as instance variables
        if threading.current_thread() is threading.main_thread():
            self.old_handler = signal.getsignal(signal.SIGTERM)
            signal.signal(signal.SIGTERM, signal_handler)
        else:
            log.warning("Not on main thread, locks may not be cleared cleanly")

    def __exit__(self, type, value, traceback):
        # restore old signal handlers
        if threading.current_thread() is threading.main_thread():
            signal.signal(signal.SIGTERM, self.old_handler)


@contextmanager
def device_lock(lock_dir: Union[Path, str], timeout=LOCK_TIMEOUT):
    lock_dir = lock_dir or get_lock_dir()
    lock_path = Path(lock_dir) / "status.lock"
    sleep_time = 0.1
    count = 0
    log.debug("Acquiring lock file at: " + str(lock_path))
    lock_fd = None
    try:
        while timeout < 0 or count < timeout:
            try:
                # Atomically attempt to create the lock file
                lock_fd = os.open(lock_path, os.O_CREAT | os.O_EXCL)
                break
            except FileExistsError:
                pass
            time.sleep(sleep_time)
            count += sleep_time

        if timeout > 0 and count >= timeout:
            raise XtagctlLockTimeout("Error: Timeout while acquiring lockfile")
        # Lock file acquired
        log.info("Lock acquired.")
        yield
        log.info("Lock released.")
    finally:
        if lock_fd is not None:
            os.close(lock_fd)
            os.remove(lock_path)


def get_local_domain():
    """Get the Full Domain Name from socket and cut off the hostname."""
    try:
        return socket.getfqdn().split('.', 1)[1]
    except:
        log.warning("No domain name found...")
        return ""


def get_lock_dir():
    """ Returns the Path to use for storing the lock file and acquired file """

    xtagctl_lock_dir = Path.home() / ".xtag"
    if "XTAGCTL_LOCK_DIR" in os.environ:
        xtagctl_lock_dir = Path(os.environ["XTAGCTL_LOCK_DIR"])
    else:
        # If we're using the default dir, try to create it
        xtagctl_lock_dir.mkdir(exist_ok=True)
    log.debug("Lock directory: " + str(xtagctl_lock_dir))
    return xtagctl_lock_dir


def get_config_dir():
    """ Returns the Path to use for storing the device_map file """

    xtagctl_config_dir = Path.home() / ".xtag"
    if "XTAGCTL_CONFIG_DIR" in os.environ:
        xtagctl_config_dir = Path(os.environ["XTAGCTL_CONFIG_DIR"])
    else:
        # If we're using the default dir, try to create it
        xtagctl_config_dir.mkdir(exist_ok=True)
    log.debug("Config directory: " + str(xtagctl_config_dir))
    return xtagctl_config_dir


def parse_target_pattern(target):
    """ Convert a string into a regex pattern """
    if target[0] == "/" and target[-1] == "/":
        # if format is "/string/", treat as real regex
        return re.compile(target[1:-1])
    else:
        # else make pattern that only matches exactly whats passed
        return re.compile("^" + re.escape(target) + "$")


class DeviceMap:
    def __init__(
        self, lock_dir=None, config_dir=None, xrun_bin="xrun", xrun_timeout=XRUN_TIMEOUT
    ):
        self.target_map = {}
        self.reset_map = {}
        self.acquired = []
        self.xtag_connected = []
        self.device_connected = []
        self.device_inuse = []

        self.xrun_output = None

        # Setup paths
        self.config_dir = config_dir or get_config_dir()
        self.lock_dir = lock_dir or get_lock_dir()

        self.acquire_path = self.lock_dir / Path("acquired")
        self.devicemap_path = self.config_dir / Path("device_map")
        self.uhubctl_path = self.config_dir / Path("uhubctl_map")

        self.xrun_bin = xrun_bin
        self.xrun_timeout = xrun_timeout

        # Parse device map
        self._parse_device_map()

        # Parse acquired file
        self._parse_acquired()

        # Parse xrun output with retries
        retry_count = 0
        while True:
            try:
                self._parse_xrun()
                break
            except XtagctlException:
                log.warning("Parsing xrun failed", exc_info=True)
                retry_count += 1
                if retry_count > XRUN_RETRIES:
                    raise
                log.debug("Retrying...")

    def _parse_device_map_string(self, device_map_string: str):
        """ Parse a string with device_map syntax and update mappings """

        device_map_lines = device_map_string.splitlines()
        for i, line in enumerate(device_map_lines):
            tokens = [t.strip() for t in line.split()]
            if len(tokens) > 0 and tokens[0].startswith("#"):
                continue
            if len(tokens) == 2:
                adapterID = tokens[0]
                target = tokens[1]
                self.target_map[adapterID] = target
            if len(tokens) == 3:
                adapterID = tokens[0]
                target = tokens[1]
                # The reset_pin now represents the method to reset, either a valid pin OR via 'uhubctl'.
                # uhubctl can also have command arguments so we need to not upper() them.
                if 'UHUBCTL' in tokens[2].upper():
                    reset_pin = tokens[2]
                else:
                    reset_pin = tokens[2].upper()
                    if reset_pin not in FT232H_PIN_MAP:
                        raise XtagctlException(f"Invalid reset pin or method on line {i}: {line}")
                self.target_map[adapterID] = target
                self.reset_map[adapterID] = reset_pin

    def _parse_device_map(self):
        """ Parse the device_map file and populate attributes """
        log.info("Parsing device map...")

        self.target_map = {}
        self.reset_map = {}
        # Parse the global device map
        if XMOS_DOMAIN in get_local_domain():
            try:
                r = requests.get(GLOBAL_DEVICE_MAP_URL, timeout=10)
                r.raise_for_status()
                self._parse_device_map_string(r.text)
                r = requests.get(GLOBAL_UHUBCTL_MAP_URL, timeout=10)
                r.raise_for_status()
                self._parse_device_map_string(r.text)
            except requests.RequestException as re:
                log.warning(f"Error while making request for global map: {re}")
                log.warning(f"Local domain name returns: {get_local_domain()}")
                log.warning("Continuing...")
                pass

        # Parse the local device map and uhubctl map
        if self.devicemap_path.exists():
            with open(self.devicemap_path, "r") as f:
                self._parse_device_map_string(f.read())
        if self.uhubctl_path.exists():
            with open(self.uhubctl_path, "r") as f:
                self._parse_device_map_string(f.read())

        log.debug("Target map: " + str(self.target_map))
        log.debug("Reset map: " + str(self.reset_map))

    def _parse_acquired(self):
        """ Parse the acquired file """
        log.info("Parsing acquired file...")

        self.acquired = []
        try:
            with open(self.acquire_path, "r") as f:
                for line in f.readlines():
                    adapterID = line.strip()
                    if adapterID in self.target_map.keys():
                        self.acquired.append(adapterID)
        except FileNotFoundError:
            pass
        log.debug("Acquired xtags: " + str(self.acquired))
        self._save_acquired()

    def _parse_xrun(self):
        """ Parse the output or xrun and populate attributes """
        log.info("Parsing xrun...")

        self.xtag_connected = []
        self.device_connected = []
        self.device_inuse = []
        try:
            xrun_ret = subprocess.run([self.xrun_bin, "-l"], stderr=subprocess.STDOUT, stdout=subprocess.PIPE,
                                      text=True, timeout=self.xrun_timeout)
        except subprocess.TimeoutExpired:
            raise XtagctlXrunTimeout(f"Error: Call to xrun timed out after {self.xrun_timeout}s")
        except FileNotFoundError:
            raise XtagctlException(f"Error: xrun command not found: {self.xrun_bin}")

        xrun_out = xrun_ret.stdout.split("\n")
        self.xrun_output = xrun_ret.stdout
        log.debug(xrun_out)

        # Check that the first 4 lines of xrun_out match the expected lines
        expected_header = ["", "Available XMOS Devices", "----------------------", ""]
        if len(xrun_out) < len(expected_header):
            raise XtagctlException(
                f"Error: xrun output:\n{xrun_out}\n"
                f"does not contain expected header:\n{expected_header}"
            )

        header_match = True
        for i, expected_line in enumerate(expected_header):
            if xrun_out[i] != expected_line:
                header_match = False

        if not header_match:
            raise XtagctlException(
                f"Error: xrun output header:\n{xrun_out[:4]}\n"
                f"does not match expected header:\n{expected_header}"
            )

        try:
            if "No Available Devices Found" in xrun_out[4]:
                return
        except IndexError:
            raise XtagctlException(f"Error: xrun output is too short:\n{xrun_out}\n")

        for line in xrun_out[6:]:
            if line.strip():
                adapterID = line[26:34].strip()
                status = line[34:].strip()
            else:
                continue
            self.xtag_connected.append(adapterID)
            if status != "None":
                self.device_connected.append(adapterID)
            if status == "In Use":
                self.device_inuse.append(adapterID)

    def _save_acquired(self):
        """ Write the list of acquired xtags to file """
        log.debug("Saving acquired file...")
        log.debug(f"Acquired file path: {self.acquire_path}")

        with open(self.acquire_path, "w") as f:
            for d in self.acquired:
                f.write(f"{d}\n")

    @classmethod
    def _get_i2c_controller(self):
        """ Adapted from adafruit_blinka/microcontroller/ft232h/pin.py (MIT) """

        i2c_controller = I2cController()
        try:
            i2c_controller.configure("ftdi://ftdi:ft232h/1")
        except pyftdi.usbtools.UsbToolsError:
            raise XtagctlFt232hException("Could not find FT232H device")
        except ValueError as e:
            raise XtagctlFt232hException(f"FT232H error: {e}")
        return i2c_controller

    @classmethod
    def _initialise_output_pin(self, ft232h_gpio, pin_id):
        """ Adapted from adafruit_blinka/microcontroller/ft232h/pin.py (MIT) """

        pin_mask = ft232h_gpio.pins | 1 << pin_id
        current = ft232h_gpio.direction
        current |= 1 << pin_id  # Set current pin to OUT
        ft232h_gpio.set_direction(pin_mask, current)

    @classmethod
    def _write_pin(self, ft232h_gpio, pin_id, val):
        """ Adapted from adafruit_blinka/microcontroller/ft232h/pin.py (MIT) """

        current = ft232h_gpio.read(with_output=True)
        # read
        if val is None:
            return 1 if current & 1 << pin_id != 0 else 0
        # write
        if val:
            current |= 1 << pin_id
        else:
            current &= ~(1 << pin_id)
        # must mask out any input pins
        ft232h_gpio.write(current & ft232h_gpio.direction)

    def acquire(self, adapterID):
        """ Acquire an adapter """

        self.acquired.append(adapterID)
        self._save_acquired()

    def release(self, adapterID):
        """ Release an adapter """

        self.acquired.remove(adapterID)
        self._save_acquired()

    def usb_reset(self, adapterID):
        """ Reset an adapter using uhubctl """
        try:
            # Manual overide options, if any, are expected between square brackets appended to the UHUBCTL keyword, we find them here.
            if self.reset_map[adapterID].find('[') > 0:
                uhub_options = self.reset_map[adapterID].split('[')[1].split(']')[0]
                uhub_options = uhub_options.split(',')
            # The default method is to reset (-a2) method targetting the device usb iSerial (-s)
            else:
                uhub_options = ['-a2', '-s', adapterID]

            log.info("Attempting uhubctl powercycle of device: %s", adapterID)
            log.debug("Using uhubctl options: %s" % uhub_options)
            uhub_cmd = subprocess.run(["uhubctl"] + uhub_options, check=True, capture_output=True , timeout=10)

            if uhub_cmd.returncode == 0:
                log.info("Uhubctl attempt reports success.")

        except subprocess.CalledProcessError as euhub:
            print("Uhubctl Error:", euhub.stderr.decode())
        except subprocess.TimeoutExpired:
            print("Uhubctl Error: Attempt Timed out")
        except FileNotFoundError:
            print("Uhubctl Error: uhubctl not found, is it installed?")
        # Allow xtag to come back on the bus before proceeding
        time.sleep(1.0)

        if check_loaded_firmware(adapterID):
            log.warning("WARNING: USB reset failed, XTAG %s still has a firmware loaded", adapterID)
        else:
            log.info("Success: XTAG %s firmware removed.", adapterID)

    def ftdi_reset(self, adapterID):
        """ Reset an adapter using an ftdi pin """
        pin_id = FT232H_PIN_MAP[self.reset_map[adapterID]]
        i2c_controller = self._get_i2c_controller()
        try:
            ft232h_gpio = i2c_controller.get_gpio()
            self._initialise_output_pin(ft232h_gpio, pin_id)
            # Set low
            self._write_pin(ft232h_gpio, pin_id, 0)
            time.sleep(0.1)
            # Set high
            self._write_pin(ft232h_gpio, pin_id, 1)
        finally:
            # Release FTDI device
            i2c_controller.terminate()
        # Allow xtag to come back on the bus before proceeding
        time.sleep(1.0)

    def reset(self, adapterID):
        """ Choose which reset method to use and call it."""
        if not adapterID in self.reset_map:
            raise XtagctlException("Adapter {adapterID} reset not mapped to an IO pin or uhubctl")
        if self.reset_map[adapterID].find('UHUBCTL'):
            self.usb_reset(adapterID)
        else:
            self.ftdi_reset(adapterID)

class DeviceController:
    def __init__(
        self,
        lock_dir: Path = None,
        config_dir: Path = None,
        xrun_bin: str = "xrun",
        verbose: bool = False
    ):
        self.dev = DeviceMap(lock_dir, config_dir, xrun_bin)
        self.verbose = verbose

    def acquire_target(self, target: str):
        """ Acquire a target, checking the output of xrun and the device map """
        
        target_pattern = parse_target_pattern(target)
        
        for mapped_target in self.dev.target_map.values():
            if target_pattern.match(mapped_target):
                break
        else:
            raise XtagctlDeviceNotFound(
                f'Error: target matching "{target}" not present in device target_map.'
            )
        # Check if there are any free xtags connected to the specified target
        connected_targets = 0
        free_xtags = []
        for adapter in self.dev.target_map:
            if target_pattern.match(self.dev.target_map[adapter]):
                if adapter not in self.dev.device_connected:
                    continue
                connected_targets += 1
                if adapter in self.dev.acquired or adapter in self.dev.device_inuse:
                    continue
                free_xtags.append(adapter)

        if not free_xtags:
            self.show_status()
            if connected_targets == 0:
                message = f'Error: No connected xTags match the target device "{target}".'
                log.info(message)
                raise XtagctlDeviceNotConnected(message)
            else:
                message = f"""Error: All xTags mapped to \"{target}\" are already acquired or in-use."""
                log.info(message)
                raise XtagctlDeviceInUse(message)

        chosen_xtag = random.choice(free_xtags)
        self.dev.acquire(chosen_xtag)
        log.info(f"Acquired {chosen_xtag}")
        return chosen_xtag

    def release_adapter(self, adapterID: str):
        """ Release an adapter """

        if not adapterID in self.dev.target_map.keys():
            raise XtagctlException(
                f'Error: adapter "{adapterID}" not present in device map.'
            )
        if not adapterID in self.dev.acquired:
            raise XtagctlException(f'Error: adapter "{adapterID}" is already free.')

        self.dev.release(adapterID)
        log.info(f"Released {adapterID}")
        self.show_status()

    def reset_adapter(self, adapterID: str):
        """ Reset an adapter """

        if not adapterID in self.dev.reset_map.keys():
            print(f'Warning: adapter "{adapterID}" not mapped to a reset pin or uhubctl.')
        else:
            log.info(f"Resetting {adapterID}")
            self.dev.reset(adapterID)
        self.show_status()

    def reset_all_adapters(self, targets: list):
        """ Reset all connected adapters matching any of targets """

        for target in targets:
            target_pattern = parse_target_pattern(target)
            for adapterID in self.dev.xtag_connected:
                if adapterID in self.dev.target_map:
                    if target_pattern.match(self.dev.target_map[adapterID]):
                        self.reset_adapter(adapterID)

    def show_status(self):
        if self.verbose:
            printfn = print
        else:
            printfn = log.debug

        if self.dev.target_map:
            max_target_name_len = max(
                10, *[len(name) for name in self.dev.target_map.values()]
            )

            table_fmt = (
                "{:<10}| {:<"
                + str(max_target_name_len + 1)
                + "}| {:<25}| {:<10}| {:<10}"
            )

            header = table_fmt.format(
                "Adapter", "Target", "Reset Pin/Method", "Aquired", "Status"
            )
            printfn(header)
            separator = ""
            for i, c in enumerate(header):
                separator += "+" if c == "|" else "-"
            printfn(separator)
        else:
            printfn("Warning: No entries in xtagctl device_map.")

        for adapter in self.dev.target_map:
            status = "xTag Disconnected"
            if adapter in self.dev.xtag_connected:
                status = "Device Disconnected"
            if adapter in self.dev.device_connected:
                status = "Ready"
            if adapter in self.dev.device_inuse:
                status = "In Use"

            reset_pin = "N/A"
            if adapter in self.dev.reset_map:
                reset_pin = str(self.dev.reset_map[adapter])

            printfn(
                table_fmt.format(
                    adapter,
                    self.dev.target_map[adapter],
                    reset_pin,
                    "Acquired" if adapter in self.dev.acquired else "Free",
                    status,
                )
            )

        for adapter in self.dev.xtag_connected:
            if adapter not in self.dev.target_map:
                printfn(f"Adapter connected but not mapped: {adapter}")

        printfn("xrun output:")
        printfn(self.dev.xrun_output)


class App:
    def __init__(self, xtagctl_path=None, lock_dir=None, config_dir=None):
        parser = argparse.ArgumentParser(
            description="Software for managing multiple xCore devices on a \
                         single host.",
            usage="""xtagctl <command> [<args>]

Command list:
   acquire <TARGET>             Acquire an adapter ID connected to a device
                                matching <TARGET> or /<TARGET_PATTERN>/.
   release <adapterID>          Release the specified adapter.
   reset <adapterID>            Resets the specified adapter.
   reset_all <TARGETS>          Resets all adapters matching space separated <TARGETS>.
   status                       Get the adapter-device map and the current
                                status of each device.
""",
        )
        parser.add_argument("command", help="Subcommand to run")
        args = parser.parse_args(sys.argv[1:2])
        if not hasattr(self, args.command):
            print("Unrecognized command")
            parser.print_help()
            sys.exit(1)

        # Setup paths
        self.xtagctl_path = Path(__file__).parent
        if xtagctl_path:
            self.xtagctl_path = xtagctl_path

        self.lock_dir = lock_dir or get_lock_dir()
        self.config_dir = config_dir or get_config_dir()

        with GracefulExit():
            with device_lock(self.lock_dir):
                getattr(self, args.command)()

    def acquire(self):
        parser = argparse.ArgumentParser(
            prog="xtagctl acquire",
            description="Attempts to acquire the specified target   ",
        )
        parser.add_argument("target", type=str,
                            help="The name of the target as specified in a device map. "
                            "Also accepts /wrapped/ strings as regex patterns")
        args = parser.parse_args(sys.argv[2:])

        dev_controller = DeviceController(self.lock_dir, self.config_dir)
        chosen_xtag = dev_controller.acquire_target(args.target)
        print(chosen_xtag)

    def release(self):
        parser = argparse.ArgumentParser(
            prog="xtagctl release",
            description="Frees the xtag with the specified adapter ID",
        )
        parser.add_argument("adapterID", type=str)
        args = parser.parse_args(sys.argv[2:])

        dev_controller = DeviceController(self.lock_dir, self.config_dir)
        dev_controller.release_adapter(args.adapterID)

    def reset(self):
        parser = argparse.ArgumentParser(
            prog="xtagctl reset",
            description="Resets the xtag with the specified adapter ID",
        )
        parser.add_argument("adapterID", type=str)
        args = parser.parse_args(sys.argv[2:])

        dev_controller = DeviceController(self.lock_dir, self.config_dir)
        dev_controller.reset_adapter(args.adapterID)

    def reset_all(self):
        parser = argparse.ArgumentParser(
            prog="xtagctl reset",
            description="Resets all the attached xtags",
        )
        parser.add_argument("targets", type=str, nargs="+",
                            help="The name of the target as specified in a device map. "
                            "Also accepts /wrapped/ strings as regex patterns")
        args = parser.parse_args(sys.argv[2:])

        dev_controller = DeviceController(self.lock_dir, self.config_dir)
        dev_controller.reset_all_adapters(args.targets)

    def status(self):
        parser = argparse.ArgumentParser(
            prog="xtagctl status",
            description="Shows the current status of the device_map",
        )
        args = parser.parse_args(sys.argv[2:])

        dev_controller = DeviceController(self.lock_dir, self.config_dir, verbose=True)
        dev_controller.show_status()
