// $Id: JbuKernel.cu 792 2009-10-01 18:24:11Z daho2 $
#include "UnmanagedUpsamplers.h"
#include "CudaHelperCommon.cuh"

#define FILTER_RADIUS 2

// Returns a gaussian colour similarity metric between two colours
__device__ float ColourSimilarity(const unsigned int p0, const unsigned int p1, const float colourDev)
{
	const int deltaB = (int)(p0 & UINT_MASK_0) - (p1 & UINT_MASK_0);
	const int deltaG = (int)((p0 >> 8) & UINT_MASK_0) - ((p1 >> 8) & UINT_MASK_0);
	const int deltaR = (int)((p0 >> 16) & UINT_MASK_0) - ((p1 >> 16) & UINT_MASK_0);

	return expf(-(float)(deltaR * deltaR + deltaG * deltaG + deltaB * deltaB) / (255.0f * 255.0f * colourDev * colourDev));
}

// Returns a gaussian distance metric for two vector components
__device__ float SpatialDistance(float x, float y, float spatialDev)
{
	return expf(-(x * x + y * y) / (2 * spatialDev * spatialDev));
}

__global__ void JbuKernel(const float* const dMapIn, int dMapInStride, const unsigned int* const image, const int imageStride, 
						  const int inWidth, const int inHeight, const int outWidth, const int outHeight, const int sampleFactor, 
						  const float spatialDev, const float colourDev, float* const dMapOut, const int dMapOutStride)
{
	const int x = (blockDim.x - 2 * FILTER_RADIUS) * blockIdx.x + threadIdx.x;
	const int y = (blockDim.y - 2 * FILTER_RADIUS) * blockIdx.y + threadIdx.y;
	const int tid = blockDim.x * threadIdx.y + threadIdx.x;

	// Up-sampled image coordinates corresponding to the pixel (x, y) in the low-resolution input
	const int xUp = x * sampleFactor;
	const int yUp = y * sampleFactor;

	extern __shared__ char shared_array[];
	unsigned int* shared_array_im = (unsigned int*)shared_array;
	float* shared_array_depth = (float*)(shared_array + blockDim.x * blockDim.y * sizeof(unsigned int));

	// Read in image and depth-map data to sahared memory
	shared_array_im[tid] = ((xUp >= outWidth || yUp >= outHeight) ? 0 : image[imageStride * yUp + xUp]);
	shared_array_depth[tid] = ((x >= inWidth || y >= inHeight) ? 0.0f : dMapIn[dMapInStride * y + x]);

	__syncthreads();

	// Only proceed for non-edge threads
	if(((threadIdx.x >= FILTER_RADIUS) || (blockIdx.x == 0)) && ((threadIdx.y >= FILTER_RADIUS) || (blockIdx.y == 0)) &&
		(threadIdx.x < blockDim.x - FILTER_RADIUS) && (threadIdx.y < blockDim.y - FILTER_RADIUS))
	{
		// Iterate over the corresponding window of pixels in the upsampled image
		for(int dy = 0; dy < sampleFactor; ++dy)
			for(int dx = 0; dx < sampleFactor; ++dx)
				if((xUp + dx < outWidth) && (yUp + dy < outHeight))
				{
					const unsigned int basePixel = image[imageStride * (yUp + dy) + xUp + dx];
					const float deltaX = (float)dx / (float)sampleFactor;
					const float deltaY = (float)dy / (float)sampleFactor;

					float result = 0.0f, sumMul = 0.0f;

					// Iterate over the small filter window
					for(int j = -FILTER_RADIUS; j <= FILTER_RADIUS; ++j)
						for(int i = -FILTER_RADIUS; i <= FILTER_RADIUS; ++i)
							if((x + i < inWidth) && (x + i >= 0) && (y + j < inHeight) && (y + j >= 0))
							{
								// Spatial and intensity-difference weighting
								const float mulFactor = 
									SpatialDistance((float)i - deltaX, (float)j - deltaY, spatialDev)  *
									ColourSimilarity(shared_array_im[blockDim.x * (threadIdx.y + j) + threadIdx.x + i], 
													 basePixel, colourDev);

								result += shared_array_depth[blockDim.x * (threadIdx.y + j) + threadIdx.x + i] * mulFactor;
								sumMul += mulFactor;
							}

					// Output the normalized result
					dMapOut[dMapOutStride * (yUp + dy) + xUp + dx] = result / sumMul;
				}
	}
}

void RunJbuKernel(const float* inputDepth, int inputDepthStride, int inWidth, int inHeight, const unsigned int* refImage, int refImageStride, 
				  int outWidth, int outHeight, float sigmaS, float sigmaC, float* outputDepth, int outputDepthStride)
{
	dim3 blockDimension(16, 8);
	dim3 gridDimension((inWidth - FILTER_RADIUS - 1) / (blockDimension.x - 2 * FILTER_RADIUS) + 1, (inHeight - FILTER_RADIUS - 1) / (blockDimension.y - 2 * FILTER_RADIUS) + 1);
	unsigned int sharedMemBytes = blockDimension.x * blockDimension.y * (sizeof(unsigned int) + sizeof(float));

	RECORD_KERNEL_LAUNCH("JBU Up-sampling kernel", gridDimension, blockDimension);

	JbuKernel<<<gridDimension, blockDimension, sharedMemBytes>>>(inputDepth, inputDepthStride, refImage, refImageStride, 
		inWidth, inHeight, outWidth, outHeight, outWidth / inWidth, sigmaS, sigmaC, outputDepth, outputDepthStride);
	
	CHECK_KERNEL_ERROR("JBU Up-sampling kernel");
}