import ascon as official
import random
import sys
import time

class COMPARE:
  def __init__(self, variant, exe_func):
    self.variant = variant
    self.func = exe_func
    if (self.variant=="Ascon-128")and(self.func=="ENC"):
      import test_AEAD_128_Enc as FUNC
    elif (self.variant=="Ascon-128")and(self.func=="DEC"):
      import test_AEAD_128_Dec as FUNC
    elif (self.variant=="Ascon-128a")and(self.func=="ENC"):
      import test_AEAD_128a_Enc as FUNC
    elif (self.variant=="Ascon-128a")and(self.func=="DEC"):
      import test_AEAD_128a_Dec as FUNC
    else:
      exit()
    self.TEST = FUNC.Test()  
  
  def close(self):
    self.TEST.close()
  
  def compare_single(self, count):
    print("================================================")
    print("#"+str(count)+":")
    key = official.get_random_bytes(16)
    nonce = official.get_random_bytes(16)
    A_data = official.get_random_bytes(0)
    if self.variant=="Ascon-128":
      P_data = official.get_random_bytes(7)
    elif self.variant=="Ascon-128a":
      P_data = official.get_random_bytes(15)
    C_data = official.ascon_encrypt(key, nonce, A_data, P_data, self.variant)
    R_data = official.ascon_decrypt(key, nonce, A_data, C_data, self.variant)
    if R_data == None: 
      print("verification failed in official codes.")
      return False
    official.demo_print([("K",key),("N",nonce),("P",P_data),("A",A_data),("C+T",C_data),("R",R_data)])
    if P_data!=R_data:
      print("wrong answer in official codes")
      return False
    print("  +++++++++++++++++++++++++++++++++++++++")
    K_off = key.hex()
    N_off = nonce.hex()
    A_off = A_data.hex()
    P_off = P_data.hex()
    C_off = C_data.hex()
    if self.func=="ENC":
      self.TEST.set_data(key.hex(), nonce.hex(), A_data.hex(), P_data.hex())
      self.TEST.encrypt()
      C_own = self.TEST.get_response()
      if C_off==C_own:
        print("encoding passed.")
      else:
        print("encoding failed.")
        return False
    elif self.func=="DEC":
      self.TEST.set_data(key.hex(), nonce.hex(), A_data.hex(), C_data.hex())
      self.TEST.decrypt()
      P_own = self.TEST.get_response()
      if P_off==P_own:
        print("decoding passed.")
      else:
        print("decoding failed.")
        return False
    return True
  
  def compare_multiple(self, number):
    Results = []
    for t in range(0, number):
      res = self.compare_single(t)
      Results.append(res)
    return Results

if __name__=='__main__':
  tS = time.time()
  Variant = sys.argv[1] # either 'Ascon-128' or 'Ascon-128a'
  Exe_Func = sys.argv[2] # either 'ENC' or 'DEC'
  Number = int(sys.argv[3])
  if (Variant=='Ascon-128')or(Variant=='Ascon-128a'):
    if (Exe_Func=='ENC')or(Exe_Func=='DEC'):
      Experiment = COMPARE(Variant, Exe_Func)
      Results = Experiment.compare_multiple(Number)
      Success = Results.count(True)
      print("===================================================================")
      print(str(Success)+"/"+str(Number)+" succeed.")
      Experiment.close()
  tE = time.time()
  print((tE-tS))

