// Copyright 2025 XMOS LIMITED.
// This Software is subject to the terms of the XMOS Public Licence: Version 1.

#include <stdbool.h>
#include <stdint.h>
#include <string.h>
#include <assert.h>

#include <xs1.h>
#include <platform.h>

#include <xcore/chanend.h>
#include <xcore/channel.h>
#include <xcore/select.h>
#include <xcore/parallel.h>

#include "some_ai.hpp"
#include "config.h"

#include "kws_labels.h"
#include "kwd_spotter.hpp"

#define AI_SAMPLE_FREQUENCY (16000)
#define SAMPLE_COUNT        (2)
#define THRESHOLD           (80)

DECLARE_JOB(some_ai_collect, (chanend_t, chanend_t));
DECLARE_JOB(some_ai_process, (chanend_t, chanend_t));


void some_ai_collect(chanend_t c_ai, chanend_t c_collect) {
    // Define buffers and index variables
    int32_t samples[SAMPLE_COUNT][CHANS_PER_FRAME] = {{0}};
    unsigned samples_idx = 0;
    unsigned channel_idx = 0;

    int32_t frame[KWD_FRAME_SIZE] = {0};
    unsigned frame_idx = 0;
    
    SELECT_RES(CASE_THEN(c_ai, event_ai_chanend))
    {
        event_ai_chanend:{
            samples[samples_idx][channel_idx] = chan_in_word(c_ai);
            channel_idx = (channel_idx + 1) % CHANS_PER_FRAME;
            if (channel_idx == 0) {
                samples_idx = (samples_idx + 1) % SAMPLE_COUNT;
                if (samples_idx == 0) {
                    // get decimated mono sample
                    int32_t sample = (samples[0][0] >> 1) + (samples[1][0] >> 1);
                    frame[frame_idx++] = sample;

                    if (N_SAMPLES <= frame_idx) { // frame collected
                        frame_idx = 0;
                        chan_out_buf_word(c_collect, (uint32_t *)frame, N_SAMPLES);
                    } // frame collected
                }
            }
            continue;
        } // event_ai_chanend
    } // select
}

void some_ai_process(chanend_t c_process, chanend_t c_gpo) {

    kwd_spotter_state_t state = KWD_SPOTTER_NOT_READY;
    kwd_spotter_ctx_t ctx = {{0}};

    state = kwd_spotter_init(&ctx);
    assert(state == KWD_SPOTTER_INIT_DONE);

    uint32_t *fft_frame_ptr = ctx.kwd_frame_ptr;

    kwd_label_indices_t keyword_detected = UNKNOWN;

    SELECT_RES(CASE_THEN(c_process, event_frame_received))
    {
        event_frame_received:{
            chan_in_buf_word(c_process, fft_frame_ptr, N_SAMPLES);
            state = kwd_spotter_compute(&ctx);
            if (state == KWD_SPOTTER_DONE){
                if (ctx.kwd_output[LEFT] > THRESHOLD){
                    keyword_detected = LEFT;
                }
                else if (ctx.kwd_output[RIGHT] > THRESHOLD){
                    keyword_detected = RIGHT;
                }
                else{
                    keyword_detected = UNKNOWN;
                }
                chan_out_word(c_gpo, keyword_detected); // Notify GPO task
            }
            continue;
        } // event_frame_received
    } // select
}

// C wrapper 
void some_ai(chanend_t c_ai,  chanend_t c_gpo) {
    assert((SAMPLE_FREQUENCY / AI_SAMPLE_FREQUENCY == 2) && (SAMPLE_FREQUENCY % AI_SAMPLE_FREQUENCY == 0));
    channel_t c_task = chan_alloc();
    PAR_JOBS(
        PJOB(some_ai_collect, (c_ai, c_task.end_a)),
        PJOB(some_ai_process, (c_task.end_b, c_gpo))
    );
    chan_free(c_task);
} // End of some_ai
