#include <arpa/inet.h>
#include <stdlib.h>
#include <string.h>
#include <err.h>
#include <assert.h>
#include <pcap.h>

#include "demux.h"
#include "proto.h"

#define TAILQ_LAST(head, type, field)                               \
((head)->tqh_last == &(head)->tqh_first ? NULL :                    \
 ((type *)((unsigned long)(head)->tqh_last -                        \
	   (unsigned long)&((type *)0)->field.tqe_next)))

#define TAILQ_PREV(head, type, field, elm)                          \
((elm)->field.tqe_prev == &(head)->tqh_first ? NULL :               \
 ((type *)((unsigned long)(elm)->field.tqe_prev -                   \
	   (unsigned long)&((type *)0)->field.tqe_next)))

#define ETH_P_IP 0x0800

FILE *
logfile;

struct statblock
stats;

static timestamp_t
last_packet_time;

static unsigned
incomplete_datagrams;

static LIST_HEAD(, tcpconn)
hash_table[HASH_SIZE];
static TAILQ_HEAD(, tcpconn)
lru_head;
static TAILQ_HEAD(, tcpconn)
free_tcpconns;

unsigned
link_type;

/* ------------------------------------------------------------------------ */
/* Connection pool management */
static void
release_tcpconn(struct tcpconn *tp)
{
	TAILQ_INSERT_HEAD(&free_tcpconns, tp, lru_chain);
}

static void
collect_live_conns(void)
{
	struct tcpconn *tcp;

	tcp = TAILQ_LAST(&lru_head, struct tcpconn, lru_chain);
	stats.gced_conns++;
	stats.nr_tcp_connections--;
	if (!tcp->comatose) {
	} else {
		stats.comatose_connections--;
	}
	if (tcp->aborted)
		stats.aborted_connections--;
	LIST_REMOVE(tcp, hash_chain);
	TAILQ_REMOVE(&lru_head, tcp, lru_chain);
	release_tcpconn(tcp);
	stats.live_connections--;
}

static struct tcpconn *
get_tcpconn(void)
{
	struct tcpconn *res;
	res = free_tcpconns.tqh_first;
	if (res == NULL) {
		collect_live_conns();
		res = free_tcpconns.tqh_first;
		assert(res != NULL);
	}
	TAILQ_REMOVE(&free_tcpconns, res, lru_chain);
	memset(res, 0, sizeof(*res));
	return res;
}

/* ------------------------------------------------------------------------- */
/* Protocol parsing */

static void
parse_packet(struct packet *p)
{
	switch (link_type) {
	case DLT_EN10MB:
		p->eth = p->payload;
		if (net_to_host(p->eth->h_proto) == ETH_P_IP) {
			p->ip = (struct iphdr *)(p->eth + 1);
		}
		break;
	case DLT_RAW:
	case 101:
		p->eth = NULL;
		p->ip = p->payload;
		break;
	default:
		errx(1, "unknown BPF linktype %d", link_type);
	}
}

static void
parse_datagram(struct datagram *dg)
{
	struct packet *p = dg->head_packet;
	if (p->ip && p->ip->version == 4 && p->ip->protocol == 6) {
		dg->ip_payload = (void *)p->ip + p->ip->hlen * 4;
		if (net_to_host(dg->tcp->sport) <
		    net_to_host(dg->tcp->dport) ||
		    (dg->tcp->sport.a == dg->tcp->dport.a &&
		     p->ip->saddr.a < p->ip->daddr.a)) {
			dg->sid.saddr = p->ip->saddr;
			dg->sid.daddr = p->ip->daddr;
			dg->sid.sport = dg->tcp->sport;
			dg->sid.dport = dg->tcp->dport;
			dg->flownr = 0;
		} else {
			dg->sid.daddr = p->ip->saddr;
			dg->sid.saddr = p->ip->daddr;
			dg->sid.dport = dg->tcp->sport;
			dg->sid.sport = dg->tcp->dport;
			dg->flownr = 1;
		}
		dg->tcpdata = net_to_host(p->ip->tot_len) -
			p->ip->hlen * 4 -
			dg->tcp->doff * 4;
	}
}

/* ------------------------------------------------------------------------- */
/* This is the reassembly engine proper */

static int
flow_timed_out(struct tcpconn *tcp, timestamp_t now)
{
	return 0;
}

