#include "StereoSolverImpls.h"
#include "CudaHelperCommon.cuh"

StdStereoSolver::~StdStereoSolver()
{
	// Ensure cuda memory is freed up (this should happen as the variables go out of scope anyway)
	gpuImLeft.Destroy();
	gpuImRight.Destroy();
	gpuCostGrid.Destroy();
	gpuDepthMap.Destroy();

	if(preprocessNode != NULL)
		delete preprocessNode;
	if(costNode != NULL)
		delete costNode;
	if(aggregateNode != NULL)
		delete aggregateNode;
	if(optNode != NULL)
		delete optNode;
	if(wtaNode != NULL)
		delete wtaNode;
}

#pragma region Node Creation

PreProcStereoNode* StdStereoSolver::CreatePreNode(const StereoNodeFactory<PreProcStereoNode>* factory)
{
	if(preprocessNode != NULL)
	{
		delete preprocessNode;
		preprocessNode = NULL;
	}

	if(factory != NULL)
		preprocessNode = factory->Create(0);
	else preprocessNode = NULL;

	return preprocessNode;
}
CostComputeStereoNode* StdStereoSolver::CreateCostNode(const StereoNodeFactory<CostComputeStereoNode>* factory)
{
	if(costNode != NULL)
	{
		delete costNode;
		costNode = NULL;
	}

	if(factory != NULL)
		costNode = factory->Create(0);
	else costNode = NULL;

	return costNode;
}
CostAggregateStereoNode* StdStereoSolver::CreateAggregateNode(const StereoNodeFactory<CostAggregateStereoNode>* factory)
{
	if(aggregateNode != NULL)
	{
		delete aggregateNode;
		aggregateNode = NULL;
	}

	if(factory != NULL)
		aggregateNode = factory->Create(0);
	else aggregateNode = NULL;

	return aggregateNode;
}
GlobalOptStereoNode* StdStereoSolver::CreateOptNode(const StereoNodeFactory<GlobalOptStereoNode>* factory)
{
	if(optNode != NULL)
	{
		delete optNode;
		optNode = NULL;
	}

	if(factory != NULL)
		optNode = factory->Create(0);
	else optNode = NULL;

	return optNode;
}
WtaSolveStereoNode* StdStereoSolver::CreateWtaNode(const StereoNodeFactory<WtaSolveStereoNode>* factory)
{
	if(wtaNode != NULL)
	{
		delete wtaNode;
		wtaNode = NULL;
	}

	if(factory != NULL)
		wtaNode = factory->Create(0);
	else wtaNode = NULL;

	return wtaNode;
}
#pragma endregion

void StdStereoSolver::FindDisparityMap(unsigned int* leftIm, unsigned int* rightIm, int width, int height, int nDisparities, float* depthMap)
{
	if(width <= 0 || height <= 0 || nDisparities <= 0)
		return;

	// The cost computation and wta nodes are essential
	if(costNode != NULL && wtaNode != NULL)
	{
		// Create/re-create the inter-node data structures as required

		if(width != gpuImLeft.GetWidth() || height != gpuImLeft.GetHeight())
		{
			gpuImLeft.Create(RGB_IMAGE_TYPE_XRGB32, width, height);
			gpuImRight.Create(RGB_IMAGE_TYPE_XRGB32, width, height);
			gpuDepthMap.Create(DEPTH_MAP_TYPE_FLOAT, width, height);
		}
		if(width != gpuCostGrid.GetWidth() || height != gpuCostGrid.GetHeight() || nDisparities != gpuCostGrid.GetDepth())
			gpuCostGrid.Create(costNode->GetCostSpaceType(), width, height, nDisparities);

		// Copy in input data
		CUDA_CALL(cudaMemcpy2D(gpuImLeft.GetImage(), gpuImLeft.GetPitch(), leftIm, width * sizeof(unsigned int), 
			width * sizeof(unsigned int), height, cudaMemcpyHostToDevice));
		CUDA_CALL(cudaMemcpy2D(gpuImRight.GetImage(), gpuImRight.GetPitch(), rightIm, width * sizeof(unsigned int), 
			width * sizeof(unsigned int), height, cudaMemcpyHostToDevice));

		// The actual processing

		if(preprocessNode != NULL)
			preprocessNode->PreProcessImages(gpuImLeft, gpuImRight);

		costNode->ComputeCosts(gpuImLeft, gpuImRight, gpuCostGrid);

		if(aggregateNode != NULL)
			aggregateNode->AggregateCosts(gpuImLeft, gpuImRight, gpuCostGrid);

		if(optNode != NULL)
			optNode->OptimizeSolution(gpuCostGrid);

		wtaNode->FindSolution(gpuCostGrid, gpuDepthMap);

		// Copy out output data
		CUDA_CALL(cudaMemcpy2D(depthMap, width * sizeof(float), gpuDepthMap.GetDepthMap(), gpuDepthMap.GetPitch(), 
			width * sizeof(float), height, cudaMemcpyDeviceToHost));
	}
}

