// name: konvolve.cpp
// desc: convolver: time-domain, time-domain optimized, fft

#include <stdlib.h>
#include <sndfile.h>
#include <math.h>

#include "chuck_fft.h"

#include <iostream>
#include <string>
using namespace std;

// our sample
#define SAMPLE float
#define MAX(a,b) (a > b ? a : b)
#define MIN(a,b) (a < b ? a : b)

// read data from file
SAMPLE * readData( const string & filename, int * size, int * srate )
{
    // handle
    SNDFILE * sf = NULL;
    // info
    SF_INFO info;
    // because the doc says
    info.format = 0;
    // ...
    SAMPLE * buffer = NULL;
    // zero out
    *size = 0;
    *srate = 0;
    
    // open it
    sf = sf_open( filename.c_str(), SFM_READ, &info );
    // check it
    if( !sf )
    {
        // error message
        cout << "error: cannot open '" << filename << "'" << endl;
        return NULL;
    }
    
    // make sure it's mono
    if( info.channels > 1 )
    {
        // error message
        cout << "error: '" << filename << "' is not MONO" << endl;
        goto done;
    }
    
    // allocate the whole thing!
    buffer = new SAMPLE[info.frames];
    // check it
    if( !buffer )
    {
        // error message
        cout << "error: out of memory... frak" << endl;
        goto done;
    }
    
    // read it
    if( sf_read_float( sf, buffer, info.frames ) != info.frames )
    {
        // error message
        cout << "error: can't read file..." << endl;
        // free
        delete [] buffer; buffer = NULL;
        goto done;
    }
    
    // set size
    *size = info.frames;
    // set srate
    *srate = info.samplerate;
    
done:
        // close sf
        if( sf ) sf_close( sf );
    
    return buffer;
}

// write data to file
bool writeData( const string & filename, SAMPLE * buffy, int size, int srate )
{
    // handle
    SNDFILE * sf = NULL;
    // info
    SF_INFO info;
    // result
    bool result = false;
    
    // fill in
    info.samplerate = srate;
    info.channels = 1;
    info.format = SF_FORMAT_WAV | SF_FORMAT_PCM_16;
    info.frames = 0;
    
    // open
    sf = sf_open( filename.c_str(), SFM_WRITE, &info );
    // check
    if( !sf )
    {
        // error message
        cout << "error: can't open '" << filename << "' for write" << endl;
        goto done;
    }
    
    // write it
    if( sf_write_float( sf, buffy, size ) != size )
    {
        // error message
        cout << "error: can't write to file..." << endl;
        goto done;
    }
    
    // set flag
    result = true;
    
done:
    if( sf ) sf_close( sf );
    
    return result;
}

// convolve
void convolve( SAMPLE * f, int fsize, SAMPLE * g, int gsize, SAMPLE * buffy, int size )
{
    // sanity check
    assert( (fsize + gsize - 1) == size );
    
    // clear out buffy
    memset( buffy, 0, sizeof(SAMPLE) * size );
    
    // loop
    for( int i = 0; i < fsize; i++ )
    {
        for( int j = 0; j < gsize; j++ )
            buffy[i+j] += f[i] * g[j];
        
        if( !(i%1000) ) cout << i << " / " << fsize << endl;
    }
}

// convolve with hand optimization
void convolve_unroll( SAMPLE * f, int fsize, 
                      SAMPLE * g, int gsize, 
                      SAMPLE * buffy, int size )
{
    // sanity check
    assert( (fsize + gsize - 1) == size );
    
    // clear out buffy
    memset( buffy, 0, sizeof(SAMPLE) * size );
    
    // loop
    SAMPLE fscale;
    SAMPLE * buf;
    SAMPLE * gbuf;
    for( int i = 0; i < fsize; i++ )
    {
        fscale = f[i];
        buf = buffy + i;
        gbuf = g;
        
        int ggsize = gsize / 4;
        int ggmod = gsize % 4;
        for( int j = 0; j < ggsize; j++ )
        {
            *buf++ += fscale * (*gbuf++);
            *buf++ += fscale * (*gbuf++);
            *buf++ += fscale * (*gbuf++);
            *buf++ += fscale * (*gbuf++);
        }
        for( int j = 0; j < ggmod; j++ )
            *buf++ += fscale * (*gbuf++);
                    
        if( !(i%1000) ) cout << i << " / " << fsize << endl;
    }
}