static struct tcpconn *
find_flow(struct datagram *p)
{
	unsigned hash;
	struct tcpconn *tcp;

	stats.conn_lookups++;
	hash = hash_streamid(&p->sid);
	tcp = hash_table[hash].lh_first;
	while (tcp) {
		stats.hash_probes++;
		if (!flow_timed_out(tcp, p->last_packet->ts) &&
		    !cmp_streamid(&p->sid, &tcp->sid)) {
			if (tcp->comatose) {
				/* Comatose connection.  The datagram
				   we're just picking up could be
				   either a new connection, or a
				   retransmit due to last ack getting
				   dropped.  Decide based on sequence
				   number. */
				/* (i.e. we've seen acknowledged fins
				   in both directions on this
				   connection) */
				if (p->tcp->fin &&
				    seq_eq(seq_plus(p->tcp->seq, p->tcpdata + 1),
					   tcp->flow[p->flownr].fin_seq)) {
					/* Retransmitted fin */
					stats.datagrams_on_comatose_conns++;
					break;
				}
				if (p->tcp->ack &&
				    seq_gt(p->tcp->ack_seq,
					   seq_sub(tcp->flow[1-p->flownr].fin_seq,
						   64*1024*1024)) &&
				    seq_le(seq_sub(p->tcp->ack_seq, 1),
					   tcp->flow[1-p->flownr].fin_seq)) {
					/* Retransmitted last ack */
					stats.datagrams_on_comatose_conns++;
					break;
				}
				if (p->tcp->rst) {
					/* RST -> probably belongs to
					 * this connection. */
					stats.datagrams_on_comatose_conns++;
					break;
				}
			} else if (tcp->aborted) {
				/* We've seen an RST on this
				   connection.  Start a new connection
				   if this frame is a SYN, otherwise
				   continue to add stuff to this
				   one. */
				if (!p->tcp->syn)
					break;
			} else {
				/* Connection is neither comatose nor
				   aborted -> continue to add datagrams
				   to it. */
				break;
			}
		}
		tcp = tcp->hash_chain.le_next;
	}
	if (tcp) {
		LIST_REMOVE(tcp, hash_chain);
		LIST_INSERT_HEAD(&hash_table[hash], tcp, hash_chain);
		TAILQ_REMOVE(&lru_head, tcp, lru_chain);
		TAILQ_INSERT_HEAD(&lru_head, tcp, lru_chain);
		return tcp;
	}
	stats.live_connections++;
	tcp = get_tcpconn();
	tcp->sid.saddr = p->sid.saddr;
	tcp->sid.daddr = p->sid.daddr;
	tcp->sid.sport = p->sid.sport;
	LIST_INSERT_HEAD(&hash_table[hash], tcp, hash_chain);
	TAILQ_INSERT_HEAD(&lru_head, tcp, lru_chain);
	stats.nr_tcp_connections++;
	return tcp;
}

static void
handle_syn(struct tcpconn *tcp, struct datagram *p)
{
}

static void
handle_data(struct tcpconn *tcp, struct datagram *p)
{
}

static void
handle_fin(struct tcpconn *tcp, struct datagram *p)
{
	tcp->flow[p->flownr].fin_seq = seq_plus(p->tcp->seq, p->tcpdata + 1);
	if (p->flownr)
		tcp->sent_fin_1 = 1;
	else
		tcp->sent_fin_0 = 1;
}

static void
handle_ack(struct tcpconn *tcp, struct datagram *p)
{
	/* Note the flip */
	int sent_fin = p->flownr ? tcp->sent_fin_0 : tcp->sent_fin_1;
	int acked_fin;
	if (sent_fin &&
	    seq_eq(p->tcp->ack_seq, tcp->flow[1-p->flownr].fin_seq)) {
		if (p->flownr)
			tcp->acked_fin_1 = 1;
		else
			tcp->acked_fin_0 = 1;
		if (p->flownr)
			acked_fin = tcp->acked_fin_0;
		else
			acked_fin = tcp->acked_fin_1;
		if (acked_fin) {
			if (!tcp->comatose)
				stats.comatose_connections++;
			tcp->comatose = 1;
		}
	}
}

static void
handle_rst(struct tcpconn *tcp, struct datagram *p)
{
	if (!tcp->aborted)
		stats.aborted_connections++;
	tcp->aborted = 1;
}

static void
handle_datagram(struct tcpconn *tcp, struct datagram *p)
{
	if (p->head_packet->ip->protocol == IPPROTO_TCP) {
		/* If it isn't TCP, it doesn't go through the state
		   machine. */
		if (p->tcp->syn)
			handle_syn(tcp, p);
		if (p->tcpdata)
			handle_data(tcp, p);
		if (p->tcp->fin)
			handle_fin(tcp, p);
		if (p->tcp->ack)
			handle_ack(tcp, p);
		if (p->tcp->rst)
			handle_rst(tcp, p);
		stats.tcp_datagrams++;
	}
}

static void
process_datagram(struct datagram *dg)
{
	struct tcpconn *f;
	stats.total_datagrams++;
	if (!dg->head_packet->ip) {
		return;
	}
	f = find_flow(dg);
	if (f)
		handle_datagram(f, dg);
}

