// $Id: DcbSuperResKernel.cu 886 2009-11-20 16:01:46Z cr333 $

#include "cudatemplates/devicememory.hpp"
#include "cudatemplates/devicememorypitched.hpp"
#include "cudatemplates/devicememoryreference.hpp"
#include "Utils.cuh"
#include "CudaHelperCommon.cuh"
#include "DcbSuperResKernel.cuh"
#include "UnmanagedUpsamplers.h"
#include "UnmanagedAggregators.h"
#include "UnmanagedMaximizers.h"

#ifdef UPSAMPLE_GROUND_TRUTH // FOR PAPER: upsample ground truth disparity map for comparison -----
#include "cudatemplates/copy.hpp"
#include "cudatemplates/hostmemoryreference.hpp"
#include "GpuTiledImages.hpp"
#endif // FOR PAPER: upsample ground truth disparity map for comparison ---------------------------

__global__ void CostFromDepthKernel(
    const cudaPitchedPtr dispmap,
    const cudaPitchedPtr costs,
    const int width,
    const int height,
    const int numDisps,
    const float scaleDisps,
    const float clipCost)
{
	const int x = blockIdx.x * blockDim.x + threadIdx.x;
	const int y = blockIdx.y * blockDim.y + threadIdx.y;

    extern __shared__ char sharedmem[];
    const int sm_pos = blockDim.x * threadIdx.y + threadIdx.x; // index into shared memory
    const int disp_pos = (dispmap.pitch >> 2) * y + x; // index into disparity map

    // copy slice of disparity map into shared memory, for later use
    float* dispSlice = (float*)sharedmem;
    if(x < width && y < height) dispSlice[sm_pos] = ((float*)dispmap.ptr)[disp_pos] / scaleDisps;
    __syncthreads();

    //// copy slice of blurred disparity map into shared memory, for later use
    //float* dispSlice = (float*)sharedmem;
    //float2 acc = make_float2(0.0f, 0.0f);
    //for(int yi = y - 2; yi <= y + 2; yi++)
    //{
    //    for(int xi = x - 2; xi < x + 2; xi++)
    //    {
    //        if(xi >= 0 && xi < width && yi >= 0 && yi < height)
    //        {
    //            acc += make_float2(((float*)dispmap.ptr)[(dispmap.pitch >> 2) * yi + xi], scaleDisps);
    //        }
    //    }
    //}

    //if(x < width && y < height) dispSlice[sm_pos] = acc.x / acc.y;
    //__syncthreads();

    if(x < width && y < height) 
    {
        #pragma unroll 8
        for(int d = 0; d < numDisps; d++)
        {
            const float dispDiff = (float)d - dispSlice[sm_pos];
	        const int cost_pos = (costs.pitch >> 2) * (height * d + y) + x;

	        if(d < numDisps) // ensure that loop unrolling does not cause access violations
            {
                ((float*)costs.ptr)[cost_pos] = min(clipCost, dispDiff * dispDiff);
            }
        }
    }
}

void RunCostFromDepth(
    const Cuda::DeviceMemory<float, 2>& gpuImgD,
    const Cuda::DeviceMemory<float, 3>& gpuCost,
    const float scaleDisps, const float clipCosts)
{
    assert(gpuImgD.size[0] == gpuCost.size[0] && gpuImgD.size[1] == gpuCost.size[1]);

    const unsigned int w = gpuCost.size[0];
    const unsigned int h = gpuCost.size[1];
    const unsigned int numDisps = gpuCost.size[2];

	// The grid structure is determined by the block structure,
	// as the entire image size needs to be covered.
    dim3 costBlock(min(256, w), 1, 1);
	dim3 costGrid((w + costBlock.x - 1) / costBlock.x, (h + costBlock.y - 1) / costBlock.y, 1);

	RECORD_KERNEL_LAUNCH("Cost-from-depth kernel", costGrid, costBlock);
	CostFromDepthKernel<<<costGrid, costBlock, costBlock.x * costBlock.y * sizeof(float)>>>(
        toPitchedPtr(gpuImgD), toPitchedPtr(gpuCost), w, h, numDisps, scaleDisps, clipCosts);
	CHECK_KERNEL_ERROR("Cost-from-depth kernel");
}


