# Copyright 2023-2024 XMOS LIMITED.
# This Software is subject to the terms of the XMOS Public Licence: Version 1.
"""
This script generates filter coefficients for a rational factor 48 - 32 kHz sample rate conversion
"""
import argparse
import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
from pathlib import Path

fs = 96000

NUM_TAPS = 96
NUM_PHASES_DS = 2
NUM_PHASES_US = 3

def test_bounds(Y, F, freq, min, max):
    # This will find the closest frequency we can get and test the responce
    idx = (np.abs(F - freq)).argmin()
    #print(f"{Y[idx]} dB at {F[idx]}")
    assert Y[idx] < max
    if min != None:
        assert Y[idx] > min

def test_filter(taps_fl):
    # This function will check that the shape of the filter is appropriate
    w, h = signal.freqz(taps_fl)
    Y = 20 * np.log10(np.abs(h))
    F = 0.5 * fs * w / np.pi

    test_bounds(Y, F, 16000,   None,  -6 )
    test_bounds(Y, F, 15000,  -2,    0.05)
    test_bounds(Y, F, 14000,  -0.05, 0.05)
    test_bounds(Y, F, 12500,  -0.05, 0.05)
    test_bounds(Y, F, 5000,   -0.05, 0.05)
    test_bounds(Y, F, 200,    -0.05, 0.05)
    test_bounds(Y, F, 80,     -0.05, 0.05)

def mix_coefs(taps, fact_up):
    num_taps_per_phase = len(taps) // fact_up

    poly_taps = np.zeros([fact_up, num_taps_per_phase])
    i = 0
    for t in range(num_taps_per_phase):
        for ph in range(fact_up):
            poly_taps[ph][t] = taps[i]
            i += 1

    poly_taps_int = np.zeros([fact_up, num_taps_per_phase], np.int32)
    for ph in range(fact_up): poly_taps_int[ph] = (poly_taps[ph] * 2 ** 30).astype(np.int32)

    return poly_taps, poly_taps_int

def gen_coefs(total_num_taps = NUM_TAPS):
    """
    Get 16 kHz low pass filter coefficients for the 48 - 32 kHz rational factor polyphase filtering

    Returns:
        taps[total_num_taps] in float for plotting and debugging

        taps_ds[2][total_num_taps / 2] in float for debugging the downsampler

        taps_ds[2][total_num_taps / 2] in int32 for the downsampler implementation

        taps_ds[3][total_num_taps / 3] in float for debugging the upsampler

        taps_ds[3][total_num_taps / 3] in int32 for the upsampler implementation

    """
    # alternative filter left here for manual testing
    #lpf_16k = signal.remez(total_num_taps, [0, 14300, 17700, 0.5 * fs], [1, 0], [.05, 1], fs=fs)
    lpf_16k = signal.firwin2(total_num_taps, [0, 15000, 17000, 0.5 * fs], [1, 1, 0, 0], window = ("kaiser", 3.2), fs=fs)

    test_filter(lpf_16k)
    poly_ds, poly_ds_int = mix_coefs(lpf_16k, NUM_PHASES_DS)
    poly_us, poly_us_int = mix_coefs(lpf_16k, NUM_PHASES_US)

    return lpf_16k, poly_ds, poly_ds_int, poly_us, poly_us_int

def plot_response(taps, passband = False, freq_domain = True):
    if passband and not freq_domain: return
    plt.figure()
    if freq_domain:
        w, h = signal.freqz(taps)
        #plt.plot(np.unwrap(np.angle(h)))
        plt.plot(0.5 * fs * w / np.pi, 20 * np.log10(np.abs(h)))
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Gain (dB)')
    else:
        plt.plot(taps)
        plt.xlabel('Time')
        plt.ylabel('Amplitude')

    if passband:
        plt.ylim(-0.2, 0.2)
        if freq_domain: plt.xlim(0, 0.25 * fs)
        title = "lpf_rat_pb"
    else:
        if freq_domain: plt.xlim(0, 0.5 * fs)
        title = "lpf_rat"
    plt.grid(True)
    plt.title(title)
    #plt.show()
    fig = plt.gcf()
    title += ".png"
    title = Path(__file__).parent / title
    fig.savefig(title, dpi = 200)

