#include "FlexibleDisparityEstimator.h"

using namespace Stereo::GpGpuLib;

FlexibleDisparityEstimator::FlexibleDisparityEstimator()
{
	solver = new StdStereoSolver();

	preFactories = GetPreFactoriesManaged();
	costFactories = GetCostFactoriesManaged();
	aggregateFactories = GetAggregateFactoriesManaged();
	optFactories = GetOptFactoriesManaged();
	wtaFactories = GetWtaFactoriesManaged();

	preIdx = costIdx = aggIdx = optIdx = wtaIdx = -1;

	if(costFactories != nullptr)
	{
		CostComputeStereoNode* costNode = solver->CreateCostNode(costFactories->Node->pFactory);
		if(costNode != NULL)
			costComputeTarget = costFactories->CreateParams(costNode);
		costIdx = 0;
	}
	if(wtaFactories != nullptr)
	{
		WtaSolveStereoNode* wtaNode = solver->CreateWtaNode(wtaFactories->Node->pFactory);
		if(wtaNode != NULL)
			wtaTarget = wtaFactories->CreateParams(wtaNode);
		wtaIdx = 0;
	}

	ValidateFactories();
}
FlexibleDisparityEstimator::~FlexibleDisparityEstimator()
{
	delete solver;
}

void FlexibleDisparityEstimator::ValidateFactories()
{
	if(preFactories != nullptr)
		solver->ValidateNodeFactories(*preFactories->Node);
	if(costFactories != nullptr)
		solver->ValidateNodeFactories(*costFactories->Node);
	if(aggregateFactories != nullptr)
		solver->ValidateNodeFactories(*aggregateFactories->Node);
	if(optFactories != nullptr)
		solver->ValidateNodeFactories(*optFactories->Node);
	if(wtaFactories != nullptr)
		solver->ValidateNodeFactories(*wtaFactories->Node);
}

void FlexibleDisparityEstimator::ProcessFrame(array<unsigned int>^ inImageL, array<unsigned int>^ inImageR, 
						int width, int height, int ndisparities, float scaling, array<float>^ outImage)
{
	// Pin the pointers to allow unmanaged code to use them
	pin_ptr<unsigned int> pInImageL = &inImageL[0];
	pin_ptr<unsigned int> pInImageR = &inImageR[0];
	pin_ptr<float> pOutImage = &outImage[0];

	// Calculate the disparity map with the current nodes
	solver->FindDisparityMap(pInImageL, pInImageR, width, height, ndisparities, pOutImage);
}

#pragma region Technique getters
List<TechniqueDescription^>^ FlexibleDisparityEstimator::GetPreprocessTechniques()
{
	List<TechniqueDescription^>^ result = gcnew List<TechniqueDescription^>();

	result->Add(gcnew TechniqueDescription("None", -1));

	int index = 0;
	if(preFactories != nullptr)
		for(StereoNodeFactoryListNode<PreProcStereoNode>* pNode = preFactories->Node; pNode != NULL; pNode = pNode->pNext)
			if(pNode->pFactory->IsValid())
				result->Add(gcnew TechniqueDescription(gcnew String(pNode->pFactory->GetName()), index++));

	return result;
}
List<TechniqueDescription^>^ FlexibleDisparityEstimator::GetCostComputeTechniques()
{
	List<TechniqueDescription^>^ result = gcnew List<TechniqueDescription^>();

	int index = 0;
	if(costFactories != nullptr)
		for(StereoNodeFactoryListNode<CostComputeStereoNode>* pNode = costFactories->Node; pNode != NULL; pNode = pNode->pNext)
			if(pNode->pFactory->IsValid())
				result->Add(gcnew TechniqueDescription(gcnew String(pNode->pFactory->GetName()), index++));

	return result;
}
List<TechniqueDescription^>^ FlexibleDisparityEstimator::GetCostAggregateTechniques()
{
	List<TechniqueDescription^>^ result = gcnew List<TechniqueDescription^>();

	result->Add(gcnew TechniqueDescription("None", -1));

	int index = 0;
	if(aggregateFactories != nullptr)
		for(StereoNodeFactoryListNode<CostAggregateStereoNode>* pNode = aggregateFactories->Node; pNode != NULL; pNode = pNode->pNext)
			if(pNode->pFactory->IsValid())
				result->Add(gcnew TechniqueDescription(gcnew String(pNode->pFactory->GetName()), index++));

	return result;
}
List<TechniqueDescription^>^ FlexibleDisparityEstimator::GetGlobalOptimizationTechniques()
{
	List<TechniqueDescription^>^ result = gcnew List<TechniqueDescription^>();

	result->Add(gcnew TechniqueDescription("None", -1));

	int index = 0;
	if(optFactories != nullptr)
		for(StereoNodeFactoryListNode<GlobalOptStereoNode>* pNode = optFactories->Node; pNode != NULL; pNode = pNode->pNext)
			if(pNode->pFactory->IsValid())
				result->Add(gcnew TechniqueDescription(gcnew String(pNode->pFactory->GetName()), index++));

	return result;
}
List<TechniqueDescription^>^ FlexibleDisparityEstimator::GetWtaTechniques()
{
	List<TechniqueDescription^>^ result = gcnew List<TechniqueDescription^>();

	int index = 0;
	if(wtaFactories != nullptr)
		for(StereoNodeFactoryListNode<WtaSolveStereoNode>* pNode = wtaFactories->Node; pNode != NULL; pNode = pNode->pNext)
			if(pNode->pFactory->IsValid())
				result->Add(gcnew TechniqueDescription(gcnew String(pNode->pFactory->GetName()), index++));

	return result;
}
#pragma endregion

