// $Id: BilSubKernel.cu 816 2009-10-16 15:28:34Z cr333 $
#include "stdio.h"
#include "cudatemplates/copy.hpp"
#include "cudatemplates/devicememoryreference.hpp"
#include "UnmanagedPreProcessors.h"
#include "CudaHelperCommon.cuh"
#include "CudaMath.h"
#include "Utils.cuh"

__global__ void BilSubKernel(
	const cudaPitchedPtr gpuImgIn,
	const cudaPitchedPtr gpuImgOut,
	const unsigned int width, const unsigned int height, const int radius, const float sigmaS, const float sigmaR)
{
	const int x = blockIdx.x * blockDim.x + threadIdx.x;
	const int y = blockIdx.y * blockDim.y + threadIdx.y;

	if(x < width && y < height)
	{
		// homogeneous accumulator
		float4 acc = make_float4(0.0f, 0.0f, 0.0f, 0.0f);

		// read in centre pixel
		const uchar4 pix1 = unpack_xyzw<uchar4>(((unsigned int*)gpuImgIn.ptr)[(gpuImgIn.pitch >> 2) * y + x]);
		const float3 pix1f = make_float3<uchar3>(select_xyz<uchar4, uchar3>(pix1));
		const float3 lab1 = xyz2lab(rgb2xyz(srgb2rgb(pix1f / 255.0f)));

		// loop over all pixels in the neighbourhood
		for(int dy = -int(radius); dy <= radius; dy++)
		{
			if(y + dy >= 0 && y + dy < height)
			{
				for(int dx = -int(radius); dx <= radius; dx++)
				{
					if(x + dx >= 0 && x + dx < width)
					{
						float w = 1.0; // start with uniform weight

						const uchar4 pix2 = unpack_xyzw<uchar4>(((int*)gpuImgIn.ptr)[(gpuImgIn.pitch >> 2) * (y + dy) + (x + dx)]);
						const float3 pix2f = make_float3<uchar3>(select_xyz<uchar4, uchar3>(pix2));
						float3 lab2 = xyz2lab(rgb2xyz(srgb2rgb(pix2f / 255.0f)));

						// spatial weight
						w *= __expf( - (dx * dx + dy * dy) / (2 * sigmaS * sigmaS) );

						// range weight
						w *= __expf( - dot(lab1 - lab2, lab1 - lab2) / (2 * sigmaR * sigmaR) );

						acc += make_float4(w * pix2f, w);
					}
				}
			}
		}

		// subtract bilateral background from central pixel (and add offset)
		float4 bilsub = make_float4(pix1f, 0.0f) - acc / acc.w + 127.0f;
		bilsub.w = 0.0f;
		((int*)gpuImgOut.ptr)[(gpuImgOut.pitch >> 2) * y + x] = pack_xyzw(bilsub + 0.5f);
	}
}

void RunBilSubKernel(
	unsigned int* image, const int imageStride,
	unsigned int* tempImage, const int tempImageStride,
	const unsigned int width, const unsigned int height,
	const unsigned int radius, const float sigmaS, const float sigmaR)
{
	// wrap pointers in cudatemplates
	Cuda::DeviceMemoryReference2D<unsigned int> ctImgIn(width, height, image);
	Cuda::DeviceMemoryReference2D<unsigned int> ctImgTmp(width, height, tempImage);
	ctImgIn.setPitch(imageStride);
	ctImgTmp.setPitch(tempImageStride);

	// asserts
	assert(ctImgIn.size == ctImgTmp.size);
	assert(ctImgIn.size[0] == width && ctImgIn.size[1] == height);

	try
	{
		// using regular square tiles for blocks
		dim3 block(16, 16, 1);
		dim3 grid((width + block.x - 1) / block.x, (height + block.y - 1) / block.y, 1);

		RECORD_KERNEL_LAUNCH("BilSub kernel", grid, block);
		BilSubKernel<<<grid, block>>>(
			toPitchedPtr(ctImgIn), toPitchedPtr(ctImgTmp), width, height, radius, sigmaS, sigmaR);
		CHECK_KERNEL_ERROR("BilSub kernel");
		
		// copy results back
		Cuda::copy(ctImgIn, ctImgTmp);
	}
	catch(const std::exception &e)
	{
		fprintf(stderr, "Error: %s", e.what());
	}
}
