#include "StdAfx.h"
#include "DisparityEstimationNode.h"
#include "SimpleWtaDisparityEstimator.h"

using namespace VideoLib::Frames;
using namespace System::Threading;

namespace VideoLib { namespace Stereo { namespace GpGpu {

DisparityEstimationNode::DisparityEstimationNode() : Node(), disparityEstimator(0), readyToProcess(false), cudaAccessExit(false)
{
	processData = gcnew DisparityCallData();

	AddInputPins("inL", "inR");
	AddOutputPin("out");
}
DisparityEstimationNode::~DisparityEstimationNode()
{
	if(cudaAccessThread != nullptr)
	{
		Monitor::Enter(this);
			cudaAccessExit = true;
			Monitor::PulseAll(this);
		Monitor::Exit(this);

		cudaAccessThread->Join();
	}

	if(disparityEstimator != 0)
	{
		delete disparityEstimator;
		disparityEstimator = 0;
	}
}

void DisparityEstimationNode::Init()
{
	Monitor::Enter(this);
	try 
	{
		if(disparityEstimator == 0)
			disparityEstimator = new SimpleWtaDisparityEstimator();

		readyToProcess = true;
		cudaAccessExit = false;
		cudaAccessThread = gcnew Thread(gcnew ThreadStart(this, &DisparityEstimationNode::CudaAccessMain));
		cudaAccessThread->Start();
	}
	finally { Monitor::Exit(this); }

	Node::Init();
}

void DisparityEstimationNode::CudaAccessMain()
{
	Monitor::Enter(this);
	try 
	{
		while(!cudaAccessExit)
		{
			while(readyToProcess && !cudaAccessExit)
				Monitor::Wait(this);

			if(!cudaAccessExit)
			{
				// Call the completely unmanaged disparity map generation function
				disparityEstimator->FindDisparityMap(processData->inImageL, processData->inImageR, 
					processData->width, processData->height, 
					processData->ndisparities, processData->scaling, processData->outImage);
			}

			readyToProcess = true;
			Monitor::PulseAll(this);
		}
	}
	finally { Monitor::Exit(this); }
}


void DisparityEstimationNode::Process()
{
    // Get input frames
    Frame^ inL = GetInputFrame("inL");
    Frame^ inR = GetInputFrame("inR");
    if (PassOnEndOfStreamFrame(inL) || PassOnEndOfStreamFrame(inR)) return;

    // Get the input bitmap data
    BitmapFrame^ bitmapL = CastFrameTo<BitmapFrame^>(inL);
    BitmapFrame^ bitmapR = CastFrameTo<BitmapFrame^>(inR);

    // Check for same size
    if (bitmapL->Width != bitmapR->Width || bitmapL->Height != bitmapR->Height)
        throw gcnew Exception("Both input frames must have same size.");

    // Check for same bitmap format
	if (bitmapL->Format != BitmapFrame::FormatType::RGB24 || bitmapR->Format != BitmapFrame::FormatType::RGB24)
        throw gcnew Exception("Both input frames must be 24-bit BGR.");

    // Check for same sequence number
    if (bitmapL->SequenceNumber != bitmapR->SequenceNumber)
        throw gcnew Exception("Both input frames must have same sequence number.");

	// Copy input images to float arrays, as this is what the CUDA code expects
	unsigned int w = bitmapL->Width;
	unsigned int h = bitmapL->Height;
	float* imL = new float[w * h];
	float* imR = new float[w * h];
	for(unsigned int y = 0; y < h; y++)
	{
		for(unsigned int x = 0; x < w; x++)
		{
			imL[y * w + x] = 
				(0.0721750f * bitmapL->Data[y * bitmapL->Stride + 3 * x] + 
				0.7151522f * bitmapL->Data[y * bitmapL->Stride + 3 * x + 1] + 
				0.2126729f * bitmapL->Data[y * bitmapL->Stride + 3 * x + 2] + 0.5f) / 255.0f;
			imR[y * w + x] = ((x < ConstantShift) ? 0.0f : 
				((0.0721750f * bitmapR->Data[y * bitmapR->Stride + 3 * (x - ConstantShift)] + 
				0.7151522f * bitmapR->Data[y * bitmapR->Stride + 3 * (x - ConstantShift) + 1] + 
				0.2126729f * bitmapR->Data[y * bitmapR->Stride + 3 * (x - ConstantShift) + 2] + 0.5f) / 255.0f));
		}
	}

	// The disparity map generation is done on another thread - all the cuda accesses need to be
	// done on the same thread
	float* result = new float[w * h];
	Monitor::Enter(this);
	try 
	{
		while(!readyToProcess && !cudaAccessExit)
			Monitor::Wait(this);

		if(!cudaAccessExit)
		{
			processData->inImageL = imL;
			processData->inImageR = imR;
			processData->width = w; 
			processData->height = h; 
			processData->ndisparities = 128; 
			processData->scaling = 1.0f / 128.0f; 
			processData->outImage = result;

			readyToProcess = false;
			Monitor::PulseAll(this);
			
			while(!readyToProcess)
				Monitor::Wait(this);
		}
	}
	finally { Monitor::Exit(this); }

	// Copy result into the output frame
	BitmapFrame^ outputFrame = gcnew BitmapFrame(w, h, BitmapFrame::FormatType::Grey8);
    outputFrame->SequenceNumber = inL->SequenceNumber;
	for(unsigned int y = 0; y < h; y++)
		for(unsigned int x = 0; x < w; x++)
			outputFrame->Data[y * outputFrame->Stride + x] = result[y * w + x] * 255.0f;
	
	// Clean up
	delete result;
	delete imL;
	delete imR;

    // Send off the frame
    PushFrame(outputFrame);
}

}}}