"""
This module represents a simulation of the TBox behavior, as described
by Omar Choudary and Frank Stajano in the paper
'Make noise and whisper: a solution to relay attacks'

Authors: Mike Bond and Omar Choudary
Last updated: 2 July 2011
"""

from random import randint
import argparse # you need Python v2.7 or later


LEN = 16


def new_bitstring(length):
    """Return a random bitstring."""
    return [randint(0, 1) for x in range(length)]


def new_tristring(length):
    """Return a new random int string of values 0, 1 or 2."""
    return [randint(0, 2) for x in range(length)]


def empty_bitstring(length):
    """Return an int string of values -1."""
    return [-1 for x in range(length)]

def listen_bitstring(length):
    """Return an int string of values 2."""
    return [2 for x in range(length)]


class AbstractTBox:
    """
    Abstract definition of the T-Box.

    Contains methods:
    __init__ (constructor)
    tbox_type
    step
    """

    def __init__(self, enda, endb):
        """Constructor taking as parameters two Endpoint objects."""
        self.enda = enda
        self.endb = endb
        self.ttype = "Abstract"

    def tbox_type(self):
        """Return the type of T-Box."""
        return self.ttype

    def step(self, step):
        """Produce the output of the TBox for the given step."""


class TBoxTypeAND(AbstractTBox):
    """
    Class of TBox that accepts 2 binary inputs and acts like an AND gate.

    Contains methods:
    __init__ (constructor extends AbstractTBox constructor)
    step (overrides AbstractTBox)
    """

    def __init__(self, enda, endb):
        """Constructor with 2 Endpoint objects as parameters."""
        AbstractTBox.__init__(self, enda, endb)
        self.ttype = "AND"

    def step(self, step):
        """For the given step execute the AND operation of the inputs."""
        out = self.enda.sig_out(step) & self.endb.sig_out(step)
        self.enda.sig_in(step, out)
        self.endb.sig_in(step, out)


class TBoxTypeTRISTATE(AbstractTBox):
    """
    Class of TBox that accepts 2 ternary (0, 1 or 2) inputs.

    The endpoints can either input a '0' or a '1' or listen,
    which is represented as the input '2'.

    In the case one of the endpoints is listening (has sent input '2')
    this TBox will output the input of the other endpoint (or '1'
    if both are listening).

    If none of the endpoints are listening but they send different inputs
    (e.g. '0' and '1') this TBox will signal a short-circuit, represented as
    output '2'.

    Contains methods:
    __init__ (constructor extends AbstractTBox constructor)
    step (overrides AbstractTBox)
    """

    def __init__(self, enda, endb):
        """Constructor with two Endpoint objects as parameters."""
        AbstractTBox.__init__(self, enda, endb)
        self.ttype = "TRISTATE"

    def step(self, step):
        """For this step perform the TBox operation and signal the result."""
        outa = self.enda.sig_out(step)
        outb = self.endb.sig_out(step)
        if (outa == 1 and outb == 0) or (outa == 0 and outb == 1):
            out = 2
        elif outa == 2:
            if outb == 2:
                out = 1
            else:
                out = outb
        else:
            out = outa
        self.enda.sig_in(step, out)
        self.endb.sig_in(step, out)