//-----------------------------------------------------------------------------
// name: next_power_2()
// desc: ...
// thanks: to Niklas Werner / music-dsp
//-----------------------------------------------------------------------------
unsigned long next_power_2( unsigned long n )
{
    unsigned long nn = n;
    for( ; n &= n-1; nn = n );
    return nn * 2;
}

// convolve in freq domain
void convolve_fft( SAMPLE * f, int fsize, SAMPLE * g, int gsize, SAMPLE * buffy, int size )
{
    // sanity check
    assert( (fsize + gsize - 1) == size );
    
    // make buffers to hold kernel and signal
    unsigned int fftsize = next_power_2( fsize + gsize - 1 );
    // do it
    SAMPLE * fbuf = new SAMPLE[fftsize];
    SAMPLE * gbuf = new SAMPLE[fftsize];
    SAMPLE * result = new SAMPLE[fftsize];
    // clear
    memset( fbuf, 0, sizeof(SAMPLE) * fftsize );
    memset( gbuf, 0, sizeof(SAMPLE) * fftsize );
    memset( result, 0, sizeof(SAMPLE) * fftsize );
    
    // copy in
    memcpy( fbuf, f, sizeof(SAMPLE) * fsize );
    memcpy( gbuf, g, sizeof(SAMPLE) * gsize );
    // take fft
    rfft( fbuf, fftsize/2, FFT_FORWARD );
    rfft( gbuf, fftsize/2, FFT_FORWARD );
    
    // complex
    complex * fcomp = (complex *)fbuf;
    complex * gcomp = (complex *)gbuf;
    complex * rcomp = (complex *)result;
    
    // variables
    float mag;
    float phase;
    // loop
    for( int i = 0; i < fftsize/2; i++ )
    {
        // multiple mag
        mag = cmp_abs(fcomp[i]) * cmp_abs(gcomp[i]);
        // add phase
        phase = cmp_phase(fcomp[i]) + cmp_phase(gcomp[i]);
        // back to rectangular
        rcomp[i].re = mag * ::cos( phase );
        rcomp[i].im = mag * ::sin( phase );
    }
    
    // invers fft
    rfft( result, fftsize/2, FFT_INVERSE );

    // copy into buffy
    memcpy( buffy, result, sizeof(SAMPLE) * size );
}


// normalize
void normalize( SAMPLE * buffy, int size, SAMPLE scale = 1.0f )
{
    SAMPLE max = 0;

    // loop over the signal
    for( int i = 0; i < size; i++ )
        if( fabs(buffy[i]) > max )
            max = fabs( buffy[i] );

    for( int i = 0; i < size; i++ )
        buffy[i] = (buffy[i] / max) * scale;
}

// entry point
int main( int argc, char ** argv )
{
    // check args
    if( argc != 4 )
    {
        // error message
        cout << "konvolve: not enough arguments" << endl;
        cout << "usage: konvolve [file1] [file2] [fileOut]" << endl;
        exit( 1 );
    }
    
    // go
    int size1 = 0;
    int srate1 = 0;
    string filename1 = argv[1];
    SAMPLE * buffer1 = readData( filename1, &size1, &srate1 );
    if( !buffer1 ) exit( 1 );
    
    // go
    int size2 = 0;
    int srate2 = 0;
    string filename2 = argv[2];
    SAMPLE * buffer2 = readData( filename2, &size2, &srate2 );
    if( !buffer2 ) exit( 1 );
    
    // size
    int size = size1 + size2 - 1;
    // allocate buffer, size is always fsize+gsize-1
    SAMPLE * buffy = new SAMPLE[size];
    
    // convolve
    convolve_fft( buffer1, size1, buffer2, size2, buffy, size );
    
    // normalize
    normalize( buffy, size, .95f );
    
    // write
    int srate = srate1;
    string filenameOut = argv[3];
    writeData( filenameOut, buffy, size, srate );
    
    return 0;
}
