#include "fft.h"

#ifdef DEBUG_FFT
static void print_vector(float *vect, char * filename, unsigned int length);
#endif

/**
 * DFT matrix exponents factor.
 * Definition of the common factor the arguments in the inverse DFT.
 */
#define ARG_CONST ((float)M_PI*2/UPS/RESX)

/* Allocate a real vector of length n and return its pointer.
 */
float * fft_malloc_r(unsigned int n) {
	return (float *) fftwf_malloc(sizeof(float) * n);
}

/* Allocate a complex vector of length n and return its pointer.
 */
fftwf_complex * fft_malloc_c(unsigned int n) {
	return (fftwf_complex *) fftwf_malloc(sizeof(fftwf_complex) * n);
}

/* Free an allocated vector.
 */
void fft_free(void ** v) {
	if (*v) {
		fftwf_free(*v);
		*v = NULL;
	}
}

/* Return the fftwf plan for the forward fft.
 * Wrapper of the fftwf function fftwf_plan_dft_r2c_1d.
 * Since the input is real the r2c plan is used because it is faster with
 * respect to the complex to complex transform.
 * Note that the input vector is not preserved.
 * The size of the vectors RESX is defined in common.h
 */
fftwf_plan fft_init_fw(float *in, fftwf_complex *out) {
	return fftwf_plan_dft_r2c_1d(RESX, in, out, FFTW_PATIENT);
}

/* Return the fftwf plan for the inverse fft.
 * Wrapper of the fftwf function fftwf_plan_dft_c2r_1d.
 * Since the input is real and the input complex vector in the frequency domain
 * satisfies the Hermitian condition, the r2c plan is used because it is faster
 * with respect to the complex to complex transform.
 * By default the inputs are not preserved, so FFTW_PRESERVE_INPUT is needed.
 * Fill the input vector with zeros if ZPUPS > 1. Since the vector is
 * preserved there is no need to write the zeros at each iteration.
 * The size of the image vectors RESX and upscaling UPS are defined in common.h
 */
fftwf_plan fft_init_bw(fftwf_complex *in, float *out) {
	fftwf_plan plan;
	plan = fftwf_plan_dft_c2r_1d(ZPRU, in, out, FFTW_PATIENT | FFTW_PRESERVE_INPUT);

#if ZPUPS > 1
	// Not sure if the plan creation destroys the input vector, to be safe
	// write the zero-padding zeros after plan creation.
	for (unsigned int i = RESX+1; i < ZPRU2; i++) {
		in[i][0] = 0;
		in[i][1] = 0;
	}
#endif

	return plan;
}

/* Given the fourier transform of vectors of the two images compute the cross
 * correlation in the frequency domain.
 * Since the two image vectors satisfy the Hermitian condition, then the cross
 * correlation satisfies it too. For this reason only half of the vector is
 * computed.
 * Apply the zero padding when needed.
 */ 
void fft_xcorr(fftwf_complex *v1, fftwf_complex *v2, fftwf_complex *out) {
	unsigned int i;

	// Compute the cross correlation of the first half of the vector
	// The length of the vector is RESX/2+1
	for (i = 0; i < RESX2 + 1; i++) {
		// Real part
		out[i][0] = v1[i][0] * v2[i][0] + v1[i][1] * v2[i][1];
		// Imaginary part
		out[i][1] = v1[i][1] * v2[i][0] - v1[i][0] * v2[i][1];
	}

#if (ZPUPS > 1)
	// Zero padding, the vector is preserved, no need to write all the zeros at
	// each iteration
	out[ZPRU2][0] = out[RESX2][0];
	out[ZPRU2][1] = out[RESX2][1];
	out[RESX2][0] = 0;
	out[RESX2][1] = 0;
#endif
}

/* Find the index of the maximum and compute the distance from the center of
 * the image, which corresponds to the displacement.
 * The cross correlation vector in space has the first half and the second half
 * swapped because of the FFT and the cross correlation, this is taken into
 * account computing the index (without reordering the vector).
 * If there are 2 points with the same max value next to each other then
 * average of their position is returned. If there are more than 2 points with
 * the same max value or they are not adjacent, then the new position can not
 * be computed. This condition is really rare and usually due to poor lighting
 * or focus, or if the camera is moved too fast and the image is blurred or
 * distorted.
 * The range considered is ZPRU/2 - deltarange: ZPRU/2 + deltarange because
 * considering the whole ZPRU range can lead to error due to the cyclic nature
 * of the cross correlation (the wrong peak is returned).
 * This also means that the maximum speed is reduced, in fact if the camera
 * moves further than deltarange - deltath between two frames the obtained
 * displacement is wrong and smaller than the correct one.
 */
