import math
import numpy as np


class TransferFunction:

    """
    Create a transfer function coding object. The input (camera RAW) signal will be between 0 and
    2**bits-1. The encoded values will be between 0 and 1.
    """
    def __init__(self, bits):        
        self.bits = bits

    def decode(self,V):
        L_float = self.decode_float(V)
        return (L_float+0.5).astype(int)  # add 0.5 for rounding

    def decode_float(self, V):
        raise NotImplementedError()


class TF_linear(TransferFunction):

    """
    Encode linear camera values (irradiance) given as integer values between 0 and 2**bits-1.
    The encoded values V will be in the 0-1 range.
    """
    def encode(self, L):
        V = L.astype(float)/(2**self.bits-1)
        return V

    """
    Decode values produced by "encode" method into linear camera values (irradiance). The camera
    linear values are in the range between 0 and 2**bits-1.
    The encoded values V must be in the range from 0 to 1.
    """
    def decode_float(self, V):
        L = V * (2**self.bits-1)
        return L


class TF_gamma(TransferFunction):
    """
    Apply gamma encoding which is equivalent to raising the provided value to an exponent.
    In practice, gamma lies between 2.2 and 2.4
    """
    def __init__(self, bits, gamma=2.2):
        self.bits = bits
        self.gamma = gamma

    def encode(self, L):
        V = np.power(L.astype(float)/(2**self.bits-1), 1/self.gamma)
        return V

    def decode_float(self, V):        
        L = np.power(V,self.gamma)*(2**self.bits-1)
        return L


class TF_log(TransferFunction):
    """
    A logarithmic transfer function that encodes the values from 1 to 2^bits-1. 
    """
    def encode(self, L):
        gt1 = (L>=1)
        V = np.zeros(L.shape)
        V[gt1] = np.log2(L[gt1].astype(float)) / math.log2(2**self.bits-1)
        return V

    def decode_float(self, V):        
        le1 = (V<1/(2**self.bits-1))
        L = np.power(2, V*math.log2(2**self.bits-1))
        L[le1] = 0
        return L


class TF_PQ(TransferFunction):
    #Lmax = 10000
    n    = 0.15930175781250000
    m    = 78.843750000000000
    c1   = 0.83593750000000000
    c2   = 18.851562500000000
    c3   = 18.687500000000000

    """
    The Perceptual Quantization (PQ) transfer function was specifically designed to maximize
    luminance encoding efficiency for human vision.
    """
    def __init__(self, bits):        
        self.bits = bits
        self.max_val = 2**self.bits-1
        self.min_val = self._abs_encode( 1 ) 
 
    def _abs_encode(self, L):
        im_t = np.power(L/self.max_val,self.n)
        V  = np.power((self.c2*im_t + self.c1) / (1+self.c3*im_t), self.m)
        return V

    def encode(self, L):
        return (self._abs_encode(L) - self.min_val)/(1 - self.min_val)
 
    def decode_float(self, V):
        V_rescaled = V *  (1 - self.min_val) + self.min_val 
        im_t = np.power(np.maximum(V_rescaled,0),1/self.m)
        L = self.max_val * np.power(np.maximum(im_t-self.c1,0)/(self.c2-self.c3*im_t), 1/self.n)        
        return L
