import numpy as np
import chipwhisperer as cw
import time
import sys
import ni_fgen
import ni_scope
import CW_ASCON_connection as CW_Con
import random
import os
import h5py

DETECT = True
TraceLen = 10000000
PPC = 500
NEW_PPC = 50
Resample_Rate = PPC//NEW_PPC
NUM_CLOCKS = 1250
OFFSET = (TraceLen//100)+455
Record_Lower = np.array([3030, 3350, 4215, 6290, 9780, 10840, 14540])
Record_Upper = np.array([3110, 3860, 4265, 6310, 9860, 10940, 14900])
NUM_CLOCKS = sum(Record_Upper-Record_Lower)
OUT_Len = NUM_CLOCKS*NEW_PPC

class Test:
  def __init__(self, tag):
    self.CW = CW_Con.CW_ASCON(tag)
    self.nifgen = ni_fgen.NIFgen()
    self.nifgen.check_locked()
    time.sleep(1)
    self.niscope = ni_scope.NIScope(1, 2500000000, TraceLen)
    #self.niscope.check_locked()
    self.CW.init_part2()
    return

  def test1(self, key, nonce, a_data, text):
    while True:
      self.CW.set_data(key, nonce, a_data, text)
      self.CW.execute()
      Trace, Gain, Offset = self.niscope.record()
      #time.sleep(4.0)
      if self.CW.exec_done():
        break
    Response, Counter = self.CW.get_response()
    print("Input    : "+text)
    print("Response : "+Response)
    return Response, Trace, Gain, Offset, Counter

  def close(self):
    print("Closing Ports.")
    self.CW.close()
    self.niscope.close()
    self.nifgen.close()
    return

def Trace_Proc(trace):
  Trace_Cut = []
  for f in range(0, len(Record_Lower)):
    lower = OFFSET+PPC*Record_Lower[f]
    upper = OFFSET+PPC*Record_Upper[f]
    Fragment = trace[lower:upper]
    Trace_Cut.append(Fragment)
  Trace_Cut = np.hstack(Trace_Cut)
  ################################################
  Samples = np.zeros((OUT_Len))
  for t in range(0, OUT_Len):
    lower = t*Resample_Rate
    upper = lower+Resample_Rate
    Samples[t] = sum(Trace_Cut[lower:upper])
  if (DETECT==False):
    return Samples
  ################################################
  Detections = np.zeros((NUM_CLOCKS))
  for t in range(0, NUM_CLOCKS):
    lower = t*PPC+20
    upper = lower+50
    Detections[t] = sum(Trace_Cut[lower:upper])
  return Samples, Detections

def Record(tag, keys, nonces, a_data, texts, results, Name):
  tS = time.time()
  print(time.asctime())
  T1 = Test(tag)
  print(time.asctime())
  Responses = [""]*len(keys)
  Counters = [""]*len(keys)
  Gains = [0.0]*len(keys)
  Offsets = [0.0]*len(keys)
  print(Record_Lower)
  print(Record_Upper)
  ###########################################
  FILE = h5py.File((Name+'.hdf5'), 'w')
  ascii_keys    = [   keys[n].encode("ascii", "ignore") for n in range(0, len(keys))]
  ascii_nonces  = [ nonces[n].encode("ascii", "ignore") for n in range(0, len(nonces))]
  ascii_texts   = [  texts[n].encode("ascii", "ignore") for n in range(0, len(texts))]
  ascii_results = [results[n].encode("ascii", "ignore") for n in range(0, len(results))]
  FILE.create_dataset('keys',       (len(ascii_keys),1),    ('S'+str(len(keys[0]))),    compression="gzip", compression_opts=9, data=ascii_keys)
  FILE.create_dataset('nonces',     (len(ascii_nonces),1),  ('S'+str(len(nonces[0]))),  compression="gzip", compression_opts=9, data=ascii_nonces)
  FILE.create_dataset('plaintexts', (len(ascii_texts),1),   ('S'+str(len(texts[0]))),   compression="gzip", compression_opts=9, data=ascii_texts)
  FILE.create_dataset('ciphertags', (len(ascii_results),1), ('S'+str(len(results[0]))), compression="gzip", compression_opts=9, data=ascii_results)
  TRACE = FILE.create_dataset('traces', (len(keys), OUT_Len), dtype='f8', compression="gzip", compression_opts=9)
  if (DETECT==True):
    DETECTION = FILE.create_dataset('detects', (len(keys), NUM_CLOCKS), dtype='f8', compression="gzip", compression_opts=9)
  for t in range(0, len(keys)):
    print("=======================================")
    print(Name+": Trace #"+str(t).zfill(4))
    response, trace, Gains[t], Offsets[t], Counters[t] = T1.test1(keys[t], nonces[t], a_data[t], texts[t])
    if (DETECT==True):
      TRACE[t,:], DETECTION[t,:] = Trace_Proc(trace)
    else:
      TRACE[t,:] = Trace_Proc(trace)
    print("Result   : "+results[t])
    print("Counter  : "+Counters[t])
    print(results[t]==response)
    if results[t]!=response:
      T1.close()
      FILE.flush()
      FILE.close()
      return 1
    Responses[t]=response
  T1.close()
  ascii_responses = [Responses[n].encode("ascii", "ignore") for n in range(0, len(Responses))]
  ascii_counters  = [Counters[n].encode("ascii", "ignore") for n in range(0, len(Counters))]
  FILE.create_dataset('responses',  (len(ascii_responses),1), ('S'+str(len(Responses[0]))), compression="gzip", compression_opts=9, data=ascii_responses)
  FILE.create_dataset('counters',  (len(ascii_counters),1), ('S'+str(len(Counters[0]))), compression="gzip", compression_opts=9, data=ascii_counters)
  FILE.flush()
  FILE.close()
  tE = time.time()
  print(tE-tS)
  return 0

if __name__=='__main__':
  NUM = int(sys.argv[1])
  Keys       = np.load('data/key_0000.npy')[0:NUM]
  Nonces     = np.load('data/nonce_0000.npy')[0:NUM]
  A_DATA     = ['']*NUM
  Plaintexts = np.load('data/plaintext_0000.npy')[0:NUM]
  Ciphertags = np.load('data/ciphertag_0000.npy')[0:NUM]
  ########################################################
  Record("ENC_128", Keys, Nonces, A_DATA, Plaintexts, Ciphertags, "Record_CPA")
  os.system('rmdir /s/q __pycache__')
