# Copyright 2024-2025 XMOS LIMITED.
# This Software is subject to the terms of the XMOS Public Licence: Version 1.
import scipy.signal as spsig
import numpy as np 
import matplotlib.pyplot as plt

import utils as utils
import biquad as bq
import signal_gen as gen
import os
from pathlib import Path

def stft_mean(sig, fs):
    f, t, Z = spsig.spectrogram(
        sig,
        fs,
        window="blackmanharris",
        nperseg=1024 * 16 * 2,
        noverlap=0,
        scaling="spectrum",
        detrend=False,
    )

    return f, np.mean(Z, axis=-1)


def plot_1(q_sig, q_coeff, out_name=None):
    fs = 48000
    l = 4
    signal = gen.sin(fs, 4, 997, 1.0)

    b = 30-q_coeff

    # audient numbers
    filter_q = 2
    filter_g = 1

    biquad_2000 = bq.biquad(bq.make_biquad_peaking(fs, 2000, filter_q, filter_g),
                            fs, 1, b_shift=b, Q_sig=q_sig)
    biquad_200 = bq.biquad(bq.make_biquad_peaking(fs, 200, filter_q, filter_g),
                           fs, 1, b_shift=b, Q_sig=q_sig)
    biquad_50 = bq.biquad(bq.make_biquad_peaking(fs, 50, filter_q, filter_g),
                          fs, 1, b_shift=b, Q_sig=q_sig)

    output_2000 = np.zeros(len(signal))
    output_200 = np.zeros(len(signal))
    output_50 = np.zeros(len(signal))

    for n in range(len(signal)):
        output_2000[n] = biquad_2000.process_int(signal[n])
        output_200[n] = biquad_200.process_int(signal[n])
        output_50[n] = biquad_50.process_int(signal[n])

    f, spect_2000 = stft_mean(output_2000[fs:], fs)
    f, spect_200 = stft_mean(output_200[fs:], fs)
    f, spect_50 = stft_mean(output_50[fs:], fs)

    plt.semilogx(f, utils.db_pow(spect_2000))
    plt.semilogx(f, utils.db_pow(spect_200))
    plt.semilogx(f, utils.db_pow(spect_50))

    plt.legend(["2000Hz", "200Hz", "50Hz"])

    plt.grid()
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Level (dB)")
    plt.title("997Hz sine, 1dB boost, q=2, fs=48kHz\nQ%d.%d coefficients, Q%d.%d signal" % (1+b, 30-b, 31-q_sig, q_sig))
    plt.ylim([-175, 5])
    plt.xlim([10, 20000])

    if out_name:
        plt.savefig(out_name)
        plt.close("all")

    else:
        plt.show()


def plot_2(qsig, q_coeff, filter_f, out_name=None):
    fs = 48000
    l = 4
    signal = gen.sin(fs, 4, 997, 1.0)

    # audient numbers
    filter_q = 2
    filter_g = 1

    b = 30-q_coeff

    biquad_double = bq.biquad(bq.make_biquad_peaking(fs, filter_f, filter_q, filter_g),
                             fs, 1, b_shift=b, Q_sig=qsig)
    biquad_int32 = bq.biquad(bq.make_biquad_peaking(fs, filter_f, filter_q, filter_g),
                             fs, 1, b_shift=b, Q_sig=qsig)
    biquad_xcore = bq.biquad(bq.make_biquad_peaking(fs, filter_f, filter_q, filter_g),
                             fs, 1, b_shift=b, Q_sig=qsig)
    biquad_df2 = bq.biquad(bq.make_biquad_peaking(fs, filter_f, filter_q, filter_g),
                             fs, 1, b_shift=b, Q_sig=qsig-10)

    output_double = np.zeros(len(signal))
    output_int32 = np.zeros(len(signal))
    output_xcore = np.zeros(len(signal))
    output_df2 = np.zeros(len(signal))

    for n in range(len(signal)):
        output_double[n] = biquad_double.process(signal[n])
        output_int32[n] = biquad_int32.process_int(signal[n])
        output_xcore[n] = biquad_xcore.process_xcore(signal[n])
        # output_df2[n] = biquad_df2.process_xcore_df2(signal[n])

    f, spect_double = stft_mean(output_double[fs:], fs)
    f, spect_int32 = stft_mean(output_int32[fs:], fs)
    f, spect_xcore = stft_mean(output_xcore[fs:], fs)
    # f, spect_df2 = stft_mean(output_df2[fs:], fs)

    plt.semilogx(f, utils.db_pow(spect_double))
    plt.semilogx(f, utils.db_pow(spect_int32))
    # plt.semilogx(f, utils.db_pow(spect_df2), 'r')
    plt.semilogx(f, utils.db_pow(spect_xcore))

    plt.legend(["double", "int32", "VPU"])

    plt.grid()
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Level (dB)")
    plt.title("997Hz sine, 1dB %dHz boost, q=2, fs=48kHz\nQ%d.%d coefficients, Q%d.%d signal" % (filter_f, 1+b, 30-b, 31-qsig, qsig))
    plt.ylim([-175, 5])
    plt.xlim([20, 20000])

    if out_name:
        plt.savefig(out_name)
        plt.close("all")

    else:
        plt.show()


if __name__ == "__main__":

    current_file_path = Path(os.path.abspath(__file__))
    image_dir = Path(current_file_path.parents[1], "doc", "rst", "images")

    plot_1(30, 30, Path(image_dir, "picture-2.pdf"))
    plot_1(27, 30, Path(image_dir, "picture-3.pdf"))
    plot_1(27, 28, Path(image_dir, "picture-4.pdf"))

    plot_2(29, 30, 50, Path(image_dir, "picture-7.pdf"))
    plot_2(27, 28, 50, Path(image_dir, "picture-8.pdf"))
    plot_2(27, 28, 2000, Path(image_dir, "picture-9.pdf"))

    pass