void RunDcbSuperResolution(
#ifdef UPSAMPLE_GROUND_TRUTH // FOR PAPER: upsample ground truth disparity map for comparison -----
float* gtDepth,
#endif // FOR PAPER: upsample ground truth disparity map for comparison ---------------------------
// - low res depth map: ptr, width, pitch, height
const float* inputDepth,
const unsigned int inputDepthStride,
const unsigned int inWidth,
const unsigned int inHeight,
const unsigned int inNumDisps,
// - high res input images: left ptr, right ptr, width, pitch, height
const unsigned int* leftImage,
const unsigned int* rightImage,
const unsigned int inputImageStride,
// - temporary cost space: ptr, pitch (have width, height of high res images)
cudaPitchedPtr tempCostSpace,
// - temporary grid texture: ptr, width, height, pitch
float2* tempGrid,
const unsigned int gridStride,
const unsigned int gridWidth,
const unsigned int gridHeight,
// - high res depth map: ptr, pitch (have width, height of high res images)
float* outputDepth,
const unsigned int outputDepthStride,
const unsigned int outWidth,
const unsigned int outHeight,
const unsigned int outNumDisps,
// parameters:
const float clipCosts,
const float sigmaS,
const float sigmaC,
const unsigned int numInterations
)
{
    // wrap everything in cuda templates
	Cuda::DeviceMemoryReference2D<const float> ctInputDepth(inWidth, inHeight, inputDepth);
	Cuda::DeviceMemoryReference2D<const unsigned int> ctLeftImage(outWidth, outHeight, leftImage);
	Cuda::DeviceMemoryReference2D<const unsigned int> ctRightImage(outWidth, outHeight, rightImage);
    Cuda::DeviceMemoryReference3D<float> ctTempCost(outWidth, outHeight, outNumDisps, (float*)tempCostSpace.ptr);
    Cuda::DeviceMemoryReference2D<float2> ctTempGrid(gridWidth, gridHeight, tempGrid);
	Cuda::DeviceMemoryReference2D<float> ctOutputDepth(outWidth, outHeight, outputDepth);

    // make sure we use the right pitch for everything
    ctInputDepth.setPitch(inputDepthStride);
    ctLeftImage.setPitch(inputImageStride);
    ctRightImage.setPitch(inputImageStride);
    ctTempCost.setPitch(tempCostSpace.pitch);
    ctTempGrid.setPitch(gridStride);
    ctOutputDepth.setPitch(outputDepthStride);

    assert(ctLeftImage.size == ctRightImage.size);
    assert(ctLeftImage.size == ctOutputDepth.size);
    assert(ctLeftImage.size[0] == ctTempCost.size[0] && ctLeftImage.size[1] == ctTempCost.size[1] && ctTempCost.size[2] == outNumDisps);

#ifdef UPSAMPLE_GROUND_TRUTH // FOR PAPER: upsample ground truth disparity map for comparison -----

	//GreyImage gtDisps("../../../data/tsukuba/true.png");
	Cuda::HostMemoryReference2D<float> cthGtDisps(inWidth, inHeight, gtDepth);
	Cuda::DeviceMemoryPitched2D<float> ctdGtDisps(inWidth, inHeight);
	Cuda::copy(ctdGtDisps, cthGtDisps);
	assert(cthGtDisps.size == ctInputDepth.size);

    // upsample low-res input depth map to output resolution
    // ctInputDepth => ctOutputDepth (use as temporary)
    RunNnKernel(ctdGtDisps.getBuffer(), ctdGtDisps.getPitch() / sizeof(unsigned int), inWidth, inHeight,
        outWidth, outHeight, outWidth / inWidth, ctOutputDepth.getBuffer(), ctOutputDepth.getPitch() / sizeof(float));

	GpuTiledImages3D<float2>* grids = new GpuTiledImages3D<float2>(1 << 13);

#else // FOR PAPER: upsample ground truth disparity map for comparison ----------------------------

    // upsample low-res input depth map to output resolution
    // ctInputDepth => ctOutputDepth (use as temporary)
    RunNnKernel(ctInputDepth.getBuffer(), ctInputDepth.getPitch() / sizeof(unsigned int), inWidth, inHeight,
        outWidth, outHeight, outWidth / inWidth, ctOutputDepth.getBuffer(), ctOutputDepth.getPitch() / sizeof(float));

#endif // FOR PAPER: upsample ground truth disparity map for comparison ---------------------------

	for(unsigned int i = 0; i < numInterations; i++)
	{
		// compute cost volume from depth map
		RunCostFromDepth(ctOutputDepth, ctTempCost, 1.0f / outNumDisps, clipCosts);
		
		// apply filtering/aggregation
		//RunAggregationYoonKweon(toPitchedPtr(ctTempCost), outNumDisps, ctLeftImage.getBuffer(), ctRightImage.getBuffer(), ctLeftImage.getPitch(), ctLeftImage.size[0], ctLeftImage.size[1], 17, sigmaS, sigmaC);
		//RunAggregationNaiveDCB(toPitchedPtr(ctTempCost), outNumDisps, ctLeftImage.getBuffer(), ctRightImage.getBuffer(), ctLeftImage.getPitch(), ctLeftImage.size[0], ctLeftImage.size[1], 17, sigmaS, sigmaC);
		RunAggregationDCBGrid(toPitchedPtr(ctTempCost), outNumDisps, ctLeftImage.getBuffer(), ctRightImage.getBuffer(), ctLeftImage.getPitch(), ctTempGrid.getBuffer(), ctTempGrid.size[0], ctTempGrid.size[1], ctTempGrid.getPitch(), ctLeftImage.size[0], ctLeftImage.size[1], sigmaS, sigmaC);
		//RunAggregationDCBGrid2(toPitchedPtr(ctTempCost), outNumDisps, ctLeftImage.getBuffer(), ctRightImage.getBuffer(), ctLeftImage.getPitch(), grids, ctLeftImage.size[0], ctLeftImage.size[1], 10, 10, 10);

		// compute refined depth map
		RunSubPixWtaKernel(toPitchedPtr(ctTempCost), outNumDisps, 1.0f / outNumDisps, 1.0f, ctOutputDepth.getBuffer(), ctOutputDepth.getPitch() / sizeof(float));
	}

#ifdef UPSAMPLE_GROUND_TRUTH // FOR PAPER: upsample ground truth disparity map for comparison -----
	delete grids;
#endif // FOR PAPER: upsample ground truth disparity map for comparison ---------------------------
}