from pydantic import BaseModel, RootModel, Field, create_model
from typing import Literal, Annotated, List, Union
from annotated_types import Len
from functools import partial


DEFAULT_Q = partial(Field, 0.707, gt=0, le=10, description="Q factor of the filter.")
DEFAULT_FILTER_FREQ = partial(
    Field, 500, gt=0, lt=24000, description="Frequency of the filter in Hz."
)  # 48kHz sample rate
DEFAULT_BW = partial(
    Field, 1, gt=0, le=10, description="Bandwidth of the filter in octaves."
)
DEFAULT_BOOST_DB = partial(
    Field, 0, ge=-24, le=24, description="Gain of the filter in dB."
)


class _BaseModel(BaseModel, extra="forbid"):
    pass


class biquad_allpass(_BaseModel):
    type: Literal["allpass"] = "allpass"
    filter_freq: float = DEFAULT_FILTER_FREQ()
    q_factor: float = DEFAULT_Q()


class biquad_bandpass(_BaseModel):
    type: Literal["bandpass"] = "bandpass"
    filter_freq: float = DEFAULT_FILTER_FREQ()
    bw: float = DEFAULT_BW()


class biquad_bandstop(_BaseModel):
    type: Literal["bandstop"] = "bandstop"
    filter_freq: float = DEFAULT_FILTER_FREQ()
    bw: float = DEFAULT_BW()


class biquad_bypass(_BaseModel):
    type: Literal["bypass"] = "bypass"


class biquad_constant_q(_BaseModel):
    type: Literal["constant_q"] = "constant_q"
    filter_freq: float = DEFAULT_FILTER_FREQ()
    q_factor: float = DEFAULT_Q()
    boost_db: float = DEFAULT_BOOST_DB()


class biquad_gain(_BaseModel):
    type: Literal["gain"] = "gain"
    gain_db: float = 0


class biquad_highpass(_BaseModel):
    type: Literal["highpass"] = "highpass"
    filter_freq: float = DEFAULT_FILTER_FREQ()
    q_factor: float = DEFAULT_Q()


class biquad_highshelf(_BaseModel):
    type: Literal["highshelf"] = "highshelf"
    filter_freq: float = DEFAULT_FILTER_FREQ()
    q_factor: float = DEFAULT_Q()
    boost_db: float = DEFAULT_BOOST_DB()


class biquad_linkwitz(_BaseModel):
    type: Literal["linkwitz"] = "linkwitz"
    f0: float = 500
    q0: float = DEFAULT_Q()
    fp: float = 1000
    qp: float = DEFAULT_Q()


class biquad_lowpass(_BaseModel):
    type: Literal["lowpass"] = "lowpass"
    filter_freq: float = DEFAULT_FILTER_FREQ()
    q_factor: float = DEFAULT_Q()


class biquad_lowshelf(_BaseModel):
    type: Literal["lowshelf"] = "lowshelf"
    filter_freq: float = DEFAULT_FILTER_FREQ()
    q_factor: float = DEFAULT_Q()
    boost_db: float = DEFAULT_BOOST_DB()


class biquad_notch(_BaseModel):
    type: Literal["notch"] = "notch"
    filter_freq: float = DEFAULT_FILTER_FREQ()
    q_factor: float = DEFAULT_Q()


class biquad_peaking(_BaseModel):
    type: Literal["peaking"] = "peaking"
    filter_freq: float = DEFAULT_FILTER_FREQ()
    q_factor: float = DEFAULT_Q()
    boost_db: float = DEFAULT_BOOST_DB()


BIQUAD_TYPES = Union[
    biquad_allpass,
    biquad_bandpass,
    biquad_bandstop,
    biquad_bypass,
    biquad_constant_q,
    biquad_gain,
    biquad_highpass,
    biquad_highshelf,
    biquad_linkwitz,
    biquad_lowpass,
    biquad_lowshelf,
    biquad_notch,
    biquad_peaking,
]


def _8biquads():
    return [biquad_bypass() for _ in range(8)]


class CascadedBiquadParams(RootModel):
    root: Annotated[list[BIQUAD_TYPES], Len(8)] = Field(
        default_factory=_8biquads, max_items=8
    )

    class Config:
        title = "Cascaded Biquads"

    def update_stage(self, stage):
        """Update a cascanded biquad stage with the parameters in this model"""
        model = self.model_dump()
        biquads = [[*spec.values()] for spec in model]
        stage.make_parametric_eq(biquads)


class VolumeControlParams(_BaseModel):
    gain_db: float = 0
    mute: bool = False

    def update_stage(self, stage):
        stage.set_gain(self.gain_db)
        stage.set_mute_state(self.mute)


class SwitchParams(_BaseModel):
    position: int = 0

    def update_stage(self, stage):
        stage.move_switch(self.position)


