import numpy as np
import get_tables as template
import processing as dejoint
import SASCA_scan
import serv_manager as svm
import KECCAK as KEC
import os
import sys

###################################################################################
#
# Independent parameters
#
# SHA Size:
D_Size = 512
# Invocation number:
INVOC = 1
Tr_NUM = 100
###################################################################################
C_Size = 2*D_Size
R_Size = 1600-C_Size
P_N = [[0, 0.5], [1, 0.5]]
P_0 = [[0, 1.0], [1, 0.0]]
P_1 = [[0, 0.0], [1, 1.0]]
Project_Tag = "SHA3_"+str(D_Size)+"_I"+str(INVOC).zfill(2)+"_"
PPC = 10

class Template_Recover:
  def __init__(self):
    self.Template_A00 = template.Template("A00")
    self.Template_A01 = template.Template("A01")
    self.Template_A02 = template.Template("A02")
    self.Template_A03 = template.Template("A03")
    self.Template_B00 = template.Template("B00")
    self.Template_B01 = template.Template("B01")
    self.Template_B02 = template.Template("B02")
    self.Template_B03 = template.Template("B03")
    self.Template_C00 = template.Template("C00")
    self.Template_C01 = template.Template("C01")
    self.Template_C02 = template.Template("C02")
    self.Template_C03 = template.Template("C03")
    self.Template_D00 = template.Template("D00")
    self.Template_D01 = template.Template("D01")
    self.Template_D02 = template.Template("D02")
    self.Template_D03 = template.Template("D03")
    self.ZeroIN = []
    return
  
  def in_table(self, State):
    Bits = []
    bit_IN = []
    for lane in range(0, 25):
      for k in range(0, 64):
        if (State[lane]&(2**k))==0:
          Bits.append(0)
        else:
          Bits.append(1)
    for t in range(0, R_Size):
      bit_IN.append(P_N)
    for t in range(R_Size, 1600):
      if Bits[t]==0:
        bit_IN.append(P_0)
      else:
        bit_IN.append(P_1)
    return bit_IN
  
  def recovering(self, trace, pre_inputs):
    GuessA00 = self.Template_A00.Guess(trace)
    GuessA01 = self.Template_A01.Guess(trace)
    GuessA02 = self.Template_A02.Guess(trace)
    GuessA03 = self.Template_A03.Guess(trace)
    GuessB00 = self.Template_B00.Guess(trace)
    GuessB01 = self.Template_B01.Guess(trace)
    GuessB02 = self.Template_B02.Guess(trace)
    GuessB03 = self.Template_B03.Guess(trace)
    GuessC00 = self.Template_C00.Guess(trace)
    GuessC01 = self.Template_C01.Guess(trace)
    GuessC02 = self.Template_C02.Guess(trace)
    GuessC03 = self.Template_C03.Guess(trace)
    GuessD00 = self.Template_D00.Guess(trace)
    GuessD01 = self.Template_D01.Guess(trace)
    GuessD02 = self.Template_D02.Guess(trace)
    GuessD03 = self.Template_D03.Guess(trace)
    b_A00 = dejoint.table_processing(GuessA00)
    b_A01 = dejoint.table_processing(GuessA01)
    b_A02 = dejoint.table_processing(GuessA02)
    b_A03 = dejoint.table_processing(GuessA03)
    b_B00 = dejoint.table_processing(GuessB00)
    b_B01 = dejoint.table_processing(GuessB01)
    b_B02 = dejoint.table_processing(GuessB02)
    b_B03 = dejoint.table_processing(GuessB03)
    b_C00 = dejoint.table_processing(GuessC00)
    b_C01 = dejoint.table_processing(GuessC01)
    b_C02 = dejoint.table_processing(GuessC02)
    b_C03 = dejoint.table_processing(GuessC03)
    b_D00 = dejoint.table_processing(GuessD00)
    b_D01 = dejoint.table_processing(GuessD01)
    b_D02 = dejoint.table_processing(GuessD02)
    b_D03 = dejoint.table_processing(GuessD03)
    b_INP = self.in_table(pre_inputs)
    return b_INP,b_C00,b_D00,b_A00,b_B00,b_C01,b_D01,b_A01,b_B01,b_C02,b_D02,b_A02,b_B02,b_C03,b_D03,b_A03,b_B03

def Table2State(Table):
  State = []
  for lane in range(0, 25):
    temp = 0
    for bit in range(0, 64):
      if Table[(lane*64+bit)][0][1]>0.5:
        temp ^= 0
      else:
        temp ^= (2**bit)
    State.append(temp)
  return State


