import numpy as np
import chipwhisperer as cw
import time
import sys
import KECCAK
Type_dic = {'SHA3_224':'x01', 'SHA3_256':'x02', 'SHA3_384':'x03',\
            'SHA3_512':'x04', 'SHAKE128':'x05', 'SHAKE256':'x06' }

class CW_KECCAK:
  def __init__(self, Type):
    self.type = Type
    self.x_tag = Type_dic[self.type]
    self.FCPU = 5000000
    self.scope = cw.scope()
    self.scope.default_setup()
    self.scope.io.target_pwr = False
    self.scope.io.hs2 = 'disabled'
    return

  def init_part2(self):
    self.scope.clock.clkgen_freq = self.FCPU
    print(self.scope.clock.clkgen_mul, self.scope.clock.clkgen_div)
    time.sleep(1)
    self.scope.clock.freq_ctr_src = 'clkgen'
    time.sleep(1)
    print(self.scope.clock.freq_ctr)
    self.scope.io.target_pwr = True
    self.target = cw.target(self.scope)
    self.scope.io.target_pwr = False
    time.sleep(0.5)
    self.scope.io.target_pwr = True
    time.sleep(0.5)
    discard = self.target.read()
    return
  
  def set_input(self, p, d = 0):
    d_tag = hex(d%256)[2:].zfill(2)+hex(d>>8)[2:].zfill(2)
    P_size = int(len(p)/2)
    #print(P_size)
    size_tag = hex(P_size%256)[2:].zfill(2)+hex(P_size>>8)[2:].zfill(2)
    #print(size_tag)
    self.target.write((self.x_tag+size_tag+d_tag+"\n"))
    time.sleep(0.1)
    discard = self.target.read()
    #print(discard[0:8])
    #print("Plain :", p)
    N_input = int(P_size/8)
    for p_frag in range(0, (N_input+1)):
      if p_frag==N_input:
        frag = "i"+(p[(p_frag*16):].ljust(16, '0'))+"\n"
      else:
        frag = "i"+p[(p_frag*16):(p_frag*16+16)]+"\n"
      #print("Set plaintext "+frag)     
      self.target.write(frag)
      time.sleep(0.1)
      discard = self.target.read()
      #print(discard)
    return 0
  
  def keccak_exec(self):
    self.target.write("k\n")
    #print("Executing Keccak")
    return
  
  def keccak_exec_done(self):
    discard = (self.target.read())
    if discard=="z00\n":
      return True
    else:
      return False

  def get_output(self, d):
    d_it = d>>3
    d_tail = d%8
    cipher = ""
    for c_frag in range(0, d_it+1):
      self.target.write("o\n")
      #print("Find out cipher part "+str(c_frag))
      time.sleep(0.1)
      cccc = self.target.read()
      #print(cccc)
      if c_frag==d_it:
        cipher += cccc[1:(2*d_tail+1)]
      else:
        cipher += cccc[1:17]
    #print("Output:", cipher)
    return cipher.lower()


  def close(self):
    self.scope.io.target_pwr = False
    time.sleep(2)
    self.scope.io.target_pwr = True
    self.target.dis()
    self.scope.dis()
    return 0

class CW_SHA3:
  def __init__(self, Type):
    if Type==224 or Type==256 or Type==384 or Type==512:
      self.type = Type
      self.connection = CW_KECCAK(('SHA3_'+str(self.type)))
    else:
      print("Error: No SHA-3 output size = "+str(Type)+".")
  
  def init_part2(self):
    self.connection.init_part2()
    return
  
  def set_input(self, p):
    return self.connection.set_input(p)
  
  def keccak_exec(self):
    self.connection.keccak_exec()
    return
  
  def keccak_exec_done(self):
    return self.connection.keccak_exec_done()
  
  def get_output(self):
    return self.connection.get_output(self.type>>3)
  
  def close(self):
    self.connection.close()
    return 0


class CW_SHAKE:
  def __init__(self, Type):
    if Type==128 or Type==256:
      self.type = Type
      self.connection = CW_KECCAK(('SHAKE'+str(self.type)))
      self.output_len = 0
    else:
      print("Error: No SHAKE function size = "+str(Type)+".")
  
  def init_part2(self):
    self.connection.init_part2()
    return

  def set_input(self, p, d):
    self.output_len = d
    return self.connection.set_input(p, d)
  
  def keccak_exec(self):
    self.connection.keccak_exec()
    return

  def keccak_exec_done(self):
    return self.connection.keccak_exec_done()

  def get_output(self):
    return self.connection.get_output(self.output_len)
  
  def close(self):
    self.connection.close()
    return 0