class AbstractEndpoint:
    """
    Abstract class for the endpoint.

    Contains methods:
    __init__ (constructor)
    endpoint_type
    sig_out
    sig_in
    detect_relay
    get_auth
    get_data_challenge
    get_data_observed
    get_data_auth_data
    """

    def __init__(self, name, threshold):
        """Initialise the endpoint with its name and detection threshold.

        threshold is a real value in the interval [0, 1] used for relay
        detection. Provide a small value to increase the detection
        probability and a higher value to decrease detection probability.
        """
        self.name = name
        assert(threshold >= 0 and threshold <= 1)
        self.threshold = threshold
        self.observed = []
        self.challenge = []
        self.auth_data = []
        self.relay = False
        self.score = 0
        self.etype = "Abstract"
        self.guess = False

    def endpoint_type(self):
        """Return the type of endpoint."""
        return self.etype

    def sig_out(self, step):
        """Return the signal for the given step."""
        return self.challenge[step]

    def sig_in(self, step, sig):
        """Get the output signal (from the TBox) for the current step."""
        self.observed[step] = sig

    def detect_relay(self):
        """Returns True if a relay attack was detected, False otherwise."""
        return self.relay

    def get_auth(self, data):
        """Get the authentication data from the other endpoint."""
        self.auth_data = data

    def send_auth(self):
        """Send the authentication data to the other endpoint."""
        return self.observed

    def get_data_challenge(self):
        """Return the input data sent by this endpoint."""
        return self.challenge

    def get_data_observed(self):
        """Return the data observed by this endpoint from the T-Box."""
        return self.observed

    def get_data_auth(self):
        """Return the authentication data received from the other endpoint."""
        return self.auth_data


class EndpointTypeAND(AbstractEndpoint):
    """
    Implements the Endpoint object for use with the AND type of the Tbox.

    Contains methods:
    __init__ (constructor extends AbstractEndpoint constructor)
    detect_relay (overrides AbstractEndpoint method)
    """

    def __init__(self, name, threshold, challenge=False):
        """Initialise the Endpoint with the given challenge or a default."""

        AbstractEndpoint.__init__(self, name, threshold)
        self.etype = "AND"
        if challenge != False:
            self.challenge = challenge
        self.observed = empty_bitstring(len(self.challenge))
        print(name.ljust(15) + "will send:".ljust(15) +
              str(self.challenge))

    def detect_relay(self):
        """Returns True if a relay attack was detected, False otherwise."""
        if self.auth_data == []:
            self.relay = False
            return self.relay
        if len(self.auth_data) != len(self.challenge):
            self.relay = True
            self.score = 1.0
            return self.relay


        print(self.name + ' is verifying the following data:')
        print("data_sent:".ljust(20) + str(self.challenge))
        print("data_observed:".ljust(20) + str(self.observed))
        print("data_received:".ljust(20) + str(self.auth_data))

        for i in range(len(self.auth_data)):
            if self.auth_data[i] != self.observed[i]:
                self.score = 1.0
                self.relay = True
                return self.relay

        seen_errors = 0
        sent_ones = 0
        for i in range(len(self.challenge)):
            if self.challenge[i] == 1:
                sent_ones = sent_ones + 1
                if self.observed[i] == 0:
                    seen_errors = seen_errors + 1

        if sent_ones > 0:
            self.score = abs(1.0 - (2.0 * seen_errors / (sent_ones * 1.0)))
        else:
            self.score = 0.0
        if self.score > self.threshold:
            self.relay = True

        return self.relay


