// $Id: CostSpaceKernels.cu 907 2010-03-14 22:18:36Z cr333 $

#include "cudatemplates/devicememoryreference.hpp"
//#include "CostSpaceComputers.h"
#include "CudaHelperCommon.cuh"
#include "Utils.cuh"

template <class CostFunction> __global__ void costKernel(
    const cudaPitchedPtr src1,
    const cudaPitchedPtr src2,
    const cudaPitchedPtr costs,
    const int width,
    const int height,
    const int numDisps,
    const float scaleCost,
    const float clipCost)
{
    const int x = blockIdx.x * blockDim.x + threadIdx.x;
    const int y = blockIdx.y;
    const int pos2 = y * (src1.pitch >> 2) + x;

    extern __shared__ char sharedmem[];
    uchar4* smemR = (uchar4*)sharedmem;

// 1. copy relevant bits from left and right images into shared memory ----------------------------

    uchar4 pixL;
    if(x < width)
    {
        // slices of left and right images
        pixL = int_to_uchar4( ((int*)src1.ptr)[pos2] );
        smemR[numDisps + threadIdx.x] = int_to_uchar4( ((int*)src2.ptr)[pos2] );
    }

    // need more from other image to fit in disparity range (only for 2nd block & up)
    if(blockIdx.x > 0)
    {
        if(threadIdx.x < numDisps)
        {
            // assuming that a block is wider than the number of disparities,
            // we don't need to check dimensions here
            smemR[threadIdx.x] = int_to_uchar4( ((int*)src2.ptr)[pos2 - numDisps] );
        }
    }

    // wait for all shared memory activity to be finished
    __syncthreads();

// 2. now compute the per-pixel costs -------------------------------------------------------------

    if(x < width) // only for valid pixels
    {
        for(int d = 0; d < numDisps; d++)
        {
            // position where to write cost in cost volume
            const int pos3 = (height * d + y) * (costs.pitch >> 2) + x;

            // default: maximum cost (for disparities that map pixels outside the right image)
            float cost = clipCost;

            if (x - d >= 0)
            {
                const uchar4 pixR = smemR[numDisps + threadIdx.x - d];
                cost = scaleCost * CostFunction::computeCost(pixL, pixR);
            }

            // clip and store cost
            ((float*)costs.ptr)[pos3] = min(cost, clipCost);
        }
    }
}

//---- AD cost computation ------------------------------------------------------------------------

struct AbsoluteDifference
{
    inline static __device__ float computeCost(const uchar4 pixL, const uchar4 pixR)
    {
        const float diffX = (float)pixL.x - (float)pixR.x;
        const float diffY = (float)pixL.y - (float)pixR.y;
        const float diffZ = (float)pixL.z - (float)pixR.z;
        return (fabs(diffX) + fabs(diffY) + fabs(diffZ)) / 3.0f;
    }
};

//---- SD cost computation ------------------------------------------------------------------------

struct SquaredDifference
{
    inline static __device__ float computeCost(const uchar4 pixL, const uchar4 pixR)
    {
        const float diffX = (float)pixL.x - (float)pixR.x;
        const float diffY = (float)pixL.y - (float)pixR.y;
        const float diffZ = (float)pixL.z - (float)pixR.z;
        return (diffX * diffX + diffY * diffY + diffZ * diffZ) / 3.0f;
    }
};

//---- runner functions ---------------------------------------------------------------------------

// runs the cost kernel with the specified cost function
template <typename CostFunction> void runCost(
    const Cuda::DeviceMemory<unsigned int, 2>& gpuImg1,
    const Cuda::DeviceMemory<unsigned int, 2>& gpuImg2,
    Cuda::DeviceMemory<float, 3>& gpuCost,
    const unsigned int w, const unsigned int h,
    const unsigned int numDisps,
    const float scaleCost,
    const float clipCost,
    const char* costName)
{
    // Since the cost computation is independent for all scanlines,
    // we give blocks a height of 1 and width of 256, for maximum occupancy.
    dim3 costBlock(256, 1, 1);

    // The grid structure is now determined by the block structure,
    // as the entire image size needs to be covered.
    dim3 costGrid( (w + costBlock.x - 1) / costBlock.x, h, 1);

    RECORD_KERNEL_LAUNCH(costName, costGrid, costBlock);

    costKernel<CostFunction><<<costGrid, costBlock, (costBlock.x + numDisps) * sizeof(uchar4)>>>(
        toPitchedPtr(gpuImg1), toPitchedPtr(gpuImg2), toPitchedPtr(gpuCost), w, h, numDisps, scaleCost, clipCost);

    CHECK_KERNEL_ERROR(costName);
}


// helper macro to define cost kernel runner functions
#define DEFINE_COST_KERNEL(RUN_FUNCTION, COST_FUNCTOR, COST_NAME)   \
void RUN_FUNCTION(                                                  \
    unsigned int* leftImage,                                        \
    unsigned int* rightImage,                                       \
    const int stride,                                               \
    const int width,                                                \
    const int height,                                               \
    const int disparityMax,                                         \
    const float rescaleGradient,                                    \
    const float rescaleLimit,                                       \
    const cudaPitchedPtr& result)                                   \
{                                                                   \
    Cuda::DeviceMemoryReference3D<float> ctGpuCost(width, height, disparityMax, (float*)result.ptr); \
    Cuda::DeviceMemoryReference2D<unsigned int> ctGpuImgL(width, height, leftImage);  \
    Cuda::DeviceMemoryReference2D<unsigned int> ctGpuImgR(width, height, rightImage); \
    ctGpuCost.setPitch(result.pitch);                   \
    ctGpuImgL.setPitch(stride * sizeof(unsigned int));  \
    ctGpuImgR.setPitch(stride * sizeof(unsigned int));  \
    runCost<COST_FUNCTOR>(ctGpuImgL, ctGpuImgR, ctGpuCost, width, height, disparityMax, rescaleGradient, rescaleLimit, (COST_NAME)); \
}


// runners for defined cost functions
DEFINE_COST_KERNEL(RunCostSpaceCrAdKernel,   AbsoluteDifference, "AD cost kernel")
DEFINE_COST_KERNEL(RunCostSpaceCrSdKernel,   SquaredDifference,  "SD cost kernel")

#undef DEFINE_COST_KERNEL // cleanup