#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 22 12:32:35 2020

@author: am2806
"""

import numpy as np
import math
import cv2
import os
import skimage.measure as skmeasure
import torch as pt
import imageio
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from PIL import Image
import matplotlib.pyplot as plt

def load(path):
    img = Image.open(path)
    return img

to_tensor = transforms.ToTensor()

to_pil = transforms.ToPILImage()

def torch_to_np(img_var):
    return img_var.detach().cpu().numpy()[0]

def np_to_torch(img_np):
    return pt.from_numpy(img_np)[None, :]

def reverse_channels(img):
    return np.moveaxis(img, 0, -1) # source, dest 


def read_convert_pt_image(image_path):
    image = imageio.imread(image_path)
    image = pt.from_numpy(imageio.core.asarray(image))
    image = image.permute(2,0,1)
    return image


def mse(x, y):
    return np.mean( (x - y) ** 2 )#np.linalg.norm(x - y)


def psnr(img1, img2):
    mse = np.mean( (img1 - img2) ** 2 )
#    print('mse is ', mse)
#    print('mse shape', mse.dtype)
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

def PSNR(pred, gt, shave_border=0):
    height, width = pred.shape[:2]
    pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
    gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
    imdff = pred - gt
    rmse = math.sqrt(np.mean(imdff ** 2))
    if rmse == 0:
        return 100
    return 20 * math.log10(255.0 / rmse)


original_path= 'results/GT'
#generated_path= 'out/places10_CR_UN_Supervised_Version2'#places10'
#generated_path= 'Ablation_Out_Lambda/Ablation_1000'
#generated_path= 'Output_Images/Hazy'

generated_path= 'results/Bicubic'
#FSIM_loss = FSIMScore()
#files= os.listdir(original_path)
#files1=  os.listdir(generated_path)

psnr1=[]
psnr2=[]
psnr3=[]
m=[]
s=[]
#mm=[]

for file in os.listdir(original_path):
#    print (file)
#    original= cv2.imread(original_path +'/'+ file)
#    original=original.astype(float)
#    generated =  cv2.imread(generated_path +'/Big_Buck_Bunny_Resized-'+ file)
#    generated =  cv2.imread(generated_path +'/'+ file)
#    generated=generated.astype(float)
    print(file)
#    imgr = read_convert_pt_image(original_path +'/'+ file)
#    imgr = imgr.unsqueeze(0).type(pt.FloatTensor)
#    
#    imgd = read_convert_pt_image(generated_path +'/'+ file)
#    imgd = imgd.unsqueeze(0).type(pt.FloatTensor)
    
    img_ref= plt.imread(original_path +'/'+ file)
    img_ref=img_ref.astype(float)
#    img_ref= img_ref/255
    
    file_name_only= file.split('.')[0]
    print('file_name_only is', file_name_only)
    img_gen= plt.imread(generated_path +'/'+ file_name_only +'/'+ file_name_only +'_SR.jpg')[:,:, 0:3]
#    img_gen= plt.imread(generated_path +'/'+  file_name_only +'.png')[:,:, 0:3]
    
#    if not os.path.exists('Inference_Outputs/Alpha/'+'SinGAN'+'/'+ file_name_only +'/'):
#        os.makedirs('Inference_Outputs/Alpha/'+'SinGAN'+'/'+ file_name_only+'/')
#    plt.imsave('Inference_Outputs/Alpha/SinGAN/'+ file_name_only + '/'+file_name_only+'_SR.png', img_gen)
    img_gen=img_gen.astype(float)#*255
    
#    print(img_ref.shape)
#    print(img_gen.shape)
    
    
    psnr1.append( PSNR(img_ref, img_gen, shave_border=0) )  # img should be float
    
#    print(psnr_bicubic)
#    fsim_out = FSIM_loss.forward_fsimc(imgr,imgd)
#    print('FSIM value is', fsim_out)
    
#    psnr1.append(skimage.measure.compare_psnr(imgr, imgd, data_range= imgr.max() - imgr.min()))
#    psnr2.append(cv2.PSNR(imgr, imgd))
    psnr3.append(psnr(img_ref, img_gen ))
    
#    print(psnr3)
    m.append(mse(img_ref, img_gen))
#    m.append(skimage.measure.compare_mse(original, generated))
    s.append(skmeasure.compare_ssim(img_ref, img_gen, multichannel=True, data_range=img_ref.max() - img_ref.min()))


#original = cv2.imread("original.png")
#contrast = cv2.imread("photoshopped.png",1)


print('Mean PSNR is:', np.array(psnr1).mean())
#print('Mean PSNR is:', np.array(psnr2).mean())
print('Mean PSNR is:', np.array(psnr3).mean())

print('Mean MSE is:', np.array(m).mean())
print('Mean SSIM is:', np.array(s).mean())
#print('Mean MSE is:', np.array(mm).mean())


#%%