void StdStereoSolver::ValidateNodeFactories(StereoNodeFactoryListNodeBase & factories)
{
	RgbImageType imageType = RGB_IMAGE_TYPE_XRGB32;

	CostSpaceGridType gridType = COST_SPACE_GRID_TYPE_ANY;
	if(costNode != NULL)
		gridType = costNode->GetCostSpaceType();

	DepthMapType depthType = DEPTH_MAP_TYPE_FLOAT;

	// Iterate through all of the factories, checking that they are compatible with the formats above
	for(StereoNodeFactoryListNodeBase* pNode = &factories; pNode != NULL; pNode = pNode->GetNext())
		pNode->GetFactory()->CheckValid(imageType, gridType, depthType);
}


/*************************************************** DoubleStereoSolver *************************************************************************/

DoubleStereoSolver::DoubleStereoSolver() : preprocessNode(0), lCostNode(0), rCostNode(0), lAggregateNode(0), rAggregateNode(0), 
	lOptNode(0), rOptNode(0), lWtaNode(0), rWtaNode(0), fOptNode(0), fWtaNode(0) { }

DoubleStereoSolver::~DoubleStereoSolver()
{
	// Ensure cuda memory is freed up (this should happen as the variables go out of scope anyway)
	gpuImLeft.Destroy();
	gpuImRight.Destroy();
	gpuCostGridLeft.Destroy();
	gpuCostGridRight.Destroy();
	gpuCostGridFinal.Destroy();
	gpuDepthMapLeft.Destroy();
	gpuDepthMapRight.Destroy();

	if(preprocessNode != NULL)
		delete preprocessNode;
	if(lCostNode != NULL)
		delete lCostNode;
	if(rCostNode != NULL)
		delete rCostNode;
	if(lAggregateNode != NULL)
		delete lAggregateNode;
	if(rAggregateNode != NULL)
		delete rAggregateNode;
	if(lOptNode != NULL)
		delete lOptNode;
	if(rOptNode != NULL)
		delete rOptNode;
	if(lWtaNode != NULL)
		delete lWtaNode;
	if(rWtaNode != NULL)
		delete rWtaNode;
	if(fOptNode != NULL)
		delete fOptNode;
	if(fWtaNode != NULL)
		delete fWtaNode;
}

#pragma region Node Creation

