// $Id: BoxAggKernel.cu 789 2009-10-01 10:58:01Z daho2 $
#include "UnmanagedAggregators.h"
#include "CudaHelperCommon.cuh"

__global__ void BoxAggKernel_H(const cudaPitchedPtr costPtr, const int disparityMax, const int radius, const float normFactor, const cudaPitchedPtr costPtrOut)
{
	const int x = blockDim.x * blockIdx.x + threadIdx.x;
	const int y = blockDim.y * blockIdx.y + threadIdx.y;
	const int index = (blockDim.x + 2 * radius) * threadIdx.y + threadIdx.x + radius;
	const bool isValid = (y < costPtr.ysize) && (x < costPtr.xsize / sizeof(float));

	// Additional read for block edge values
	const int extraWriteIndex = index - radius + (threadIdx.x < radius ? 0 : blockDim.x);
	const int extraReadX = x - radius + (threadIdx.x < radius ? 0 : blockDim.x);
	const bool doExtraRead = ((extraReadX >= 0) && (extraReadX < costPtr.xsize / sizeof(float)));

	extern __shared__ float buffer_boxh[];
	
	// Iterate through all the disparities, running the aggregation kernel for each
	for(int d = 0; d < disparityMax; ++d)
	{
		float baseVal = (!isValid ? 0.0f : ACCESS_3D(costPtr, x, y, d));
		buffer_boxh[index] = baseVal;

		// Do an additional read of edge values (only some threads)
		if(threadIdx.x < 2 * radius && y < costPtr.ysize)
			buffer_boxh[extraWriteIndex] = (doExtraRead ? ACCESS_3D(costPtr, extraReadX, y, d) : 0.0f);

		__syncthreads();

		// Perform the averaging operation
		for(int i = 1; i <= radius; ++i)
			baseVal += buffer_boxh[index + i] + buffer_boxh[index - i];

		__syncthreads();

		// Save the result
		if(isValid)
			ACCESS_3D(costPtrOut, x, y, d) = baseVal * normFactor;
	}
}

// The memory usage and iteration break-down is relatively simple: blocks are tiled across an (X,D) plane, and each
// block iterates down in the Y direction
__global__ void BoxAggKernel_V(const cudaPitchedPtr costPtr, const int disparityMax, const int radius, const float normFactor)
{
	const int x = blockDim.x * blockIdx.x + threadIdx.x;
	const int d = blockDim.y * blockIdx.y + threadIdx.y;

	// Prevent this needing to be recomputed each time
	const int baseBufIndex = blockDim.x * threadIdx.y + threadIdx.x;

	extern __shared__ float buffer_boxv[];

	if(x < costPtr.xsize / sizeof(float) && d < disparityMax)
	{
		int bufIndex = 0;
		float runningSum = 0;

		// First loop runs until the buffer is full, so that the averages can be written back
		// and the previous values subtracted
		int y = 0;
		for(; y < costPtr.ysize && y < 2 * radius + 1; ++y)
		{
			// Add each pixel value to the running sum, and the end of a 'circular buffer', so that it can be subtracted
			// when the window moves past the current pixel
			float rVal = ACCESS_3D(costPtr, x, y, d);
			runningSum += rVal;
			buffer_boxv[blockDim.x * blockDim.y * bufIndex + baseBufIndex] = rVal;
			
			// Increment the buffer index (without expensive modulo operation)
			++bufIndex;
			if(bufIndex >= 2 * radius + 1)
				bufIndex = 0;
		}

		// Second loop actually writes back the averaged values
		for(; y < costPtr.ysize; ++y)
		{
			runningSum -= buffer_boxv[blockDim.x * blockDim.y * bufIndex + baseBufIndex];

			// Add each pixel value to the running sum, and the end of a 'circular buffer', so that it can be subtracted
			// when the window moves past the current pixel
			float rVal = ACCESS_3D(costPtr, x, y, d);
			runningSum += rVal;
			buffer_boxv[blockDim.x * blockDim.y * bufIndex + baseBufIndex] = rVal;

			// Save the result
			ACCESS_3D(costPtr, x, y - radius, d) = runningSum * normFactor;

			// Increment the buffer index (without expensive modulo operation)
			++bufIndex;
			if(bufIndex >= 2 * radius + 1)
				bufIndex = 0;
		}
	}
}

// Performs a box aggregation over a fixed spatial window (2 * radius + 1) x (2 * radius + 1).
// Works for radius between 1 and 16
void RunBoxAggKernel(const cudaPitchedPtr & costPtr, int disparityMax, int radius, const cudaPitchedPtr & costPtrOut)
{
	int blockDimX = 16 * ((2 * radius - 1) / 16 + 1);
	dim3 blockDimension(blockDimX, max(1, 256 / blockDimX));
	dim3 gridDimension((costPtr.xsize / sizeof(float) - 1) / blockDimension.x + 1,
		(costPtr.ysize - 1) / blockDimension.y + 1);

	// WARNING: This limits the filter half size
	int sharedMemBytes = (blockDimension.x + 2 * radius) * blockDimension.y * sizeof(float);

	RECORD_KERNEL_LAUNCH("Box aggregation kernel H", gridDimension, blockDimension);
	BoxAggKernel_H<<<gridDimension, blockDimension, sharedMemBytes>>>(costPtr, disparityMax, radius, 1.0f / (float)(2 * radius + 1), costPtrOut);
	CHECK_KERNEL_ERROR("Box aggregation kernel H");

	blockDimension = dim3(32, 2);
	gridDimension = dim3((costPtr.xsize / sizeof(float) - 1) / blockDimension.x + 1, 
		(disparityMax - 1) / blockDimension.y + 1);

	// WARNING: This limits the filter half-size to (16384 / (4 * <threads-per-block>) - 1) / 2 = 31
	sharedMemBytes = blockDimension.y * blockDimension.x * (2 * radius + 1) * sizeof(float);

	RECORD_KERNEL_LAUNCH("Box aggregation kernel V", gridDimension, blockDimension);
	BoxAggKernel_V<<<gridDimension, blockDimension, sharedMemBytes>>>(costPtrOut, disparityMax, radius, 1.0f / (float)(2 * radius + 1));
	CHECK_KERNEL_ERROR("Box aggregation kernel  V");
}