// $Id: MinAggKernel.cu 789 2009-10-01 10:58:01Z daho2 $
#include "UnmanagedAggregators.h"
#include "CudaHelperCommon.cuh"

__global__ void MinAggKernel_H(cudaPitchedPtr costPtr, int disparityMax, int radius, float normFactor, cudaPitchedPtr costPtrOut)
{
	const int x = blockDim.x * blockIdx.x + threadIdx.x;
	const int y = blockDim.y * blockIdx.y + threadIdx.y;
	const int index = (blockDim.x + 2 * radius) * threadIdx.y + threadIdx.x + radius;
	const bool isValid = (y < costPtr.ysize) && (x < costPtr.xsize / sizeof(float));

	// Additional read for block edge values
	const int extraWriteIndex = index - radius + (threadIdx.x < radius ? 0 : blockDim.x);
	const int extraReadX = x - radius + (threadIdx.x < radius ? 0 : blockDim.x);
	const bool doExtraRead = ((extraReadX >= 0) && (extraReadX < costPtr.xsize / sizeof(float)));

	extern __shared__ float buffer_boxh[];
	
	// Iterate through all the disparities, running the aggregation kernel for each
	for(int d = 0; d < disparityMax; ++d)
	{
		float baseVal = (!isValid ? 0.0f : ACCESS_3D(costPtr, x, y, d));
		buffer_boxh[index] = baseVal;

		// Read in extra edge values (some threads only)
		if(threadIdx.x < 2 * radius && y < costPtr.ysize)
			buffer_boxh[extraWriteIndex] = (!doExtraRead ? 0.0f : ACCESS_3D(costPtr, extraReadX, y, d));

		__syncthreads();

		// Main filtering loop
		for(int i = 1; i <= radius; ++i)
			baseVal = min(baseVal, min(buffer_boxh[index + i], buffer_boxh[index - i]));

		__syncthreads();

		// Write out result
		if(isValid)
		{
			ACCESS_3D(costPtrOut, x, y, d) = baseVal * normFactor;
		}
	}
}

__global__ void MinAggKernel_V(cudaPitchedPtr costPtr, int disparityMax, int radius, float normFactor, cudaPitchedPtr costPtrOut)
{
	const int x = blockDim.x * blockIdx.x + threadIdx.x;
	const int y = blockDim.y * blockIdx.y + threadIdx.y;
	const int index = blockDim.x * threadIdx.y + threadIdx.x + blockDim.x * radius;
	const bool isValid = (x < costPtr.xsize / sizeof(float) && y < costPtr.ysize);

	// Additional read for block edge values
	const int extraWriteIndex = index - blockDim.x * radius + (threadIdx.y < radius ? 0 : blockDim.y * blockDim.x);
	const int extraReadY = y - radius + (threadIdx.y < radius ? 0 : blockDim.y);
	const bool doExtraRead = ((extraReadY >= 0) && (extraReadY < costPtr.ysize));

	extern __shared__ float buffer_boxv[];
	
	// Iterate through all the disparities, running the aggregation kernel for each
	for(int d = 0; d < disparityMax; ++d)
	{
		float baseVal = (!isValid ? 0.0f : ACCESS_3D(costPtr, x, y, d));
		buffer_boxv[index] = baseVal;

		// Read in extra edge values (some threads only)
		if(threadIdx.y < 2 * radius && isValid)
			buffer_boxv[extraWriteIndex] = (!doExtraRead ? 0.0f : ACCESS_3D(costPtr, x, extraReadY, d));

		__syncthreads();

		// Main filtering loop
		for(int i = 1; i <= radius; ++i)
			baseVal = min(baseVal, min(buffer_boxv[index + blockDim.x * i], buffer_boxv[index - blockDim.x * i]));

		__syncthreads();

		// Write out result
		if(isValid)
		{
			ACCESS_3D(costPtrOut, x, y, d) = baseVal * normFactor;
		}
	}
}

// The range of allowed values of radius is 0 to 16
void RunMinAggKernel(const cudaPitchedPtr & costPtr, int disparityMax, int radius, const cudaPitchedPtr & interPtr)
{
	// Calculate the size of x-blocks in the horizontal filter
	int minFilterDim = 16 * ((2 * radius - 1) / 16 + 1);

	dim3 blockDimension(minFilterDim, max(1, 256 / minFilterDim));
	dim3 gridDimension((costPtr.xsize / sizeof(float) - 1) / blockDimension.x + 1,
		(costPtr.ysize - 1) / blockDimension.y + 1);

	int sharedMemBytes = (blockDimension.x + 2 * radius) * blockDimension.y * sizeof(float);

	RECORD_KERNEL_LAUNCH("Min aggregation kernel H", gridDimension, blockDimension);
	MinAggKernel_H<<<gridDimension, blockDimension, sharedMemBytes>>>(costPtr, disparityMax, radius, 1.0f / (float)(2 * radius + 1), interPtr);
	CHECK_KERNEL_ERROR("Min aggregation kernel H");

	// Calculate the preferred size of y-blocks in the vertical filter
	int prefFilterBlocks = 4 * ((2 * radius - 1) / 4 + 1);

	blockDimension = dim3(16, max(4, min(32, prefFilterBlocks)));
	gridDimension = dim3((costPtr.xsize / sizeof(float) - 1) / blockDimension.x + 1,
		(costPtr.ysize - 1) / blockDimension.y + 1);

	sharedMemBytes = (blockDimension.y + 2 * radius) * blockDimension.x * sizeof(float);

	RECORD_KERNEL_LAUNCH("Min aggregation kernel V", gridDimension, blockDimension);
	MinAggKernel_V<<<gridDimension, blockDimension, sharedMemBytes>>>(interPtr, disparityMax, radius, 1.0f / (float)(2 * radius + 1), costPtr);
	CHECK_KERNEL_ERROR("Min aggregation kernel  V");
}