// $Id: WtaKernel.cu 792 2009-10-01 18:24:11Z daho2 $
#include "UnmanagedMaximizers.h"
#include "CudaHelperCommon.cuh"

// Find the minimum of the cost function for each point on the image in a simple 'winner takes all' way
__global__ void WtaKernel(cudaPitchedPtr costPtr, int depth, float scaleFactor, float* result, int resultStride)
{
	const int x = blockDim.x * blockIdx.x + threadIdx.x;
	const int y = blockDim.y * blockIdx.y + threadIdx.y;

	float minDisparity = 1E6f;
	int minDisparityIndex = -1;
	if(x < costPtr.xsize / sizeof(float) && y < costPtr.ysize)
	{
		// Minimize the cost for each pixel individually over the range of disparity values
		for(int d = 0; d < depth; ++d)
		{
			float disparity = ACCESS_3D(costPtr, x, y, d);

			if(disparity < minDisparity)
			{
				minDisparityIndex = d;
				minDisparity = disparity;
			}
		}

		result[resultStride * y + x] = (minDisparityIndex >= 0 ? ((float)minDisparityIndex * scaleFactor) : 0.0f);
	}
}

void RunWtaKernel(const cudaPitchedPtr & costSpace, int disparityMax, float scaleFactor, float* result, int resultStride)
{
	dim3 blockDimension(128, 1);
	dim3 gridDimension((costSpace.xsize / sizeof(float) - 1) / blockDimension.x + 1, (costSpace.ysize - 1) / blockDimension.y + 1);

	RECORD_KERNEL_LAUNCH("WTA kernel", gridDimension, blockDimension);
	WtaKernel<<<gridDimension, blockDimension>>>(costSpace, disparityMax, scaleFactor, result, resultStride);
	CHECK_KERNEL_ERROR("WTA kernel");
}