#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 22 12:32:35 2020

@author: am2806
"""

import numpy as np
import math
import cv2
import os
import skimage.measure as skmeasure
import torch as pt
import imageio
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from PIL import Image
import matplotlib.pyplot as plt

def load(path):
    img = Image.open(path)
    return img

to_tensor = transforms.ToTensor()

to_pil = transforms.ToPILImage()

def torch_to_np(img_var):
    return img_var.detach().cpu().numpy()[0]

def np_to_torch(img_np):
    return pt.from_numpy(img_np)[None, :]

def reverse_channels(img):
    return np.moveaxis(img, 0, -1) # source, dest 


def read_convert_pt_image(image_path):
    image = imageio.imread(image_path)
    image = pt.from_numpy(imageio.core.asarray(image))
    image = image.permute(2,0,1)
    return image


def Resize(image):   # Original Size of the GT Image
    
    image = image.resize((480,320), Image.BILINEAR)
    return image

def Resize_small(image,factor):
    h,w= image.size
    image = image.resize((int(h/factor),int(w/factor)), Image.BILINEAR)
    return image
#%%


#gt_path= 'DS_US/106024.png'
#image_name= '106024'
path= 'Sample/results'

save_path= 'New'
folders= os.listdir(path)

#folders= folders[2:]
print(folders)
for folder_name in folders:
    print(folder_name)
    if folder_name != 'GT' and folder_name != 'html_generator_new.py' and folder_name != 'test.html' and folder_name != 'test_new.html' and  folder_name != 'GT_png' and  folder_name != 'DQ' and  folder_name != 'LR_Images':
        
        application= path+'/'+folder_name
        file_names= os.listdir(application)
        file_names_sorted= sorted(file_names)
        only_50= file_names_sorted[0:30]
        for files in only_50:
            print(files)
            img= load(application + '/'+files +'/' + files +'_SR.png')
            rgb_im= img.convert('RGB')
            
            if not os.path.exists(save_path +'/'+application+'/'+ files+'/' ):
                os.makedirs(save_path +'/'+application+'/'+ files +'/')
                
            rgb_im.save(save_path +'/'+application+'/'+ files +'/' + files+ '_SR.jpg' , 'JPEG', quality= 40)
            
            

gt_imgs=  os.listdir(path + '/GT')
file_names_sorted= sorted(gt_imgs)
only_50= file_names_sorted[0:30]   

for files in only_50:
    print(files)
    img= load(path + '/GT/' + files)
    rgb_im= img.convert('RGB')
    if not os.path.exists(save_path + '/'+  path + '/GT/'):
        os.makedirs(save_path + '/'+  path + '/GT/')
    rgb_im.save(save_path + '/'+  path + '/GT/' + files , 'JPEG', quality= 40)


#%%
import Loss_Modules.singan_loss_corrected as losses
import skimage.io
import torch
import pandas as pd
import matplotlib.pyplot as plt
import math
import numpy as np


gt_path= 'DS_US/106024.png'
image_name= '106024'
generated_path= 'DS_US/Images/'

factor_list= np.arange(1,8.5,0.5)
x = skimage.io.imread(gt_path)

x = x[:, :, :, None]
x = x.transpose([3, 2, 0, 1]) / 255.
x = torch.from_numpy(x)
x = x.type(torch.cuda.FloatTensor)

# Test batch capabilities
#x = torch.cat([x, x, x, x, x, x, x], dim=0)  
total_loss=[]
for factor in factor_list:
    print(factor)
    distored_img = "106024_"+str(factor)+".png"
    y = skimage.io.imread(generated_path+'/'+ distored_img)
    print('y.shape is', y.shape)
    y=y[:,:,0:3]
    y = y[:, :, :, None]
    print('y.shape is', y.shape)
    y = y.transpose([3, 2, 0, 1]) / 255.
    y = torch.from_numpy(y)
    y = y.type(torch.cuda.FloatTensor)
    
    sin_gan_loss = losses.SinGANLoss('Loss_Modules/discriminator_models/Ds_trees.pth')
    loss = sin_gan_loss(x, y, 1, 1)
    total_loss.append(loss)

plt.plot(factor_list,total_loss, 'bo--')
plt.xlabel('DS US Factor')
plt.ylabel('Loss')
plt.title(' Ours Trees 8_1')
plt.savefig('DS_US/DS_US_Plots/DS_US_Ours_Trees_1_1')


#%%
from All_Losses import LC
LC = LC()

loss_function = LC.get_attribute('SSIM')
criterion = loss_function    
criterion = criterion.cuda()

gt_path= 'DS_US/106024.png'
image_name= '106024'
generated_path= 'DS_US/Images/'

factor_list= np.arange(1,8.5,0.5)
x = skimage.io.imread(gt_path)

x = x[:, :, :, None]
x = x.transpose([3, 2, 0, 1]) / 255.
x = torch.from_numpy(x)
x = x.type(torch.cuda.FloatTensor)

# Test batch capabilities
#x = torch.cat([x, x, x, x, x, x, x], dim=0)  
total_loss=[]
for factor in factor_list:
    print(factor)
    distored_img = "106024_"+str(factor)+".png"
    y = skimage.io.imread(generated_path+'/'+ distored_img)
    print('y.shape is', y.shape)
    y=y[:,:,0:3]
    y = y[:, :, :, None]
    print('y.shape is', y.shape)
    y = y.transpose([3, 2, 0, 1]) / 255.
    y = torch.from_numpy(y)
    y = y.type(torch.cuda.FloatTensor)
    
#    sin_gan_loss = losses.SinGANLoss('Loss_Modules/discriminator_models/Ds_trees.pth')
#    loss = sin_gan_loss(x, y, 8, 1)
    loss= criterion(x,y)
    total_loss.append(loss)


plt.plot(factor_list,total_loss, 'bo--')
plt.xlabel('DS US Factor')
plt.ylabel('Loss')
plt.title('SSIM')
plt.savefig('DS_US/DS_US_Plots/DS_US_SSIM')

