#pragma once

#include "StereoSolver.h"

#include "PreProcStereoNode.h"
#include "CostComputeStereoNode.h"
#include "CostAggregateStereoNode.h"
#include "GlobalOptStereoNode.h"
#include "WtaSolveStereoNode.h"
#include "GuidedGlobalOptStereoNode.h"

class StdStereoSolver : public StereoSolver
{
protected:
	// The nodes which form the stereo solver processes
	PreProcStereoNode* preprocessNode;
	CostComputeStereoNode* costNode;
	CostAggregateStereoNode* aggregateNode;
	GlobalOptStereoNode* optNode;
	WtaSolveStereoNode* wtaNode;

	// The cuda data that is to be passed between nodes
	RgbImage gpuImLeft;
	RgbImage gpuImRight;
	CostSpaceGrid gpuCostGrid;
	DepthMap gpuDepthMap;

public:
	StdStereoSolver() : preprocessNode(0), costNode(0), aggregateNode(0), optNode(0), wtaNode(0) { }
	~StdStereoSolver();

	PreProcStereoNode* CreatePreNode(const StereoNodeFactory<PreProcStereoNode>* factory);
	CostComputeStereoNode* CreateCostNode(const StereoNodeFactory<CostComputeStereoNode>* factory);
	CostAggregateStereoNode* CreateAggregateNode(const StereoNodeFactory<CostAggregateStereoNode>* factory);
	GlobalOptStereoNode* CreateOptNode(const StereoNodeFactory<GlobalOptStereoNode>* factory);
	WtaSolveStereoNode* CreateWtaNode(const StereoNodeFactory<WtaSolveStereoNode>* factory);

	virtual void FindDisparityMap(unsigned int* leftIm, unsigned int* rightIm, int width, int height, int nDisparities, float* depthMap);
	virtual void ValidateNodeFactories(StereoNodeFactoryListNodeBase & factories);
};

// TODO: Implement
class DoubleStereoSolver : public StereoSolver
{
protected:
	// The nodes which form the stereo solver processes
	PreProcStereoNode* preprocessNode;
	CostComputeStereoNode* lCostNode;
	CostComputeStereoNode* rCostNode;
	CostAggregateStereoNode* lAggregateNode;
	CostAggregateStereoNode* rAggregateNode;
	GlobalOptStereoNode* lOptNode;
	GlobalOptStereoNode* rOptNode;
	WtaSolveStereoNode* lWtaNode;
	WtaSolveStereoNode* rWtaNode;
	GuidedGlobalOptStereoNode* fOptNode;
	WtaSolveStereoNode* fWtaNode;

	// The cuda data that is to be passed between nodes
	RgbImage gpuImLeft, gpuImRight;
	CostSpaceGrid gpuCostGridLeft, gpuCostGridRight, gpuCostGridFinal;
	DepthMap gpuDepthMapLeft, gpuDepthMapRight;

public:
	DoubleStereoSolver();
	~DoubleStereoSolver();

	PreProcStereoNode* CreatePreNode(const StereoNodeFactory<PreProcStereoNode>* factory);

	void CreateCostNodes(const StereoNodeFactory<CostComputeStereoNode>* factory, CostComputeStereoNode** left, CostComputeStereoNode** right);
	void CreateAggregateNodes(const StereoNodeFactory<CostAggregateStereoNode>* factory, CostAggregateStereoNode** left, CostAggregateStereoNode** right);
	void CreateOptNodes(const StereoNodeFactory<GlobalOptStereoNode>* factory, GlobalOptStereoNode** left, GlobalOptStereoNode** right);
	void CreateWtaNodes(const StereoNodeFactory<WtaSolveStereoNode>* factory, WtaSolveStereoNode** left, WtaSolveStereoNode** right);

	GuidedGlobalOptStereoNode* CreateGuidedOptNode(const StereoNodeFactory<GuidedGlobalOptStereoNode>* factory);
	WtaSolveStereoNode* CreateFinalWtaNode(const StereoNodeFactory<WtaSolveStereoNode>* factory);

public:
	virtual void FindDisparityMap(unsigned int* leftIm, unsigned int* rightIm, int width, int height, int nDisparities, float* depthMap);
	virtual void ValidateNodeFactories(StereoNodeFactoryListNodeBase & factories);
};