import numpy as np
import time
import sys
import os
import serv_manager as svm

PPC = 10

def sortPos(val):
  return val[1]

class Template:
  def __init__(self, func):
    if (func[0]=="A")or(func[0]=="B")or(func[0]=="C")or(func[0]=="D"):
      if (int(func[1:3])>=0)or(int(func[1:3])<24):
        label = "ics_original_010/ics_"+func+"_i"
      else:
        print(("Class init error: "+func))
        sys.exit()
    else:
       print(("Class init error: "+func))
       sys.exit()
    if (func[0]=="A")or(func[0]=="B"):
      self.Size = 200
    elif (func[0]=="C")or(func[0]=="D"):
      self.Size = 40
    self.name = func
    print(("Initializing "+self.name))
    self.ICs = []
    fname = "templateLDA_O010/template_"+func+"/template"
    self.INV = []
    self.EXP = []
    self.AVE = []
    for byte in range(0, self.Size):
      ints = byte//4
      ics_name = label+str(ints).zfill(2)+".npy"
      ics_byte = svm.Load(ics_name)
      self.ICs.append(ics_byte)
      Sname = fname+"_scov_b"+str(byte).zfill(3)+".npy"
      Scov = svm.Load(Sname)
      matS = np.matrix(Scov)
      ImatS = matS.I
      self.INV.append(ImatS)
      Aname = fname+"_avts_b"+str(byte).zfill(3)+".npy"
      Avecs = svm.Load(Aname)
      Amat = []
      for a in range(0, len(Avecs)):
        amt = np.matrix(Avecs[a])
        Amat.append(amt)
      self.AVE.append(Amat)
      Ename = fname+"_expect_b"+str(byte).zfill(3)+".npy"
      try:
        Tmeans = np.load(Ename)
      except:
        print(("  No expected trace for byte "+str(byte)+" and being generated."))
        Mname = fname+"_model_b"+str(byte).zfill(3)+".npy"
        Models = svm.Load(Mname)
        Tmeans = []
        for x in range(0, 256):
          trace_e = []
          Xbits = bin(x)[2:].zfill(8)
          inter = []
          for b in range(0, 8):
            inter.append(float(Xbits[b]))
          inter.append(1.0)
          x_inter = np.transpose(np.matrix(inter))
          for p in range(0, len(Models)):
            coef = np.matrix(Models[p])
            res = (coef*x_inter).item(0)
            trace_e.append(res)
          mtr = np.matrix(trace_e)
          compo = []
          for a in range(0, len(Amat)):
            compo.append((mtr*Amat[a]).item(0))
          compo = np.matrix(compo)
          Tmeans.append(compo)
        svm.Save(Ename, Tmeans)
        Tmeans = svm.Load(Ename)
      self.EXP.append(Tmeans)

  def Guess(self, Trace, sortbyprob=False):
    print(("Guessing "+self.name))
    Rank_Table = []
    ICs = self.ICs
    for byte in range(0, self.Size):
      #print "Byte:", byte
      Poss = []
      ips = []
      for ic in range(0, len(ICs[byte])):
        for p in range(0, PPC):
          indx = ICs[byte][ic]*PPC+p
          ips.append(Trace[indx])
      ips_mat = np.matrix(ips)
      Xm = []
      for comp in range(0, len(self.AVE[byte])):
        Xm.append((ips_mat*self.AVE[byte][comp]).item(0))
      Xm = np.matrix(Xm)
      length = len(ips)
      ImatS = self.INV[byte]
      Prob_sum = 0.0
      for x in range(0, 256):
        tM = self.EXP[byte][x]
        matX = Xm - tM
        pos = np.exp(-0.5*abs((matX*ImatS*np.transpose(matX)).item(0)))
        #print(("++ Candidate:", x, "Pos:", pos))
        Prob_sum += pos
        pair = []
        pair.append(x)
        pair.append(pos)
        Poss.append(pair)
      for x in range(0, 256):
        Poss[x][1] /= Prob_sum
      if sortbyprob:
        Poss.sort(reverse=True, key = sortPos)
      Rank_Table.append(Poss)
    return Rank_Table