def generate_header_file(output_path, total_num_taps = NUM_TAPS, num_phases_ds = NUM_PHASES_DS, num_phases_us = NUM_PHASES_US, filename = None):
    header_template = """\
// Copyright 2023 XMOS LIMITED.
// This Software is subject to the terms of the XCORE VocalFusion Licence.

/*********************************/
/* AUTOGENERATED. DO NOT MODIFY! */
/*********************************/

// Use src_rat_fir_gen.py script to regenare this file
// python src_rat_fir_gen.py  -o <outpath> -gc True -nt 96

#ifndef _SRC_RAT_COEFS_H_
#define _SRC_RAT_COEFS_H_

#include <stdint.h>

#ifndef ALIGNMENT
#  ifdef __xcore__
#    define ALIGNMENT(N)  __attribute__((aligned (N)))
#  else
#    define ALIGNMENT(N)
#  endif
#endif

#define SRC_RAT_FIR_NUM_TAPS (%(num_taps)s)
#define SRC_RAT_FIR_NUM_PHASES_DS (%(phases_ds)s)
#define SRC_RAT_FIR_TAPS_PER_PHASE_DS (%(taps_per_phase_ds)s)
#define SRC_RAT_FIR_NUM_PHASES_US (%(phases_us)s)
#define SRC_RAT_FIR_TAPS_PER_PHASE_US (%(taps_per_phase_us)s)

/** q30 coefficients to use for the 48 -> 32 kHz polyphase rational factor downsampling */
extern const int32_t src_rat_fir_ds_coefs[SRC_RAT_FIR_NUM_PHASES_DS][SRC_RAT_FIR_TAPS_PER_PHASE_DS];

/** q30 coefficients to use for the 32 -> 48 kHz polyphase rational factor upsampling */
extern const int32_t src_rat_fir_us_coefs[SRC_RAT_FIR_NUM_PHASES_US][SRC_RAT_FIR_TAPS_PER_PHASE_US];

#endif // _SRC_RAT_COEFS_H_

"""

    filename = "src_rat_fir_coefs.h"
    header_path = Path(output_path) / filename

    tph_ds = total_num_taps // num_phases_ds
    tph_us = total_num_taps // num_phases_us

    with open(header_path, "w") as header_file:
        header_file.writelines(header_template % {
                                        'num_taps':total_num_taps,
                                        'phases_ds':num_phases_ds,
                                        'taps_per_phase_ds':tph_ds,
                                        'phases_us':num_phases_us,
                                        'taps_per_phase_us':tph_us})

def generate_c_file(output_path, mixed_taps_ds, mixed_taps_us, total_num_taps = NUM_TAPS,
                    num_phases_ds = NUM_PHASES_DS, num_phases_us = NUM_PHASES_US):
    c_template = """\
// Copyright 2023 XMOS LIMITED.
// This Software is subject to the terms of the XCORE VocalFusion Licence.

/*********************************/
/* AUTOGENERATED. DO NOT MODIFY! */
/*********************************/

// Use src_rat_fir_gen.py script to regenare this file
// python src_rat_fir_gen.py -o <outpath> -gc True -nt 96

#include "src_rat_fir_coefs.h"
#include <stdint.h>

/** q30 coefficients to use for the 48 -> 32 kHz polyphase rational factor downsampling */
const int32_t ALIGNMENT(8) src_rat_fir_ds_coefs[SRC_RAT_FIR_NUM_PHASES_DS][SRC_RAT_FIR_TAPS_PER_PHASE_DS] = {
%(coefs_ds)s
};

/** q30 coefficients to use for the 32 -> 48 kHz polyphase rational factor upsampling */
const int32_t ALIGNMENT(8) src_rat_fir_us_coefs[SRC_RAT_FIR_NUM_PHASES_US][SRC_RAT_FIR_TAPS_PER_PHASE_US] = {
%(coefs_us)s
};

"""
    tph_ds = total_num_taps // num_phases_ds
    tph_us = total_num_taps // num_phases_us

    coefs_ds = ''

    for phase in range(num_phases_ds):
        coefs_ds += '    {\n    '
        for tap in range(tph_ds):
            coefs_ds += ' ' + str(mixed_taps_ds[phase][tap]).rjust(12)  + ','
            if(((tap + 1) % 6) == 0):
                coefs_ds += '\n    '
        coefs_ds += '},\n'

    coefs_us = ''

    for phase in range(num_phases_us):
        coefs_us += '    {\n    '
        for tap in range(tph_us):
            coefs_us += ' ' + str(mixed_taps_us[phase][tap]).rjust(12)  + ','
            if(((tap + 1) % 6) == 0):
                coefs_us += '\n    '
        coefs_us += '},\n'

    filename = "src_rat_fir_coefs.c"
    c_path = Path(output_path) / filename

    with open(c_path, "w") as c_file:
        c_file.writelines(c_template % {
                                    'coefs_ds':coefs_ds,
                                    'coefs_us':coefs_us})

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Generate FIR coefficiens for a 48 - 32 kHz polyphase SRC")
    parser.add_argument('--output_dir','-o', help='output path for filter files')
    parser.add_argument('--gen_c_files','-gc', help='Generate .h and .c files', action='store_true')
    parser.add_argument('--gen_plots', '-gp', help='Generate .png files', action='store_true')
    parser.add_argument('--num_taps', '-nt', help='number of taps', default=NUM_TAPS, type=int)
    args = parser.parse_args()

    print(f"Running In src_ff3_fir_gen.py. num_taps = {args.num_taps}")

    taps, poly_ds, poly_ds_int, poly_us, poly_us_int = gen_coefs(args.num_taps)

    if args.gen_c_files:
        Path(args.output_dir).mkdir(exist_ok=True, parents=True)
        generate_header_file(args.output_dir, args.num_taps)
        generate_c_file(args.output_dir, poly_ds_int, poly_us_int, args.num_taps)
    if args.gen_plots:
        plot_response(taps, False)
        plot_response(taps, True)
