import numpy as np
import time
import sys
import os
from numpy import linalg as LA
import scipy.stats as ss
from sklearn.linear_model import LinearRegression
import serv_manager as svm

Dir_ics = "./ics_original_010/"
PPC = 10
TAG = "O010"

alphabet2number = {
    "0":0,  "1":1,  "2":2,  "3":3,
    "4":4,  "5":5,  "6":6,  "7":7,
    "8":8,  "9":9,  "a":10, "b":11,
    "c":12, "d":13, "e":14, "f":15
}

def training(byte, group):
  SETnum = 400
  Total_Tnum = SETnum*160
  print("==============")
  print([byte, group])
  ints = byte//4
  icsname = Dir_ics+"ics_"+group+"_i"+str(ints).zfill(2)+".npy"
  ICS = svm.Load(icsname)
  Pnum = PPC*len(ICS)
  print("Loading Intermediate Values (in bits)...", time.asctime())
  intername = "./intermediate_B_"+group+"/"+group+"_b"+str(byte).zfill(3)+".npy"
  InterBits = svm.Load(intername)
  A = np.matrix([[128.0], [64.0], [32.0], [16.0], [8.0], [4.0], [2.0], [1.0]])
  temp_BVal = np.matrix(InterBits)*A
  ByteValue = []
  for t in range(0, Total_Tnum):
    ByteValue.append(int(temp_BVal.item(t)))
  print("Loading Resampled Trace Data", time.asctime())
  Re_Traces = []
  for p in range(0, Pnum):
    Re_Traces.append([])
  set_dir = "./Ints_"+group+"_i"+str(ints).zfill(2)+"/"
  for t_set in range(0, SETnum):
    setname = set_dir+"set_"+str(t_set).zfill(3)+".npy"
    rTraces = svm.Load(setname)
    for p in range(0, Pnum):
      for t in range(0, 160):
        Re_Traces[p].append(rTraces[t][p])
    rTraces = []
  ## Get Models ##################
  print(len(Re_Traces), len(Re_Traces[0]), time.asctime())
  Models = []
  Zero = []
  for p in range(0, Pnum):
    if p%PPC==0:
      print("point ", p,"/",Pnum)
    Zero.append(0.0)
    model = []
    samples = Re_Traces[p]
    reg = LinearRegression().fit(InterBits, samples)
    expect = reg.predict(InterBits)
    bb = reg.coef_
    aa = reg.intercept_
    whole = aa
    for b in range(0, 8):
      model.append(bb[b])
    model.append(aa)
    Models.append(model)
  fname = "./templateLDA_"+TAG+"/template_"+group+"/template"
  Mname = fname+"_model_b"+str(byte).zfill(3)+".npy"
  svm.Save(Mname, Models)
  Model_matx = np.transpose(np.matrix(Models))
  Models = []
  Re_Traces = []
  #######################################################
  ZeroCL = []
  AVectors = []
  ComprTraces = []
  for p in range(0, Pnum):
    ZeroCL.append(0.0)
  ZeroTemp = np.matrix(ZeroCL)
  for t in range(0, Total_Tnum):
    ComprTraces.append([])
  Bmat = np.transpose(ZeroTemp)*ZeroTemp
  Wmat = np.transpose(ZeroTemp)*ZeroTemp
  ## Adding B #########################################
  print("  Adding B", time.asctime())
  AVE_mat = np.matrix([[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0]])
  Total_AVE = AVE_mat*Model_matx
  load_bits = np.load("Bits.npy")
  Check_total_number = 0
  Exp_Matrices = []
  for by in range(0, 256):
    by_num = ByteValue.count(by)
    #print("  Byte: "+str(by)+", Num = "+str(by_num))
    Check_total_number += by_num
    Exp_tr = np.matrix(load_bits[by])*Model_matx
    Exp_Matrices.append(Exp_tr)
    sub_inter = Exp_tr-Total_AVE
    Bmat += by_num*(np.transpose(sub_inter)*sub_inter)
  if Check_total_number != Total_Tnum:
    print("Error: number did not match.", time.asctime())
    exit()
  ## Adding W #########################################
  print("  Adding W", time.asctime())
  for t_set in range(0, SETnum):
    if (t_set%50)==0:
      print("  Set = ", t_set, time.asctime())
    setname = set_dir+"set_"+str(t_set).zfill(3)+".npy"
    rTraces = svm.Load(setname)
    for t in range(0, 160):
      this_trace = []
      for p in range(0, Pnum):
        this_trace.append(rTraces[t][p])
      sub_inner = np.matrix(this_trace)-Exp_Matrices[(ByteValue[(t_set*160+t)])]
      Wmat += np.transpose(sub_inner)*sub_inner
    rTraces = []
  print("  Finding A", time.asctime())
  Target = (Wmat.I)*Bmat
  EigVLs, EigVRs = LA.eig(Target)
  tempEigVLs = abs(EigVLs)
  non_zero = 0
  for VIndx in range(0, len(tempEigVLs)):
    if tempEigVLs[VIndx]>(0.001*sum(tempEigVLs)):
      print(VIndx, tempEigVLs[VIndx])
      non_zero+=1
      A = EigVRs[:,VIndx]
      AVectors.append(A)
  print("There are "+str(non_zero)+" vectors non-zero")
  print("Total: ", sum(tempEigVLs))
  print("  Compressing", time.asctime())
  for t_set in range(0, SETnum):
    #print("  Set = ", t_set)
    setname = set_dir+"set_"+str(t_set).zfill(3)+".npy"
    rTraces = svm.Load(setname)
    for t in range(0, 160):
      this_trace = []
      for p in range(0, Pnum):
        this_trace.append(rTraces[t][p])
      sub_inner = np.matrix(this_trace)-Exp_Matrices[(ByteValue[(t_set*160+t)])]
      for A_indx in range(0, len(AVectors)):
        temp =(sub_inner*AVectors[A_indx]).item(0)
        ComprTraces[(t_set*160+t)].append(temp)
    rTraces = []
  ZeroVc = 0.0*np.matrix(ComprTraces[0])
  Scov = np.transpose(ZeroVc)*ZeroVc
  for t in range(0, Total_Tnum):
    tempV = np.matrix(ComprTraces[t])
    Scov += np.transpose(tempV)*tempV
  Scov *= 1.0/(float(Total_Tnum-9))
  #######################################################
  Sname = fname+"_scov_b"+str(byte).zfill(3)+".npy"
  svm.Save(Sname, Scov)
  Aname = fname+"_avts_b"+str(byte).zfill(3)+".npy"
  svm.Save(Aname, AVectors)
  print("  Finished.", time.asctime())
  return

if __name__=='__main__':
  tS = time.time()
  print(time.asctime())
  Gr = sys.argv[1]
  lower = int(sys.argv[2])
  upper = int(sys.argv[3])
  for ints in range(lower, upper):
    os.system("kinit -R")
    Name_Int = "Ints_"+Gr+"_i"+str(ints).zfill(2)
    cmd = "unzip -qq "+Name_Int+".zip"
    print(cmd)
    svm.System(cmd)
    for byte in range((ints*4),(ints*4+4)):
      training(byte, Gr)
    cmd = "rm -r "+Name_Int+"/"
    print(cmd)
    svm.System(cmd)
    print(time.asctime())
  tE = time.time()
  print("Exe. time = ", tE-tS)