#pragma region Technique setters

void FlexibleDisparityEstimator::SetPreprocessTechnique(TechniqueDescription^ description)
{
	if(description->index == preIdx)
		return;

	if(description->index < 0)
	{
		solver->CreatePreNode(NULL);
		preProcessTarget = nullptr;
	}
	else 
	{
		int index = 0;
		for(StereoNodeFactoryListNodeWrap<PreProcStereoNode>^ node = preFactories; node != nullptr; node = node->Next, ++index)
			if(index == description->index)
			{
				PreProcStereoNode* preNode = solver->CreatePreNode(node->Node->pFactory);
				if(preNode != NULL)
					preProcessTarget = node->CreateParams(preNode);
				else preProcessTarget = nullptr;

				break;//for
			}
	}
	
	preIdx = description->index;
	ValidateFactories();
}
void FlexibleDisparityEstimator::SetCostComputeTechnique(TechniqueDescription^ description)
{
	if(description->index == costIdx)
		return;

	if(description->index < 0)
	{
		solver->CreateCostNode(NULL);
		costComputeTarget = nullptr;
	}
	else 
	{
		int index = 0;
		for(StereoNodeFactoryListNodeWrap<CostComputeStereoNode>^ node = costFactories; node != nullptr; node = node->Next, ++index)
			if(index == description->index)
			{
				CostComputeStereoNode* costNode = solver->CreateCostNode(node->Node->pFactory);
				if(costNode != NULL)
					costComputeTarget = node->CreateParams(costNode);
				else costComputeTarget = nullptr;

				break;//for
			}
	}
	
	costIdx = description->index;
	ValidateFactories();
}
void FlexibleDisparityEstimator::SetCostAggregateTechnique(TechniqueDescription^ description)
{
	if(description->index == aggIdx)
		return;

	if(description->index < 0)
	{
		solver->CreateAggregateNode(NULL);
		aggregateTarget = nullptr;
	}
	else 
	{
		int index = 0;
		for(StereoNodeFactoryListNodeWrap<CostAggregateStereoNode>^ node = aggregateFactories; node != nullptr; node = node->Next, ++index)
			if(index == description->index)
			{
				CostAggregateStereoNode* aggNode = solver->CreateAggregateNode(node->Node->pFactory);
				if(aggNode != NULL)
					aggregateTarget = node->CreateParams(aggNode);
				else aggregateTarget = nullptr;

				break;//for
			}
	}
	
	aggIdx = description->index;
	ValidateFactories();
}
void FlexibleDisparityEstimator::SetGlobalOptimizationTechnique(TechniqueDescription^ description)
{
	if(description->index == optIdx)
		return;

	if(description->index < 0)
	{
		solver->CreateOptNode(NULL);
		optTarget = nullptr;
	}
	else 
	{
		int index = 0;
		for(StereoNodeFactoryListNodeWrap<GlobalOptStereoNode>^ node = optFactories; node != nullptr; node = node->Next, ++index)
			if(index == description->index)
			{
				GlobalOptStereoNode* optNode = solver->CreateOptNode(node->Node->pFactory);
				if(optNode != NULL)
					optTarget = node->CreateParams(optNode);
				else optTarget = nullptr;

				break;//for
			}
	}
	
	optIdx = description->index;
	ValidateFactories();
}
void FlexibleDisparityEstimator::SetWtaTechnique(TechniqueDescription^ description)
{
	if(description->index == wtaIdx)
		return;

	if(description->index < 0)
	{
		solver->CreateWtaNode(NULL);
		wtaTarget = nullptr;
	}
	else 
	{
		int index = 0;
		for(StereoNodeFactoryListNodeWrap<WtaSolveStereoNode>^ node = wtaFactories; node != nullptr; node = node->Next, ++index)
			if(index == description->index)
			{
				WtaSolveStereoNode* wtaNode = solver->CreateWtaNode(node->Node->pFactory);
				if(wtaNode != NULL)
					wtaTarget = node->CreateParams(wtaNode);
				else wtaTarget = nullptr;

				break;//for
			}
	}
	
	wtaIdx = description->index;
	ValidateFactories();
}

#pragma endregion