class NoiseSuppressorExpanderParams(_BaseModel):
    threshold: float = Field(-45, ge=-50, le=10, title="Threshold (dB)")
    attack: float = Field(0.005, ge=0, le=0.2, title="Attack (s)")
    release: float = Field(0.12, ge=0, le=1.0, title="Release (s)")
    ratio: float = Field(3, ge=1, le=20, title="Ratio")

    def update_stage(self, stage):
        stage.make_noise_suppressor_expander(
            ratio=self.ratio,
            threshold_db=self.threshold,
            attack_t=self.attack,
            release_t=self.release,
        )


class ReverbParams(_BaseModel):
    damping: float = Field(
        0.5,
        ge=0,
        le=1,
        title="Damping",
        description="This controls how much high frequency attenuation "
        "is in the room. Higher values yield shorter "
        "reverberation times at high frequencies. Range: 0 to 1",
    )
    decay: float = Field(
        0.5,
        ge=0,
        le=1,
        title="Decay",
        description="This sets how reverberant the room is. Higher "
        "values will give a longer reverberation time for "
        "a given room size. Range: 0 to 1",
    )
    early_diffusion: float = Field(
        0.5, ge=0, le=1, title="Early Diffusion", description="Range: 0 to 1"
    )
    late_diffusion: float = Field(
        0.5, ge=0, le=1, title="Late Diffusion", description="Range: 0 to 1"
    )
    bandwidth: float = Field(
        8000, ge=0, le=24000, title="Bandwidth", description="Range: 0 to 1"
    )
    predelay: float = Field(
        15,
        ge=0,
        le=30,
        title="Predelay",
        description="Set the predelay in milliseconds.",
    )
    width: float = Field(1.0, ge=0, le=1, title="Width", description="Range: 0 to 1")
    pregain: float = Field(
        0.5,
        ge=0,
        le=1,
        title="Pregain",
        description="It is not advised to increase this value above the "
        "default 0.015, as it can result in saturation inside "
        "the reverb delay lines.",
    )
    wet_dry: float = Field(
        0.5,
        ge=0,
        le=1,
        title="Wet/Dry mix",
        description="It is not advised to increase this value above the "
        "default 0.015, as it can result in saturation inside "
        "the reverb delay lines.",
    )

    class Config:
        title = "Reverb"

    def update_stage(self, stage):
        stage.set_damping(self.damping)
        stage.set_decay(self.decay)
        stage.set_early_diffusion(self.early_diffusion)
        stage.set_late_diffusion(self.late_diffusion)
        stage.set_bandwidth(self.bandwidth)
        stage.set_predelay(self.predelay)
        stage.set_width(self.width)
        stage.set_pre_gain(self.pregain)
        # Set by physical slider
        stage.set_wet_dry_mix(self.wet_dry)


class CompressorSidechainParams(_BaseModel):
    threshold: float = Field(-40, ge=-50, le=10, title="Threshold (dB)")
    attack: float = Field(0.01, ge=0, le=0.2, title="Attack (s)")
    release: float = Field(0.5, ge=0, le=1, title="Release (s)")
    ratio: float = Field(5, ge=1, le=20, title="Ratio")

    def update_stage(self, stage):
        stage.make_compressor_sidechain(
            ratio=self.ratio,
            threshold_db=self.threshold,
            attack_t=self.attack,
            release_t=self.release,
        )


class EnvelopeDetetectorParams(_BaseModel):
    attack: float = 0
    release: float = 0


class Params(BaseModel, extra="forbid"):
    checksum: List
    reverb: ReverbParams = Field(default_factory=ReverbParams, title="Reverb\nWet/Dry")
    headphone_volume: VolumeControlParams = Field(
        default_factory=VolumeControlParams, title="Headphone\nVolume"
    )
    output_volume: VolumeControlParams = Field(
        default_factory=VolumeControlParams, title="Output\nVolume"
    )
    mic_volume: VolumeControlParams = Field(
        default_factory=VolumeControlParams, title="Microphone\nVolume"
    )
    music_volume: VolumeControlParams = Field(
        default_factory=VolumeControlParams, title="Music\nVolume"
    )
    peq: CascadedBiquadParams = Field(default_factory=CascadedBiquadParams)
    duck_switch: SwitchParams = Field(default_factory=SwitchParams, title="Ducking")
    monitor_switch: SwitchParams = Field(default_factory=SwitchParams, title="Monitor")
    loopback_switch: SwitchParams = Field(
        default_factory=SwitchParams, title="Loopback"
    )
    reverb_switch: SwitchParams = Field(default_factory=SwitchParams, title="Reverb")
    peq_switch: SwitchParams = Field(default_factory=SwitchParams, title="Equaliser")
    denoise_switch: SwitchParams = Field(default_factory=SwitchParams, title="Denoise")
    denoise: NoiseSuppressorExpanderParams = Field(
        default_factory=NoiseSuppressorExpanderParams, title="Denoise"
    )
    ducking: CompressorSidechainParams = Field(
        default_factory=CompressorSidechainParams, title="Ducking"
    )