def Consistent_Check(TableINP, TableA00):
  StateA00 = Table2State(TableA00)
  StateINP = Table2State(TableINP)
  temp0 = KEC.Back_RhoPi(StateA00)
  temp1 = KEC.Back_Theta(temp0)
  cal_stringH = KEC.lane2hex(temp1)
  inp_stringH = KEC.lane2hex(StateINP)
  Start = R_Size//4
  Next = KEC.Keccak_f1600(KEC.hex2lane(cal_stringH))
  if cal_stringH[Start:]==inp_stringH[Start:]:
    return True, temp1, Next
  else:
    return False, temp1, Next

def Keccak_input_strip(kec_in):
  if kec_in[-2:]=="86":
    output = kec_in[:-2]
    return output
  elif kec_in[-2:]=="80":
    tmp = kec_in[:-2]
    while(tmp[-2:]=="00"):
      tmp = tmp[:-2]
    if tmp[-2:]=="06":
      output = tmp[:-2]
      return output
  return "Wrong_In"

def Set_recover(Set_N, tr):
  Success = []
  Predictions = []
  Scan_Num = []
  TA = Template_Recover()
  Set_Tag = Project_Tag+str(PPC).zfill(2)+"_"+str(Set_N).zfill(4)
  ################################################################
  ans_inputs = []
  ans_name = "./data_raw_in/Raw_"+Project_Tag+str(Set_N).zfill(4)+"_data_in.npy"
  ans_inputs = svm.Load(ans_name)
  ################################################################
  for t in range(0, tr):
    print("======================================================")
    print("SHA3_input #"+str(t).zfill(2))
    print("======================================================")
    rec_IN = ""
    prev = [0]*25
    Scan_it = [-1]*INVOC
    Suc_state = True
    for invoc in range(0, INVOC):
      Tname = "./"+Set_Tag+"/trace_"+str(t).zfill(3)+str(invoc).zfill(1)+".npy"
      print(Tname)
      trace = svm.Load(Tname)
      b_INP,b_C00,b_D00,b_A00,b_B00,b_C01,b_D01,b_A01,b_B01,b_C02,b_D02,b_A02,b_B02,b_C03,b_D03,b_A03,b_B03 = TA.recovering(trace, prev)
      print("Loopy-BP processing...")
      INP, A00, B00, A01, B01, A02, B02, A03, B03, Scan_it[invoc] =  SASCA_scan.State_Scan(b_INP,b_C00,b_D00,b_A00,b_B00,b_C01,b_D01,b_A01,b_B01,b_C02,b_D02,b_A02,b_B02,b_C03,b_D03,b_A03,b_B03)
      Suc, inv_in, inv_out = Consistent_Check(INP, A00)
      Suc_state = Suc_state&Suc
      if Suc_state:
        print("Consistence check passed!")
        IN_XOR_bin = []
        for lane in range(0, 25):
          IN_XOR_bin.append(inv_in[lane]^prev[lane])
        IN_XOR_hex = KEC.lane2hex(IN_XOR_bin)
        rec_IN += (IN_XOR_hex[:(R_Size//4)])
        prev = inv_out
      else:
        print("Consistence check failed!")
        rec_IN = "Not Found!"
        break 
    rec_IN = Keccak_input_strip(rec_IN)
    print(("Prediction: "+rec_IN))
    print(("Answer    : "+ans_inputs[t]))
    Predictions.append(rec_IN)
    Scan_Num.append(Scan_it)
  ################################################################
  pred_name = "./recovered_data/recovered_inputs_"+str(Set_N).zfill(4)+".npy"
  svm.Save(pred_name, Predictions)
  for t in range(0, len(Predictions)):
    Success.append((Predictions[t]==ans_inputs[t]))
  return Success, Scan_Num

def Recover_One_Set(set_n, T_upper=Tr_NUM):
  print("==============================================================================")
  os.system("kinit -R")
  tag = Project_Tag+str(PPC).zfill(2)+"_"+str(set_n).zfill(4)
  print(tag)
  zipname = "Processed_10/"+tag+".zip"
  zipcmd = "unzip -qq "+zipname
  print(zipcmd)
  svm.System(zipcmd)
  Suc, Num = Set_recover(set_n, T_upper)
  svm.Save(("Success/success_"+str(set_n).zfill(4)+".npy"), Suc)
  svm.Save(("Iter_Num/iteration_"+str(set_n).zfill(4)+".npy"), Num)
  rm_name = "rm -r ./"+tag+"/"
  print(rm_name)
  svm.System(rm_name)
  return Suc

if __name__=='__main__':
  os.system("kinit -r 99d")
  S_L = int(sys.argv[1])
  S_U = int(sys.argv[2])
  Success = []
  for S in range(S_L, S_U):
    Success += Recover_One_Set(S)
  print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
  print("Results:")
  print("Total  : "+str(len(Success)))
  print("Success: "+str(Success.count(True)))