class EndpointTypeTRISTATE(AbstractEndpoint):
    """
    Implements the endpoint object for use with the TRISTATE type of the Tbox.

    Contains methods:
    __init__ (constructor extends AbstractEndpoint constructor)
    sig_out (overrides AbstractEndpoint method)
    detect_relay (overrides AbstractEndpoint method)
    """

    def __init__(self, name, threshold, challenge, mask=False):
        """Initialise the Endpoint with the given challenge.

        challenge is a sequence of '0' or '1'.
        mask should also be a sequence of '0' and '1' of the same length
        as the challenge. Each bit '1' in the mask will make the endpoint
        to listen on that slot instead of sending the bit from the
        challenge. If you don't provide a mask then a random one will be
        created instead (recommended).
        """
        AbstractEndpoint.__init__(self, name, threshold)
        self.etype = "TRISTATE"
        self.challenge = challenge
        self.observed = empty_bitstring(len(self.challenge))
        self.output = []
        self.output.extend(self.challenge)
        if mask == False:
            self.mask = new_bitstring(len(challenge))
        else:
            assert (len(challenge) == len(mask))
            self.mask = mask
        for i in range(len(challenge)):
            if self.mask[i] == 1:
                self.output[i] = 2
        print(self.name.ljust(15) + "will send:".ljust(15) + str(self.output))
        print("for sequence:".ljust(30) + str(self.challenge))

    def sig_out(self, step):
        """Return the signal for the given step."""
        return self.output[step]

    def detect_relay(self):
        """Returns True if a relay attack was detected, False otherwise."""
        selflistens = 0
        otherlistens = 0
        print(self.name.ljust(15) +
              "is checking relay attack, for observed sequence: " +
              str(self.observed))
        for i in range(len(self.challenge)):
            if self.observed[i] == 2:
                print(self.name.ljust(15) +
                      "has seen a short circuit")
                self.score = 1.0
                self.relay = True
                return self.relay
            elif self.challenge[i] == 1 and self.observed[i] == 0:
                print(self.name.ljust(15) +
                      "has seen a bad secret value")
                self.score = 1.0
                self.relay = True
                return self.relay

            if (self.challenge[i] == 0) and (self.mask[i] == 1):
                selflistens = selflistens + 1
                if self.observed[i] == 1:
                    otherlistens = otherlistens + 1

        print(self.name.ljust(15) +
              "was listening in %d slots (when the sequence was 0)" %
              (selflistens) +
              " from which the other party was listening in %d" %
              (otherlistens))

        if selflistens > 0:
            self.score = abs(1.0 - (2.0 * otherlistens / (selflistens * 1.0)))
        else:
            self.score = 0.0
        if self.score > self.threshold:
            self.relay = True

        return self.relay


class EndpointTypeRelay(AbstractEndpoint):
    """
    Implements the relay attack model object for use with any type of the Tbox.

    This object can be used to model one of the intermediary parties involved
    in a relay attack, where data from one Endpoint is being forwarded to
    another Endpoint. For a relay attack you need to use two instances of this
    object, one to forward the data from the "sender" endpoint to the
    "receiver" endpoint and another object for the opposite direction.

    Contains methods:
    __init__ (constructor extends AbstractEndpoint constructor)
    get_auth (overrides AbstractEndpoint method)
    send_auth (overrides AbstractEndpoint method)
    sig_in (overrides AbstractEndpoint method)
    """

    def __init__(self, name, etype, receiver, challenge_pointer, auth_pointer,
        guess=False):
        """Initialise this relay endpoint.

        etype is used to specify to which type of TBox is connecting.
        receiver defines the endpoint of relay: set to True in order to relay
        data from "sender" or False to relay data from the "receiver".
        challenge_pointer is a pointer to a common list (initialised to a
        list of the required length) that will be filled if receiver is True or
        sent as challenge if receiver is False.
        auth_pointer is a pointer to a common list (initialised to the
        empty list) that will be filled with auth data if receiver is False
        or that will be replicated if receiver is True.
        guess can be used with the TRISTATE type, to specify that the attackers
        should try to guess the sequence bits when the receiver observes a 1
        rather than listening all time."""

        AbstractEndpoint.__init__(self, name, 0)
        self.etype = etype
        self.receiver = receiver
        self.guess = guess

        if self.receiver == True:
            self.observed = challenge_pointer
            self.auth_data = auth_pointer
            if self.etype == "TRISTATE":
                self.challenge = new_tristring(len(self.observed))
            else:
                self.challenge = new_bitstring(len(self.observed))
            print(self.name.ljust(15) + "will send:".ljust(15) +
                  str(self.challenge))
        else:
            self.challenge = challenge_pointer
            self.observed = empty_bitstring(len(self.challenge))
            self.auth_data = auth_pointer
            print(self.name.ljust(15) + "will forward data from complice")

    def sig_in(self, step, sig):
        """Get the output signal (from the TBox) for the current step."""
        if self.etype == "TRISTATE" and self.receiver == True and sig == 1:
            if self.guess == True:
                self.observed[step] = 1
            else:
                self.observed[step] = 2
        else:
            self.observed[step] = sig

    def get_auth(self, data):
        """If relaying the "receiver" data then fill the auth_data pointer,
        else don't do anything"""
        assert (len(data) == len(self.get_data_challenge()))
        if self.receiver == False:
            self.auth_data.extend(data)

    def send_auth(self):
        """If replaying to the "sender" then just forward the auth_data.
        Else return what was observed."""
        if self.receiver == True:
            return self.auth_data
        else:
            return self.observed