int fft_displacement(float *in, unsigned int deltarange, float *index) {
	
	// First and last index of the max interval, all the points in the interval
	// must have the same max value
	unsigned int first_index, last_index;

	// Max value and pointer to the input vector
	float max, *inptr;

	// Flag used to signal if there are two different max values
	char many_peak;

	// Initialize the pointer to the input vector, start at ZPRU - deltarange due to
	// fftshift (negative positions are in the second half of the vector, while
	// the positive positions are in the first half).  
	inptr = in + ZPRU - deltarange;

#ifdef DEBUG_FFT
	volatile unsigned int indexes[3] = {0, 0, 0};
#endif

	// Find index of the max in the vector
	max = 0;
	first_index = 0;
	last_index = 0;
	many_peak = 0;
	for (int i = 0; i < (deltarange << 1); i++) {

		// New max found
		if (*inptr > max) {
			max = *inptr;
			first_index = i;
			last_index = i;

			// Clear the flag, the invalid points which were found were not
			// peaks
			many_peak= 0;

		// Another element with the same max value found
		} else if (*inptr == max) { 

			// Non-adjacent peak
			if (i > first_index + 1) {
				many_peak = 1;

#ifdef DEBUG_FFT
				// Store the indexes
				indexes[0] = first_index;
				indexes[1] = last_index;
				indexes[2] = i;
#endif

			// Adjacent temporary max found
			} else {
				last_index = i;
			}
		}

		// After deltarange items move the pointer at the beginning of the input
		// vector, at the beginning of positive positions, else just increment
		if (i == deltarange - 1) 
			inptr = in;
		else 
			inptr++;
	}

	// If the flag is set after the loop terminated, then the couple invalid
	// values found were in fact separated max values and the displacement can
	// not be evaluated
	if (many_peak) {

#ifdef DEBUG_FFT
		// Print the max interval and the not allowed value
		fprintf(stderr, "Multiple peak indexes: %d %d %d\n", indexes[0], indexes[1], indexes[2]);
#endif

		return 1;
	}

	// Compute the effective index using the average between the first index
	// and the last one. Then subtract deltarange * 2 to evaluate the displacement.
	*index= (float)((int)(last_index + first_index) - (int)(deltarange << 1)) * 0.5;

	return 0;
}

/* Execute the inverse fft using the given fftwf plan
 * Wrapper of the fftwf function fftwf_execute
 */
void fft_execute_bw(fftwf_plan p) {
	fftwf_execute(p);
}

/* Execute the fft using the given fftwf plan
 * Wrapper of the fftwf function fftwf_execute_dft
 */
void fft_execute_fw(fftwf_plan p, float *in, fftwf_complex *out) {
	fftwf_execute_dft_r2c(p, in, out);
}

/* Inverse DFT, compute the upsampled displacement in the window around the
 * peak.
 * The DFT is computed as a sum of products using its definition (or row by
 * column product if we think in terms of DFT matrix).
 * In the loop to compute one element of the DFT has only RESX/2 iterations
 * instead of RESX because it exploits the hermitian condition which is
 * satisfied both by the cross correlation in freq domain and by the row of the
 * DFT matrix (the phase terms of the DFT definition). 
 * Another optimization is using the sine and cosine addition formulas to
 * compute the real and imaginary parts of the phase terms. This is allowed by
 * the fact that the argument is incremented between the elements by a fixed
 * amount. 
 */
