// $Id: CostKernelBase.h 65 2010-03-18 17:06:22Z cr333 $
#pragma once

// A set of macros enabling different cost functions to be implemented statically 
// (as function pointers and virtual functions are unsupported in cuda)

// KERNEL_NAME: The name of the kernel to be defined, function signature:
//  For float costs:
//  __global__ void KERNEL_NAME(const unsigned int* const imageL_xrgb, const unsigned int* const imageR_xrgb,
//                              const int stride, const int width, const int height, const int disparityMax,
//							    const int sharedArrayLen, const float rescaleGradient, const float rescaleLimit,
//							    const cudaPitchedPtr outCostPtr)
//  For byte costs:
//  __global__ void KERNEL_NAME(const unsigned int* const imageL_xrgb, const unsigned int* const imageR_xrgb,
//                              const int stride, const int width, const int height, const int disparityMax,
//							    const int sharedArrayLen, const float costGradient const cudaPitchedPtr outCostPtr)

// COST_FUNC: The name of an already defined device-based cost function, function signature:
//   __device__ float COST_FUNC(unsigned int p1, unsigned int p2)

// N,B. The output of the cost function is expected to lie in the range [0, 1], where 0 is identical and 1 is
// maximally different. The output of the implemented kernel, in this case, is [0, rescaleLimit].

#define IMPLEMENT_COST_KERNEL(KERNEL_NAME, COST_FUNC)                                                          \
__global__ void KERNEL_NAME(const unsigned int* const imageL_xrgb, const unsigned int* const imageR_xrgb,      \
                            const int stride, const int width, const int height, const int disparityMax,       \
							const int sharedArrayLen, const float rescaleGradient, const float rescaleLimit,   \
							const cudaPitchedPtr outCostPtr)                                                   \
{                                                                                                              \
	const int x = (blockDim.x - disparityMax) * blockIdx.x + threadIdx.x;                                      \
	const int imIndex = stride * blockIdx.y + x;                                                               \
                                                                                                               \
	extern __shared__ unsigned int shared_array[];                                                             \
	unsigned int testPixel;                                                                                    \
                                                                                                               \
	if(x < width)                                                                                              \
	{                                                                                                          \
		testPixel = imageL_xrgb[imIndex]; /*Read in pixel test data*/                                          \
		shared_array[threadIdx.x] = imageR_xrgb[imIndex]; /*Read in relevant scanline data from right image*/  \
	}                                                                                                          \
	__syncthreads();                                                                                           \
                                                                                                               \
    /* Process the disparity block */                                                                          \
	if(x < width && (blockIdx.x == 0 || threadIdx.x >= disparityMax))                                          \
		for(int d = 0; d < disparityMax; ++d)                                                                  \
			ACCESS_3D(outCostPtr, x, blockIdx.y, d) = (d > threadIdx.x ? rescaleLimit :                        \
				min(rescaleGradient * COST_FUNC(testPixel, shared_array[threadIdx.x - d]), rescaleLimit));     \
}