class LegitProtocolRun:
    """
    Class representing a normal protocol run.

    Contains methods:
    __init__ (constructor)
    run
    """

    def __init__(self, alice, bob, tbox):
        """Initialise the class with two endpoints and a T-Box."""
        assert(alice.etype == bob.etype and bob.etype == tbox.ttype)
        assert(len(alice.get_data_challenge()) == 
               len(bob.get_data_challenge()))
        self.alice = alice
        self.bob = bob
        self.tbox = tbox

    def run(self):
        """Run the anti_relay protocol for the given endpoints and tbox

        Returns the tuple (detector, score), where detector is any of the
        following strings depending on the detection result:
        "alice", if Alice has detected a relay attack.
        "bob", if Bob has detected a relay attack.
        "none", if none of them has detected a relay attack.
        and score is the detection score of the detector (default Alice)."""

        print("Running TBox anti-relay protocol for type " + self.alice.etype)
        for t in range(len(self.alice.get_data_challenge())):
            self.tbox.step(t)

        self.alice.get_auth(self.bob.send_auth())
        relaya = self.alice.detect_relay()
        relayb = self.bob.detect_relay()
        if relaya:
            return "alice", self.alice.score
        elif relayb:
            return "bob", self.bob.score
        else:
            return "none", max(self.alice.score, self.bob.score)


class RelayProtocolRun:
    """
    Class representing a relay attack run.

    Contains methods:
    __init__ (constructor)
    run
    """

    def __init__(self, alice, tboxa, malbob, malvalice, tboxb, bob):
        """Initialise the class with 4 endpoints and two T-Boxes.

        alice is an honest endpoint
        tboxa is the tbox between alice and malbob
        malbob is one of the relay complices
        malvalice is the other relay complice
        tboxb is the tbox between malvalice and bob
        bob is the other honest endpoint"""
        assert(alice.etype == tboxa.ttype and tboxa.ttype == malbob.etype
                and malbob.etype == malvalice.etype and
                malvalice.etype == tboxb.ttype and tboxb.ttype == bob.etype)
        assert(len(alice.get_data_challenge()) ==
               len(bob.get_data_challenge()))
        self.alice = alice
        self.malbob = malbob
        self.tboxa = tboxa
        self.malvalice = malvalice
        self.tboxb = tboxb
        self.bob = bob

    def run(self):
        """Run the anti_relay protocol for the given endpoints and tboxes.

        Returns the tuple (detector, score), where detector is any of the
        following strings depending on the detection result:
        "alice", if Alice has detected a relay attack.
        "bob", if Bob has detected a relay attack.
        "none", if none of them has detected a relay attack.
        and score is the detection score of the detector (default Alice)."""

        print("Running TBox anti-relay protocol for type " + self.alice.etype)
        for t in range(len(self.alice.get_data_challenge())):
            self.tboxa.step(t)
            self.tboxb.step(t)

        self.malvalice.get_auth(self.bob.send_auth())
        self.alice.get_auth(self.malbob.send_auth())
        relaya = self.alice.detect_relay()
        relayb = self.bob.detect_relay()
        if relaya:
            return "alice", self.alice.score
        elif relayb:
            return "bob", self.bob.score
        else:
            return "none", max(self.alice.score, self.bob.score)