/* ------------------------------------------------------------------------- */
/* Fragment handling.  This is annoying. */
static struct packet *
copy_packet(const struct packet *in)
{
	struct packet *out;
	out = xcalloc(sizeof(*out), 1);
	out->payload = malloc(in->len);
	memcpy((void *)out->payload, in->payload, in->len);
	out->eth = out->payload + ((unsigned long)in->eth -
				   (unsigned long)in->payload);
	out->ip = out->payload + ((unsigned long)in->ip -
				  (unsigned long)in->payload);
	out->ts = in->ts;
	out->len = in->len;
	return out;
}

/* Note that most packets are on the stack; only call this function on
   packets obtained from copy_packet. */
static void
release_packet(struct packet *p)
{
	free((void *)p->payload);
	free(p);
}

static void
release_datagram(struct datagram *dg)
{
	struct packet *p, *n;
	for (p = dg->head_packet; p; p = n) {
		n = p->next_off;
		release_packet(p);
	}
	*dg->pdg = dg->next;
	if (dg->next)
		dg->next->pdg = dg->pdg;
	free(dg);
}

static struct datagram *
head_dg;

static struct datagram *
find_datagram(n32 saddr, n32 daddr, n16 fragid)
{
	struct datagram **pdg;
	struct datagram *dg;
	pdg = &head_dg;
	dg = *pdg;
	while (dg && (dg->first_packet->ip->saddr.a != saddr.a ||
		      dg->first_packet->ip->daddr.a != daddr.a ||
		      dg->first_packet->ip->id.a != fragid.a)) {
		pdg = &dg->next;
		dg = *pdg;
		if (dg) {
			assert(dg->first_packet);
			assert(dg->first_packet->ip);
		}
	}
	if (dg)
		return dg;
	incomplete_datagrams++;
	dg = xcalloc(sizeof(*dg), 1);
	*pdg = dg;
	dg->pdg = pdg;
	return dg;
}

static void
handle_packet(struct packet *p)
{
	struct packet *np, *tmp;
	struct datagram *dg;

	if (!p->ip || p->ip->version != 4) {
		return;
	}
	if (!(net_to_host(p->ip->frag_off) & 0x3fff)) {
		struct datagram triv_dg;
		triv_dg.head_packet = triv_dg.tail_packet = p;
		triv_dg.first_packet = triv_dg.last_packet = p;
		parse_datagram(&triv_dg);
		process_datagram(&triv_dg);
		return;
	}

	/* Okay, we have a fragment.  Find the datagram, and shove it
	 * in. */
	np = copy_packet(p);
	dg = find_datagram(p->ip->saddr, p->ip->daddr, p->ip->id);

	np->prev_time = dg->last_packet;
	np->next_time = NULL;
	if (dg->last_packet)
		dg->last_packet->next_time = np;
	else
		dg->last_packet = dg->first_packet = np;
	for (tmp = dg->head_packet;
	     tmp && (net_to_host(tmp->ip->frag_off) & 0x3fff) >
		     (net_to_host(np->ip->frag_off) & 0x3fff);
	     tmp = tmp->next_off)
		;
	if (!tmp) {
		/* Insert at the end */
		np->prev_off = dg->tail_packet;
		np->next_off = NULL;
		if (dg->tail_packet)
			dg->tail_packet->next_off = np;
		else
			dg->head_packet = dg->tail_packet = np;
	} else {
		/* Insert immediately before tmp */
		np->prev_off = tmp->prev_off;
		np->next_off = tmp;
		if (tmp->prev_off)
			tmp->prev_off->next_off = np;
		else
			dg->head_packet = np;
		tmp->prev_off = np;
		if (dg->tail_packet)
			dg->tail_packet = np;
	}

	fprintf(logfile, "Have a fragment for datagram %D.\n", dg);
	if (net_to_host(p->ip->frag_off) & 0x4000) {
		fprintf(logfile,
			"Huh? Fragment had DONT_FRAGMENT set (datagram %D)\n",
			dg);
	}

	/* Now walk the packet list and see if we've got a complete
	   datagram. */
	if ((net_to_host(dg->head_packet->ip->frag_off) & 0x1fff) != 0) {
		/* Frag off of first packet != 0 -> not complete. */
		return;
	}
	if (net_to_host(dg->tail_packet->ip->frag_off) & 0x2000) {
		/* Last fragment has MORE_FRAGMENTS set -> not
		 * complete */
		return;
	}
	for (tmp = dg->head_packet;
	     tmp != dg->last_packet;
	     tmp = tmp->next_off) {
		unsigned tmp_ends;
		assert((net_to_host(tmp->ip->frag_off) & 0x1fff) <=
		       (net_to_host(tmp->next_off->ip->frag_off) & 0x1fff));
		tmp_ends = net_to_host(tmp->ip->frag_off) & 0x1fff;
		tmp_ends += net_to_host(tmp->ip->tot_len);
		tmp_ends -= sizeof(struct iphdr);
		if (tmp_ends <
		    (net_to_host(tmp->next_off->ip->frag_off) & 0x1fff))
			return;
		if (tmp_ends >
		    (net_to_host(tmp->next_off->ip->frag_off) & 0x1fff)) {
			fprintf(logfile,
				"WARNING: overlapping fragments on datagram %D\n",
				dg);
		}
	}

	fprintf(logfile, "Completing datagram %D.\n", dg);

	/* If we get here, we now have a complete datagram.  Pass it
	   down for further processing. */
	parse_datagram(dg);
	process_datagram(dg);
	release_datagram(dg);

	incomplete_datagrams--;
}

