// $Id: SubPixWtaKernel.cu 819 2009-10-21 15:24:41Z cr333 $
#include "UnmanagedMaximizers.h"
#include "CudaHelperCommon.cuh"

// Find the minimum of the cost function for each point on the image in a simple 'winner takes all' way
__global__ void SubPixWtaKernel(cudaPitchedPtr costPtr, int depth, float scaleFactor, float threshold, float* result, int resultStride)
{
	int x = blockDim.x * blockIdx.x + threadIdx.x;
	int y = blockDim.y * blockIdx.y + threadIdx.y;

	float minDisparity = 1E6f;
	int minDisparityIndex = -1;


	if(x < costPtr.xsize / sizeof(float) && y < costPtr.ysize)
	{
		// These values keep track of the costs either side of the minimum
		float prevDisparity;
		float preMd = 1.0f;
		float postMd = 1.0f;

		float disparity;

		// Minimize the cost for each pixel individually over the range of disparity values
		for(int d = 0; d < depth; ++d)
		{
			disparity = ACCESS_3D(costPtr, x, y, d);

			if(disparity < minDisparity)
			{
				minDisparityIndex = d;
				minDisparity = postMd = disparity;

				// Write the disparity before the minimum to preMd
				preMd = (d == 0 ? disparity : prevDisparity);
			}

			// Write the disparity after the minimum to postMd
			if(d == minDisparityIndex + 1)
				postMd = disparity;

			prevDisparity = disparity;
		}
		if(minDisparityIndex == depth - 1)
			postMd = disparity;

		// Quadratic approximation to find minimum disparity estimate
        const float subpixDisparity = (preMd - postMd) / (2 * (preMd + postMd - 2 * minDisparity));
		float normDepth = (float)minDisparityIndex;

		// Only use sub-pixel refinement if the value is in-range
        if(fabs(subpixDisparity) <= threshold)
			normDepth += subpixDisparity; 

		// Write the range-truncated refined disparity value to the result
		result[resultStride * y + x] = (minDisparityIndex >= 0 ? max(min(normDepth * scaleFactor, 1.0f), 0.0f) : 0.0f);
	}
}

void RunSubPixWtaKernel(const cudaPitchedPtr & costSpace, int disparityMax, float scaleFactor, float threshold, float* result, int resultStride)
{
	dim3 blockDimension(128, 1);
	dim3 gridDimension((costSpace.xsize / sizeof(float) - 1) / blockDimension.x + 1, (costSpace.ysize - 1) / blockDimension.y + 1);

	RECORD_KERNEL_LAUNCH("Sub-pixel WTA kernel", gridDimension, blockDimension);
	
	SubPixWtaKernel<<<gridDimension, blockDimension>>>(costSpace, disparityMax, scaleFactor, threshold, result, resultStride);

	CHECK_KERNEL_ERROR("Sub-pixel WTA kernel");
}