def main():
    """An command line to use the classes in this module."""

    parser = argparse.ArgumentParser(description='Python simulation for\
            the noise-based anti-relay solution')
    parser.add_argument('-t', '--ttype', default='AND',
            type=str, choices=['AND', 'TRISTATE'],
            help='the type of T-Box to use (default AND)')
    parser.add_argument('-a', '--action', default='legitimate',
            type=str,
            choices=['legitimate', 'relay'],
            help='the type of protocol run (default legitimate)')
    parser.add_argument('-g', '--guess', action = 'store_true',
            help='Only used for the TRISTATE type in the relay scenario, to\
            request the attackers to guess values when MalBob receives a 1.')
    parser.add_argument('-r', '--threshold', default=0.3,
            type=float,
            help='the detection threshold (between [0, 1], default 0.3)')
    parser.add_argument('-l', '--length', default=16,
            type=int,
            help='the length of transmitted data between parties (default 16)')
    parser.add_argument('-n', '--nruns', default=1,
            type=int,
            help='the number of protocol runs (default 1)')
    parser.add_argument('-v', '--verbose', action = 'store_true',
            help='Produce more textual output than the default detected/score')
    args = parser.parse_args()

    ttype = args.ttype
    action = args.action
    guess = args.guess
    threshold = args.threshold
    length = args.length
    nruns = args.nruns
    verbose = args.verbose

    for i in range(nruns):
        if action == "legitimate":
            if ttype == "AND":
                print("\nLegitimate run %d with AND T-Box" % (i + 1))
                a = EndpointTypeAND("Alice", threshold, new_bitstring(length))
                b = EndpointTypeAND("Bob", threshold, new_bitstring(length))
                t = TBoxTypeAND(a, b)
                result, score = LegitProtocolRun(a, b, t).run()
            elif ttype == "TRISTATE":
                print("\nLegitimate run %d with Tristate T-Box" % (i + 1))
                challenge = new_bitstring(length)
                a = EndpointTypeTRISTATE("Alice", threshold, challenge)
                b = EndpointTypeTRISTATE("Bob", threshold, challenge)
                t = TBoxTypeTRISTATE(a, b)
                result, score = LegitProtocolRun(a, b, t).run()
        elif action == "relay":
            if ttype == "AND":
                print('\nRelay run %d with AND T-Box' % (i + 1))
                challenge_pointer = empty_bitstring(length)
                auth_pointer = []
                a = EndpointTypeAND("Alice", threshold, new_bitstring(length))
                mb = EndpointTypeRelay("MalBob", "AND", True,
                        challenge_pointer, auth_pointer, guess)
                ta = TBoxTypeAND(a, mb)
                ma = EndpointTypeRelay("MalvAlice", "AND", False,
                        challenge_pointer, auth_pointer)
                b = EndpointTypeAND("Bob", threshold, new_bitstring(length))
                tb = TBoxTypeAND(ma, b)
                result, score = RelayProtocolRun(a, ta, mb, ma, tb, b).run()
            elif ttype == "TRISTATE":
                print('\nRelay run %d with TRISTATE T-Box' % (i + 1))
                challenge_pointer = empty_bitstring(length)
                challenge = new_bitstring(length)
                auth_pointer = []
                a = EndpointTypeTRISTATE("Alice", threshold, challenge)
                mb = EndpointTypeRelay("MalBob", "TRISTATE", True,
                        challenge_pointer, auth_pointer, guess)
                ta = TBoxTypeTRISTATE(a, mb)
                ma = EndpointTypeRelay("MalvAlice", "TRISTATE", False,
                        challenge_pointer, auth_pointer)
                b = EndpointTypeTRISTATE("Bob", threshold, challenge)
                tb = TBoxTypeTRISTATE(ma, b)
                result, score = RelayProtocolRun(a, ta, mb, ma, tb, b).run()
        if verbose:
            print result, "detected a relay attack, detection score =", score
        else:
            print result, score


if __name__ == "__main__":
    main()