int fft_dft_bw(fftwf_complex *cross, float delta, float * win_delta) {

	float tmp, arg_inc, shift_fix, rc_product, max; 
	float phase_term[2], phase_term_inc[2] ;
	unsigned int m, k, id, id_end;
	fftwf_complex *cc;

#ifdef DEBUG_FFT
	float v[WIN]; 
	volatile float zr[RESX], zi[RESX]; 
#endif

	// Part of the argument common to all the elements
	shift_fix = delta * UPS - WIN/2;

	// Initialize the variables used to find the max of the spacial cross
	// correlation
	max = 0;
	id = 0;
	id_end = 0;

	// Main loop, iterate between the elements of the output spacial cross correlation
	for (m = 0; m < WIN; m++) {

		// Initialize the row by column product (sum of products) used to
		// compute the DFT.
		// Notice that the result is real, so the imaginary part is not computed.
		rc_product = 0;

		// Argument increment between the phase terms of the DFT
		arg_inc = ARG_CONST * (shift_fix + m);

		// Argument increment written as real and imaginary part
		phase_term_inc[0] = cos(arg_inc);
		phase_term_inc[1] = sin(arg_inc);

		// Initial phase term
		phase_term[0] = 1;
		phase_term[1] = 0;

		// Pointer to the second cross correlation element (skip the dc
		// component for now)
		cc = cross + 1;
		
		// Loop half of the elements of the DFT, compute the sum of products
		// between the cross correlation and the phase terms
		for (k = 0; k < RESX2 - 1; k++) {

			// Phase terms computed using sine and cosine addition formulas,
			// incremented each iteration by the same amount.
			tmp = phase_term[0] * phase_term_inc[0] - 
				phase_term[1] * phase_term_inc[1];
			phase_term[1] =	phase_term[0] * phase_term_inc[1] + 
				phase_term[1] * phase_term_inc[0];
			phase_term[0] =	tmp;

			// Sum of the product between the cross correlation and the phase
			// term, DFT definition
			rc_product += (*cc)[0] * phase_term[0] - (*cc)[1] * phase_term[1];

			// Increment cross correlation pointer
			cc++;

#ifdef DEBUG_FFT
			// Save in a vector real and imaginary part of the computed phase
			// terms (half DFT matrix)
			zr[k] = phase_term[0];
			zi[k] = phase_term[1];
#endif
		}

		// Exploit hermitian redundancy to compute the complete row by column
		// product.
		rc_product *= 2;

		// Add the missing terms which are not replicated in the vector.
		// Phase term -RESX/2, which is the complex conjugate of RESX/2 (hence
		// the sum instead of the subtraction).
		rc_product += cross[ZPRU2][0] * (phase_term[0] * phase_term_inc[0] + 
			phase_term[1] * phase_term_inc[1]);

		// The dc component phase term is 1
		rc_product += cross[0][0];

#ifdef DEBUG_FFT
		// Save the computed spatial DFT element in a vector 
		v[m] = rc_product;
#endif

		// Search for the peak while computing spacial cross correlation elements
		// In this case it quite common to have multiple adjacent elements with
		// the same max value, for this reason the position of the max is
		// evaluated as the mean between the position of the first and the last
		// max elements.
		if (rc_product > max) {
			max = rc_product;
			id = m;
			id_end = m;
		} else if (rc_product == max) {
			id_end = m;
		}

	}
#ifdef DEBUG_FFT
	print_vector(v, "cross_win", WIN);
#endif

	// Return displacement, which is the position of the peak, computed as the
	// average between the position of the first and the last elements with max
	// value, with respect to the center of the vector
	*win_delta = (float)((int)(id + id_end) - (int)WIN) * 0.5;

	// The peak is not centered in the window, it is at one of the extremes. 
	if ((id == 0 || id_end == WIN-1) && id == id_end) {
		return 1;
	}

	return 0;
}


/* Destroy a fftwf plan 
 * Wrapper of the fftwf function fftwf_destroy_plan
 */
void fft_destroy(fftwf_plan *p) {
	if (p) {
		fftwf_destroy_plan(*p);
		*p = NULL;
	}
}

/* Read fftwf wisdom, wrapper of fftwf function fftwf_import_wisdom_from_filename.
 * Filename defined in config.h
 */
int fft_read_wisdom() {
	return fftwf_import_wisdom_from_filename(WISDOM);
}

/* Write fftwf wisdom, wrapper of fftwf function fftwf_export_wisdom_to_filename.
 * Filename defined in config.h
 */
int fft_save_wisdom() {
	return fftwf_export_wisdom_to_filename(WISDOM);
}

#ifdef DEBUG_FFT
// Debug, save vector to file
static void print_vector(float *vect, char * filename, unsigned int length){
	char str[64];
	FILE *fp;
	sprintf(str, "%s/%s", DEBUGDIR, filename);
	fp = fopen(str, "w+");
	for (int i = 0; i < length; i++) {
		fprintf(fp, "%lf\n", vect[i]);
	}
	fclose(fp);
}
#endif
