#include "SimpleWtaDisparityEstimator.h"

#include "CostSpaceComputers.h"
#include "WtaComputers.h"
#include "CudaHelperCommon.cuh"

SimpleWtaDisparityEstimator::~SimpleWtaDisparityEstimator()
{
	// Free image and cost space data
	FreeCudaData();
}
	
void SimpleWtaDisparityEstimator::FreeCudaData()
{
	if(lImageCudaData != NULL)
	{
		cudaFree(lImageCudaData);
		lImageCudaData = NULL;
	}
	if(rImageCudaData != NULL)
	{
		cudaFree(rImageCudaData);
		rImageCudaData = NULL;
	}
	if(costCudaData.ptr != NULL)
	{
		cudaFree(costCudaData.ptr);
		costCudaData.ptr = NULL;
	}
}

void SimpleWtaDisparityEstimator::ReserveNewSpace(int width, int height, int ndisparities)
{
	// Free previous image and cost space data
	FreeCudaData();

	// Ensure the row pitch is a multiple of 128 bytes
	imagesPitch = sizeof(unsigned int) * width;
	imagesPitch = ((imagesPitch - 1) / 128 + 1) * 128;
	
	// Allocate linear device memory
	CUDA_CALL(cudaMalloc((void**)&lImageCudaData, imagesPitch * height));
	CUDA_CALL(cudaMalloc((void**)&rImageCudaData, imagesPitch * height));

	// Create cost space memory
	costSpaceExtent.width = width * sizeof(float);
	costSpaceExtent.height = height;
	costSpaceExtent.depth = ndisparities;
	CUDA_CALL(cudaMalloc3D(&costCudaData, costSpaceExtent));
}
void SimpleWtaDisparityEstimator::GenerateDisparityMap(unsigned int* inImageL, unsigned int* inImageR, 
														int width, int height, int ndisparities, float scaling, 
														float* outImage)
{
	CUDA_CALL(cudaMemcpy2D(lImageCudaData, imagesPitch, inImageL, width * sizeof(unsigned int), width * sizeof(unsigned int), height, cudaMemcpyHostToDevice));
	CUDA_CALL(cudaMemcpy2D(rImageCudaData, imagesPitch, inImageR, width * sizeof(unsigned int), width * sizeof(unsigned int), height, cudaMemcpyHostToDevice));

	RunCostSpaceSadKernel(lImageCudaData, rImageCudaData, imagesPitch / sizeof(unsigned int), width, height, ndisparities, 1.0f, 1.0f, costCudaData);
	ComputeWtaDepthImage(costCudaData, ndisparities, scaling, (float*)lImageCudaData, imagesPitch / sizeof(float));

	CUDA_CALL(cudaMemcpy2D(outImage, width * sizeof(float), lImageCudaData, imagesPitch, width * sizeof(float), height, cudaMemcpyDeviceToHost));
}