static void
process_packet(const struct pcap_pkthdr *hdr, const unsigned char *data)
{
	struct packet p;
	timestamp_t ts;

	if (!stats.total_packets)
		setup_time(&hdr->ts);

	ts = timeval_to_timestamp(hdr->ts);

	last_packet_time = ts;
	memset(&p, 0, sizeof(p));
	p.payload = data;
	p.len = hdr->caplen;
	p.ts = ts;
	if (stats.total_packets == 0)
		stats.start_time = ts;
	stats.total_packets++;
	parse_packet(&p);
	if (p.ip->protocol != 6)
		return;
	handle_packet(&p);

	if (stats.total_packets % 100000 == 0)
		dump_stats(ts);
}

int
main(int argc, char * argv[])
{
	int x;
	char fname[4097];
	FILE *f;
	const char *bpf_program;
	struct tcpconn *big_tcpconn_table;

	/* Sanity check configuration */
	if ((1 << 21) < MAX_BUFFERS)
		errx(1, "(1 << 21) < MAX_BUFFERS ( %d < %d)",
		     1 << 21, MAX_BUFFERS);

	setup_util();

	x = NR_LIVE_CONNS;
	big_tcpconn_table = calloc(sizeof(big_tcpconn_table[0]), x);
	if (!big_tcpconn_table)
		errx(1, "cannot allocate tcp connection table; try with smaller NR_LIVE_CONNS");
	while (x > 0) {
		x--;
		TAILQ_INSERT_HEAD(&free_tcpconns, &big_tcpconn_table[x],
				  lru_chain);
	}

	bpf_program = NULL;
	while (argc != 1) {
		if (argv[1][0] != '-')
			errx(1, "unknown argument %s", argv[1]);
		/* Woot, Duff's device */
		switch (argv[1][1]) {
		case '-':
			if (!strcmp(argv[1], "--bpf")) {
		case 'F':
				if (argc < 3)
					errx(1, "need an argument to --bpf");
				if (bpf_program)
					errx(1, "only allowed one bpf program per run");
				bpf_program = argv[2];
				argc -= 2;
				argv += 2;
			} else {
		default:
				fprintf(stderr, "Bad argument %s\n", argv[1]);
				fprintf(stderr, "-t,--timeout <proto> <timeout>\n");
				fprintf(stderr, "-F,--bpf <filter>\n");
				exit(1);
			}
		}
	}

	for (x = 0; x < HASH_SIZE; x++)
		LIST_INIT(&hash_table[x]);
	TAILQ_INIT(&lru_head);

	f = fopen("logfile", "a");
	if (!f)
		err(1, "openning logfile");
	setvbuf(f, NULL, _IONBF, 0);
	logfile = open_tee(stdout, f);

	gettimeofday(&stats.rt_start_time, NULL);
	while (!feof(stdin)) {
		int r;
		fname[0] = 0;
		for (x = 0; x < 4096; x++) {
			r = fgetc(stdin);
			if (r == EOF || r == '\n')
				break;
			fname[x] = r;
		}
		fname[x] = 0;
		if (r == EOF)
			break;
		fprintf(logfile, "New file: %s\n", fname);
		process_file(fname, bpf_program, process_packet);
	}

	fprintf(logfile, "pcap_loop exitted; going to final cleanup.\n");

	if (head_dg) {
		fprintf(logfile, "WARNING: still have some uncomplete datagrams.\n");
	}
	while (head_dg) {
		fprintf(logfile, "Uncomplete datagrams: %D\n", head_dg);
		head_dg = head_dg->next;
	}

	while (TAILQ_LAST(&lru_head, struct tcpconn, lru_chain))
		collect_live_conns();

	fclose(logfile);
	fclose(f);
	return 0;
}
