#ifndef RELAXATION_H
#define RELAXATION_H

#include "inverseIterative.h"
#include "dynamicKdTree.h"
#include "project.h"
#include <time.h>

void relaxPoints(double* points, double* normals, int numPoints, int dim, double sigmaP, double sigmaN, int max_num_neighbors, int maxIteration, bool sharpFeatures)
{
    
    //Normalize data
	double* pointsNormals = new double[dim*numPoints*2];
    for (int i = 0; i < numPoints; i++)
    {
		for (int j = 0; j < dim; j++)
		{
			//points[i*dim + j] /= sigmaP;
			pointsNormals[i*dim*2 + j] = points[i*dim + j] / sigmaP;
		}

		for (int j = 0; j < dim; j++)
		{
			//normals[i*dim + j] /= sigmaN;
			pointsNormals[i*dim*2 + j + dim] = normals[i*dim + j] / sigmaN;
		}
    }

    //Vars
    double gaussianStd = 2.5;
    double gaussianFactor = 1;
    double* point = new double[dim];
	double* normal = new double[dim];
	double* pointNormal = new double[dim*2];
    double* neighboringPoint = new double[dim];
    

    //Kd-Tree Init
    int* indicesData = new int[numPoints];
	
    kdtree * kdTree = kd_create( dim );

    for (unsigned int i = 0; i < numPoints; i++)
    {
        indicesData[i] = i;

		for (int j = 0; j < dim; j++)
			point[j] = points[i*dim + j];

        kd_insert( kdTree, point, &indicesData[i] );
    }


    //Get neighbors
    struct kdres * neighborIndicesDataRaw;
    int* neighborIndices = new int[numPoints*max_num_neighbors];
    int neighborIndex;
    unsigned int numNeighborIndices;

    for (unsigned int i = 0; i < numPoints; i++)
    {
		for (int j = 0; j < dim; j++)
			point[j] = points[i*dim + j];
        
        neighborIndicesDataRaw = kd_nearest_range(kdTree, point, gaussianStd*gaussianFactor*sigmaP);
        
        numNeighborIndices = 0;
        
        while(!kd_res_end(neighborIndicesDataRaw))
        {
            neighborIndex = *(int*)kd_res_item(neighborIndicesDataRaw, neighboringPoint);
            neighborIndices[i*max_num_neighbors + numNeighborIndices] = neighborIndex;
            numNeighborIndices++;
            kd_res_next(neighborIndicesDataRaw);
        }
        kd_res_free( neighborIndicesDataRaw );

        for (unsigned int j = 0; j < max_num_neighbors - numNeighborIndices; j++)
        {
            neighborIndices[i*max_num_neighbors + numNeighborIndices] = -1;
            numNeighborIndices++;
        }
    }



    //Vars
    double th = 0.001;
    unsigned int maxGradientDescendIteration = 1;
    //unsigned int maxIteration = 20;
    gaussianFactor = 1;
    double radius = gaussianStd * gaussianStd * gaussianFactor * gaussianFactor;
    double iterationError = 0;
    
    double* a = new double[max_num_neighbors];
    int* aIndices = new int[max_num_neighbors];
    //unsigned int k;
	unsigned int dummy;
    double* K = new double[max_num_neighbors*max_num_neighbors];
    double* kx = new double[max_num_neighbors]; 
    double p;
    
    int* neighbors = new int[max_num_neighbors];
    double* neighborPoints = new double[max_num_neighbors*dim];
	double* neighborNormals = new double[max_num_neighbors*dim];
	double* neighborPointsNormals = new double[max_num_neighbors*dim*2];
    int numNeighbors;
    
    double* gradient = new double[2*dim];
    double* xka = new double[2*dim];
    
    clock_t start = clock();
    //printf("%f\n", start);
    //Algorithm
	        
    for (unsigned int iteration = 0; iteration < maxIteration; iteration++)
    {
        iterationError = 0;
        for (unsigned int i = 0;i < numPoints; i++)
        {
			for (int j = 0; j < dim; j++)
				point[j] = pointsNormals[i*dim*2 + j];

			for (int j = 0; j < dim; j++)
				pointNormal[j] = pointsNormals[i*dim*2 + j];

			for (int j = 0; j < dim; j++)
				pointNormal[j + dim] = pointsNormals[i*dim*2 + dim + j];
        
            //Get neighbors within radius
            numNeighbors = 0;
            for (unsigned int j = 0; j < max_num_neighbors; j++)
            {
				if (neighborIndices[i*max_num_neighbors + j] == -1)
					break;

				for (int k = 0; k < dim; k++)
					neighboringPoint[k] = pointsNormals[neighborIndices[i*max_num_neighbors + j]*dim*2 + k];

                //if (sqDistVect(point, neighboringPoint, dim) < radius && neighborIndices[i*max_num_neighbors + j] != i)
				if (neighborIndices[i*max_num_neighbors + j] != i)
                {
                    neighbors[numNeighbors] = neighborIndices[i*max_num_neighbors + j];
                    for (int k = 0; k < dim; k++)
						neighborPoints[dim*numNeighbors + k] = points[dim*neighborIndices[i*max_num_neighbors + j] + k];
                    for (int k = 0; k < dim; k++)
						neighborNormals[dim*numNeighbors + k] = normals[dim*neighborIndices[i*max_num_neighbors + j] + k];
					for (int k = 0; k < dim; k++)
						neighborPointsNormals[dim*numNeighbors*2 + k] = pointsNormals[dim*neighborIndices[i*max_num_neighbors + j]*2 + k];
					for (int k = 0; k < dim; k++)
						neighborPointsNormals[dim*numNeighbors*2 + dim + k] = pointsNormals[dim*neighborIndices[i*max_num_neighbors + j]*2 + dim + k];
                    numNeighbors++;
                }   
            }

            //printf("%d\n", i);
            //Get Inverse K
			 if (numNeighbors == 0)
				 continue;
			
			//printf("%d\n", i);
            p = takeInverseIterative(pointNormal, 0, neighborPointsNormals, numNeighbors, dim*2, -1, a, aIndices, &dummy, K, kx);
            
            //Do gradient descend



            for (unsigned int gradientDescendIteration = 0; gradientDescendIteration < maxGradientDescendIteration; gradientDescendIteration++)
            {

				
				for (int k = 0; k < dim*2; k++)
					xka[k] = 0;
                for (int j = 0; j < numNeighbors; j++)
                {
					for (int k = 0; k < dim*2; k++)
						xka[k] += neighborPointsNormals[j*dim*2 + k]*kx[j]*a[j];
                }
                
				for (int k = 0; k < dim*2; k++)
					gradient[k] = 2*pointNormal[k]*p - 2*xka[k];
				

                //Descend
				for (int k = 0; k < dim; k++)
				{
					pointNormal[k] = pointNormal[k] + gradient[k];
					point[k] = pointNormal[k]*sigmaP;
				}
				for (int k = 0; k < dim; k++)
				{
					pointNormal[k + dim] = pointNormal[k + dim] + gradient[k + dim];
					normal[k] = pointNormal[k + dim]*sigmaN;
				}
	
            }
           
            for (int k = 0; k < dim; k++)
					neighborPoints[dim*numNeighbors + k] = points[dim*i + k];
            for (int k = 0; k < dim; k++)
					neighborNormals[dim*numNeighbors + k] = normals[dim*i + k];
			numNeighbors++;

			project(point, normal, sigmaP, sigmaN, neighborPoints, neighborNormals, numNeighbors, dim, sharpFeatures);

			for (int k = 0; k < dim; k++)
			{
				iterationError = iterationError + fabs(pointsNormals[i*dim*2 + k] - point[k] / sigmaP);
				pointsNormals[i*dim*2 + k] = point[k] / sigmaP;
				//points[i*dim + k] = point[k];
			}

			for (int k = 0; k < dim; k++)
			{
				iterationError = iterationError + fabs(pointsNormals[i*dim*2 + k + dim] - normal[k] / sigmaN);
				pointsNormals[i*dim*2 + k + dim] = normal[k] / sigmaN;
				//normals[i*dim + k] = normal[k];
			}
            
			//printf("%d\n", i);
        }
        
        printf("%f\n", iterationError / numPoints);
    }
	
    clock_t end = clock();
    printf("%f\n", static_cast<double>(end-start)/CLOCKS_PER_SEC);
    
    
    //DeNormalize data
    for (int i = 0; i < numPoints; i++)
    {
		for (int j = 0; j < dim; j++)
			points[i*dim + j] = pointsNormals[i*dim*2 + j]*sigmaP;

		for (int j = 0; j < dim; j++)
			normals[i*dim + j] = pointsNormals[i*dim*2 + dim + j]*sigmaN;
    }
 

    //Clean
    delete[] indicesData;
    kd_free( kdTree );
    delete[] point;
	delete[] normal;
    delete[] neighborIndices;
    delete[] neighboringPoint;
    delete[] neighborPoints;
	delete[] neighborNormals;
	delete[] neighborPointsNormals;
	delete[] pointNormal;
	delete[] pointsNormals;
    delete[] a;
    delete[] aIndices;
    delete[] K;
    delete[] kx;
    delete[] neighbors;
	delete[] gradient;
	delete[] xka;
   
}


				/*
                //Compute kx
                for (int j = 0; j < numNeighbors; j++)
                {
                    neighboringPoint[0] = points[neighbors[j]*2];
                    neighboringPoint[1] = points[neighbors[j]*2 + 1];
                    kx[j] = exp(-sqDistVect(point, neighboringPoint, 2));
                }
                
                //Compute gradient
              
                //a
                for (int j = 0; j < numNeighbors; j++)
                {
                    a[j] = 0;
                    for (int jj = 0; jj < numNeighbors; jj++)
                        a[j] = a[j] + K[j*numNeighbors + jj]*kx[jj];
                }
                
                //p
                p = 0;
                for (int j = 0; j < numNeighbors; j++)
                    p = p + kx[j]*a[j];
                */

#endif