PreProcStereoNode* DoubleStereoSolver::CreatePreNode(const StereoNodeFactory<PreProcStereoNode>* factory)
{
	if(preprocessNode != NULL)
	{
		delete preprocessNode;
		preprocessNode = NULL;
	}

	if(factory != NULL)
		preprocessNode = factory->Create(0);
	else preprocessNode = NULL;

	return preprocessNode;
}
void DoubleStereoSolver::CreateCostNodes(const StereoNodeFactory<CostComputeStereoNode>* factory, CostComputeStereoNode** left, CostComputeStereoNode** right)
{
	if(lCostNode != NULL)
	{
		delete lCostNode;
		lCostNode = NULL;
	}
	if(rCostNode != NULL)
	{
		delete lCostNode;
		lCostNode = NULL;
	}

	if(factory != NULL)
	{
		*left = lCostNode = factory->Create(0);
		*right = rCostNode = factory->Create(0);
	}
	else *left = *right = lCostNode = rCostNode = NULL;
}
void DoubleStereoSolver::CreateAggregateNodes(const StereoNodeFactory<CostAggregateStereoNode>* factory, CostAggregateStereoNode** left, CostAggregateStereoNode** right)
{
	if(lAggregateNode != NULL)
	{
		delete lAggregateNode;
		lAggregateNode = NULL;
	}
	if(rAggregateNode != NULL)
	{
		delete rAggregateNode;
		rAggregateNode = NULL;
	}

	if(factory != NULL)
	{
		*left = lAggregateNode = factory->Create(0);
		*right = rAggregateNode = factory->Create(0);
	}
	else *left = *right = lAggregateNode = rAggregateNode = NULL;
}
void DoubleStereoSolver::CreateOptNodes(const StereoNodeFactory<GlobalOptStereoNode>* factory, GlobalOptStereoNode** left, GlobalOptStereoNode** right)
{
	if(lOptNode != NULL)
	{
		delete lOptNode;
		lOptNode = NULL;
	}
	if(rOptNode != NULL)
	{
		delete rOptNode;
		rOptNode = NULL;
	}

	if(factory != NULL)
	{
		*left = lOptNode = factory->Create(0);
		*right = rOptNode = factory->Create(0);
	}
	else *left = *right = lOptNode = rOptNode = NULL;
}
void DoubleStereoSolver::CreateWtaNodes(const StereoNodeFactory<WtaSolveStereoNode>* factory, WtaSolveStereoNode** left, WtaSolveStereoNode** right)
{
	if(lWtaNode != NULL)
	{
		delete lWtaNode;
		lWtaNode = NULL;
	}
	if(rWtaNode != NULL)
	{
		delete rWtaNode;
		rWtaNode = NULL;
	}

	if(factory != NULL)
	{
		*left = lWtaNode = factory->Create(0);
		*right = rWtaNode = factory->Create(0);
	}
	else *left = *right = lWtaNode = rWtaNode = NULL;
}
GuidedGlobalOptStereoNode* DoubleStereoSolver::CreateGuidedOptNode(const StereoNodeFactory<GuidedGlobalOptStereoNode>* factory)
{
	if(fOptNode != NULL)
	{
		delete fOptNode;
		fOptNode = NULL;
	}

	if(factory != NULL)
		fOptNode = factory->Create(0);
	else fOptNode = NULL;

	return fOptNode;
}
WtaSolveStereoNode* DoubleStereoSolver::CreateFinalWtaNode(const StereoNodeFactory<WtaSolveStereoNode>* factory)
{
	if(fWtaNode != NULL)
	{
		delete fWtaNode;
		fWtaNode = NULL;
	}

	if(factory != NULL)
		fWtaNode = factory->Create(0);
	else fWtaNode = NULL;

	return fWtaNode;
}
#pragma endregion

void DoubleStereoSolver::FindDisparityMap(unsigned int* leftIm, unsigned int* rightIm, int width, int height, int nDisparities, float* depthMap)
{
	
}
void DoubleStereoSolver::ValidateNodeFactories(StereoNodeFactoryListNodeBase & factories)
{
	RgbImageType imageType = RGB_IMAGE_TYPE_XRGB32;

	CostSpaceGridType gridType = COST_SPACE_GRID_TYPE_ANY;
	if(lCostNode != NULL)
		gridType = lCostNode->GetCostSpaceType();

	DepthMapType depthType = DEPTH_MAP_TYPE_FLOAT;

	// Iterate through all of the factories, checking that they are compatible with the formats above
	for(StereoNodeFactoryListNodeBase* pNode = &factories; pNode != NULL; pNode = pNode->GetNext())
		pNode->GetFactory()->CheckValid(imageType, gridType, depthType);
}