#include "StereoDataTypes.h"
#include "CudaHelperCommon.cuh"

// Helper function
template<typename T>
void Swap(T & a, T & b)
{
	T temp = a;
	a = b;
	b = temp;
}

/******************************************************* RgbImage **************************************************************/

RgbImage::~RgbImage()
{
	Destroy();
}
void RgbImage::Create(RgbImageType type, int width, int height)
{
	Destroy();

	CUDA_CALL(cudaMallocPitch((void**)&gpuImage, &gpuImagePitch, width * sizeof(unsigned int), height));
	w = width;
	h = height;
	imageType = type;
}
void RgbImage::Destroy()
{
	if(gpuImage != NULL)
	{
		cudaFree(gpuImage);
		gpuImage = NULL;
		w = h = 0;
	}
}
void RgbImage::SwapData(RgbImage & other)
{
	Swap<unsigned int*>(gpuImage, other.gpuImage);
	Swap<RgbImageType>(imageType, other.imageType);
	Swap<size_t>(gpuImagePitch, other.gpuImagePitch);
	Swap<int>(w, other.w);
	Swap<int>(h, other.h);
}

/******************************************************* CostSpaceGrid **************************************************************/

CostSpaceGrid::~CostSpaceGrid()
{
	Destroy();
}
void CostSpaceGrid::Create(CostSpaceGridType type, int width, int height, int depth)
{
	Destroy();

	cudaExtent extent;

	switch(type)
	{
	case COST_SPACE_GRID_TYPE_ANY: //fall-through
	case COST_SPACE_GRID_TYPE_SINGLE:
		extent.width = width * sizeof(float);
		extent.height = height;
		extent.depth = depth;
		break;

	case COST_SPACE_GRID_TYPE_HALFS_D:
		extent.width = width * sizeof(unsigned int);
		extent.height = height;
		extent.depth = (depth - 1) / 2 + 1;
		break;
	}

	CUDA_CALL(cudaMalloc3D(&gpuGrid, extent));
	w = width;
	h = height;
	d = depth;
	gridType = type;
}
void CostSpaceGrid::Destroy()
{
	if(gpuGrid.ptr != NULL)
	{
		cudaFree(gpuGrid.ptr);
		gpuGrid.ptr = NULL;
		w = h = d = 0;
	}
}
void CostSpaceGrid::SwapData(CostSpaceGrid & other)
{
	Swap<cudaPitchedPtr>(gpuGrid, other.gpuGrid);
	Swap<CostSpaceGridType>(gridType, other.gridType);
	Swap<int>(w, other.w);
	Swap<int>(h, other.h);
	Swap<int>(d, other.d);
}
void CostSpaceGrid::SizeToMatch(const CostSpaceGrid & other)
{
	Create(other.gridType, other.GetWidth(), other.GetHeight(), other.GetDepth());
}

/******************************************************* DepthMap **************************************************************/

DepthMap::~DepthMap()
{
	Destroy();
}
void DepthMap::Create(DepthMapType type, int width, int height)
{
	Destroy();

	CUDA_CALL(cudaMallocPitch((void**)&gpuImage, &gpuImagePitch, width * sizeof(float), height));
	w = width;
	h = height;
	mapType = type;
}
void DepthMap::Destroy()
{
	if(gpuImage != NULL)
	{
		cudaFree(gpuImage);
		gpuImage = NULL;
		w = h = 0;
	}
}
void DepthMap::SwapData(DepthMap & other)
{
	Swap<float*>(gpuImage, other.gpuImage);
	Swap<size_t>(gpuImagePitch, other.gpuImagePitch);
	Swap<DepthMapType>(mapType, other.mapType);
	Swap<int>(w, other.w);
	Swap<int>(h, other.h);
}