#define _LARGEFILE_SOURCE
#define _FILE_OFFSET_BITS 64

#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>
#include <math.h>
#include <glob.h>
#include <assert.h>

#include <string>

#include <db_cxx.h>

#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <net/ethernet.h>
#include <netinet/ip.h>
#include <pcap.h>

#include <pcrecpp.h>

static std::string ERRPFX ("Error pcap-idx: ");

/* returns the time in ms between two timevals */
static int timedif(timeval a, timeval b)
{
    int us, s;

    us = a.tv_usec - b.tv_usec;
    us /= 1000;
    s = a.tv_sec - b.tv_sec;
    s *= 1000;
    return s + us;
}

static timeval strtotv(const char *s)
{
    timeval r;
    char *end;
    r.tv_sec = strtol(s, &end, 10);
    if (*end == '.')
        end++;
    r.tv_usec = strtol(end, 0, 10);
    return r;
}

/* returns 1 if a >= b */
static int timegt(timeval a, timeval b)
{
    if (a.tv_sec > b.tv_sec)
        return 1;
    else if (a.tv_sec == b.tv_sec)
        return a.tv_usec >= b.tv_usec;
    return 0;
}

class Packet
{
public:
    Packet(const uint8_t *data_, size_t caplen_, size_t len_,
            timeval ts_, int dl_type_);

    //const uint8_t *data() const { return data; }
    size_t cap_length() const { return caplen; }
    size_t length() const { return len; }
    
    const uint8_t *link_header() const { return link_hdr; }
    const uint8_t *network_header() const { return net_hdr; }
    const iphdr *ip_header() const { return (const iphdr *) net_hdr; }

    bool is_ip() const { return net_type == ETHERTYPE_IP; }
    bool is_tcp() const
    {
        return is_ip() && ip_header()->protocol == IPPROTO_TCP;
    }
    bool is_udp() const
    {
        return is_ip() && ip_header()->protocol == IPPROTO_UDP;
    }
private:
    void dissect(int dl_type);

    void parse_ethhdr(const uint8_t *d, size_t l);
    void parse_iphdr(const uint8_t *d, size_t l);
    
    const uint8_t *data;
    size_t caplen;
    size_t len;
    timeval ts;

    int link_type;
    int net_type;
    const uint8_t *link_hdr;
    const uint8_t *net_hdr;

    class datalink_exception : public std::exception
    {
    public:
        datalink_exception(int dl_type_) : dl_type(dl_type_) {}
        virtual const char* what() const throw()
        {
            static char buf[64];
            snprintf(buf, 64, "Unsupported data link type `%d`",
                    dl_type);
            return buf;
        }
    private:
        int dl_type;
    };
};

Packet::Packet(const uint8_t *data_, size_t caplen_, size_t len_,
        timeval ts_, int dl_type_) : data(data_), caplen(caplen_),
                                    len(len_), ts(ts_),
                                    link_type(0), net_type(0),
                                    link_hdr(0), net_hdr(0)
{
    dissect(dl_type_);
}

void Packet::dissect(int dl_type)
{
    if (dl_type == DLT_EN10MB)
        parse_ethhdr(data, caplen);
    else if (dl_type == DLT_RAW)
        parse_iphdr(data, caplen);
    else
        throw datalink_exception(dl_type);
}

void Packet::parse_ethhdr(const uint8_t *d, size_t l)
{
    const ether_header *eth_hdr = (const ether_header *) d;
    if (sizeof(ether_header) > l)
        return;
    link_type = DLT_EN10MB;
    link_hdr = d;
    if (ntohs(eth_hdr->ether_type) == ETHERTYPE_IP)
        parse_iphdr(d + sizeof(ether_header), l - sizeof(ether_header));
}

void Packet::parse_iphdr(const uint8_t *d, size_t l)
{
    const iphdr *ip_hdr = (const iphdr *) d;
    if (sizeof(iphdr) > l)
        return;
    net_type = ETHERTYPE_IP;
    net_hdr = d;
    // TODO
}

static std::string pattern_ip ("(\\d+\\.\\d+\\.\\d+\\.\\d+)");
static std::string pattern_ipProto ("(tcp|udp)");


class Key
{
public:
    Key(size_t size) : sz(size), wpos(0)
    {
        buf = new uint8_t[sz];
        memset(buf, 0, sz);
    }
    Key(const Key& k) : sz(k.sz)
    {
        buf = new uint8_t[sz];
        memcpy(buf, k.buf, sz);
        //buf = k.buf; // TODO: how to do this?
        //k.buf = 0;
    }
    ~Key() { delete[] buf; }

    size_t size() const { return sz; }
    void *data() const { return buf; }
    void write(const void *x, size_t len)
    {
        assert(wpos < sz);
        memcpy(buf + wpos, x, len);
        wpos += len;
    }
private:

    size_t sz;
    uint8_t *buf;
    size_t wpos;
};

class Index
{
public:
    virtual ~Index() {}
    
    virtual std::string name() const = 0;
    virtual Key get_key(const Packet& pkt) const = 0;
    virtual Key parse_query(const std::string& query) const = 0;
    virtual std::string syntax() const = 0;
};

class IpsProtoIndex : public Index
{
public:
    std::string name() const { return "ipsproto"; }
    Key get_key(const Packet& pkt) const;
    Key parse_query(const std::string& query) const;
    std::string syntax() const
    {
        return "ip.add.re.ss1 ip.add.re.ss2 tcp|udp";
    }
private:
    Key make_key(uint32_t hosta, uint32_t hostb, uint8_t proto) const;
    static pcrecpp::RE re;
};

Key IpsProtoIndex::get_key(const Packet& pkt) const
{
    const iphdr *ip_hdr = pkt.ip_header();
    uint32_t hosta = ip_hdr->saddr;
    uint32_t hostb = ip_hdr->daddr;
    uint8_t proto = ip_hdr->protocol;
    return make_key(hosta, hostb, proto);
}

Key IpsProtoIndex::make_key(uint32_t hosta, uint32_t hostb, uint8_t proto) const
{
    if (hosta > hostb)
        std::swap(hosta, hostb);
    Key key(2 * sizeof(uint32_t) + sizeof(uint8_t));
    key.write(&hosta, sizeof(uint32_t));
    key.write(&hostb, sizeof(uint32_t));
    key.write(&proto, sizeof(uint8_t));
    return key;
}

pcrecpp::RE IpsProtoIndex::re(pattern_ip + "\\s+" + pattern_ip + "\\s+"
    + pattern_ipProto);

Key IpsProtoIndex::parse_query(const std::string& query) const
{
    std::string ipa, ipb, ipProto;

    if (!re.FullMatch(query, &ipa, &ipb, &ipProto))
        throw std::string("Invalid query: `" + query + "`. " +
            "Syntax: " + syntax());

    uint32_t hosta = inet_addr(ipa.c_str());
    uint32_t hostb = inet_addr(ipb.c_str());
    uint8_t proto = IPPROTO_TCP;
    if (ipProto == "udp")
        proto = IPPROTO_UDP;
    return make_key(hosta, hostb, proto);
}


class CreateCommand
{
public:
    CreateCommand() : read_bytes(0), read_pkts(0), trace_size(0),
                      index(0), cur(0), pcap(0)
    {
        pcap_err[0] = '\0';
    }
    void run(std::string filename);

private:
    void show_progress();
    void process_packet(const Packet& pkt);
    static void pcap_handler(uint8_t *user, const pcap_pkthdr *pcap_hdr,
            const uint8_t *data);

    off_t read_bytes;
    uint32_t read_pkts;
    timeval start_time;
    off_t trace_size;

    Index *index;
    Dbc *cur;
    
    pcap_t *pcap;
    char pcap_err[PCAP_ERRBUF_SIZE];
};

void CreateCommand::run(std::string filename)
{
    index = new IpsProtoIndex();

    std::string dbfile(filename);
    dbfile += std::string(".idx") + index->name();

    Db db(NULL, 0);

    db.set_flags(DB_DUP);
    
    try {
        struct stat s;
        if (stat(filename.c_str(), &s) != 0) {
            throw std::string(strerror(errno));
        }
        trace_size = s.st_size;
        
        db.open(NULL, dbfile.c_str(), NULL, DB_BTREE,
            DB_CREATE | DB_TRUNCATE, 0644);
        
        db.cursor(NULL, &cur, 0); // TODO close cursor with sth ala auto_ptr
        
        FILE *trace = fopen(filename.c_str(), "rb");
        if (trace == NULL) {
            throw std::string(strerror(errno));
        }
        
        pcap = pcap_fopen_offline(trace, pcap_err);
        if (trace == NULL) {
            throw std::string(pcap_err);
        }
        read_bytes = ftell(trace);
        gettimeofday(&start_time, NULL);

        if (pcap_loop(pcap, -1, CreateCommand::pcap_handler,
            (uint8_t *) this) < 0)
        {
            throw std::string(pcap_err);
        }
        show_progress();
        std::cerr << std::endl;

        pcap_close(pcap);

        cur->close();
        cur = 0;
        db.close(0);
    } catch (DbException& e) {
        std::cerr << ERRPFX << e.what() << std::endl;
        exit(1);
    } catch (std::exception& e) {
        std::cerr << ERRPFX << e.what() << std::endl;
        exit(1);
    } catch (std::string& e) {
        std::cerr << ERRPFX << e << std::endl;
        exit(1);
    }

    delete index;
} 

void CreateCommand::process_packet(const Packet& pkt)
{
    off_t off = read_bytes;
    read_bytes += pkt.cap_length() + sizeof(pcap_pkthdr);
    read_pkts += 1;
    // skip non TCP/UDP packets
    if (pkt.is_tcp() || pkt.is_udp())
    {
        Key key = index->get_key(pkt);
        Dbt k(key.data(), key.size());
        Dbt v(&off, sizeof(off_t));
        cur->put(&k, &v, DB_KEYLAST);
    }
    if (read_pkts % 10000 == 0)
        show_progress();

}

void CreateCommand::show_progress()
{
    timeval end_time;
    gettimeofday(&end_time, NULL);
    double elapsed = timedif(end_time, start_time);
    double rate_pkts = read_pkts * 1000.0 / elapsed;
    double rate_bytes = read_bytes * 1000.0 / elapsed;
    
    int progress = read_bytes * 100 / trace_size;
    fprintf(stderr, "\rProgress: %02d%% (%llu/%llu) %f pkt/s %f bytes/s",
        progress, read_bytes, trace_size, rate_pkts, rate_bytes);
}

void CreateCommand::pcap_handler(uint8_t *user, const pcap_pkthdr *pcap_hdr,
    const uint8_t *data)
{
    CreateCommand *cmd = (CreateCommand *) user;
    Packet pkt(data, pcap_hdr->caplen, pcap_hdr->len, pcap_hdr->ts,
        pcap_datalink(cmd->pcap));
    cmd->process_packet(pkt);
}

class QueryCommand
{
public:
    void run(std::string filename, std::string query, timeval ts_start_, timeval ts_end_);
private:
    void run_file(std::string filename, std::string dbfile, bool wrote_hdr);

    Index *index;
    std::string query;
    timeval ts_start;
    timeval ts_end;
    
#define PCAP_MAGIC 0xa1b2c3d4
    struct pcap_file_header {
        bpf_u_int32 magic;
        u_short version_major;
        u_short version_minor;
        bpf_int32 thiszone;     /* gmt to local correction */
        bpf_u_int32 sigfigs;    /* accuracy of timestamps */
        bpf_u_int32 snaplen;    /* max length saved portion of each pkt */
        bpf_u_int32 linktype;   /* data link type (LINKTYPE_*) */
    };
};

void QueryCommand::run(std::string ts_path, std::string query_, timeval ts_start_, timeval ts_end_)
{
    index = new IpsProtoIndex();
    query = query_;
    ts_start = ts_start_;
    ts_end = ts_end_;

    glob_t traces;
    std::string traces_pattern(ts_path);
    traces_pattern += "/*[0-9].[0-9][0-9][0-9][0-9][0-9][0-9]";
    int r = glob(traces_pattern.c_str(), 0, 0, &traces);
    if (r != 0)
    {
        std::cerr << ERRPFX << "cannot find traces. glob() failed." << std::endl;
        exit(1);
    }

    bool wrote_hdr = false;
    size_t idx = 0;
    for (size_t tidx = 0; tidx < traces.gl_pathc; ++tidx)
    {
        timeval t;
        std::string filename(traces.gl_pathv[tidx]);
        size_t pos = filename.rfind('/');
        if (pos != string::npos)
            filename = filename.substr(pos + 1);

        t = strtotv(filename.c_str());

        if (timegt(t, ts_start))
            break;
        idx = tidx;
    }
    
    for (size_t tidx = idx; tidx < traces.gl_pathc; ++tidx)
    {
        timeval t;
        std::string filename(traces.gl_pathv[tidx]);
        size_t pos = filename.rfind('/');
        if (pos != string::npos)
            filename = filename.substr(pos + 1);

        t = strtotv(filename.c_str());

        filename = traces.gl_pathv[tidx];
        std::cerr << "Processing " << filename << std::endl;

        std::string dbfile(filename);
        dbfile += std::string(".idx") + index->name();

        run_file(filename, dbfile, wrote_hdr);
        wrote_hdr = true;
        
        if (!timegt(ts_end, t))
            break;
    }        

    delete index;
}

void QueryCommand::run_file(std::string filename, std::string dbfile, bool wrote_hdr)
{
    Db db(NULL, 0);

    db.set_flags(DB_DUP);

    try {

        db.open(NULL, dbfile.c_str(), NULL, DB_BTREE,
                DB_RDONLY, 0);

        Dbc *cur;
        db.cursor(NULL, &cur, 0); // TODO close cursor with sth ala auto_ptr

        FILE *trace = fopen(filename.c_str(), "rb");
        if (trace == NULL) {
            throw std::string(strerror(errno));
        }

        pcap_file_header hdr;
        if (fread(&hdr, sizeof(hdr), 1, trace) != 1) {
            throw std::string(strerror(errno));
        }

        bool swap = (htonl(hdr.magic) == PCAP_MAGIC);
        if (!swap && hdr.magic != PCAP_MAGIC) {
            throw std::string("Invalid pcap file");
        }

        if (!wrote_hdr)
            fwrite(&hdr, sizeof(hdr), 1, stdout);

        Key key = index->parse_query(query);
        Dbt k(key.data(), key.size());
        Dbt v;

        pcap_pkthdr pcap_hdr;
        uint8_t buf[65536];

        int ret = cur->get(&k, &v, DB_SET);
        while (ret != DB_NOTFOUND) {
            off_t *off = (off_t *) v.get_data();
            fseeko(trace, *off, SEEK_SET);
            if (fread(&pcap_hdr, sizeof(pcap_hdr), 1, trace) != 1) {
                throw std::string(strerror(errno));
            }
            if (swap) {
                pcap_hdr.caplen = ntohs(pcap_hdr.caplen);
                pcap_hdr.ts.tv_sec = ntohl(pcap_hdr.ts.tv_sec);
                pcap_hdr.ts.tv_usec = ntohl(pcap_hdr.ts.tv_usec);
            }
            if (timegt(pcap_hdr.ts, ts_start)) {
                if (!timegt(ts_end, pcap_hdr.ts)) {
                    break;
                }
                fwrite(&pcap_hdr, sizeof(pcap_hdr), 1, stdout);
                if (fread(buf, 1, pcap_hdr.caplen, trace) != pcap_hdr.caplen) {
                    throw std::string(strerror(errno));
                }
                fwrite(buf, 1, pcap_hdr.caplen, stdout);
            }
            ret = cur->get(&k, &v, DB_NEXT_DUP);
        }

        fclose(trace);

        cur->close();
        cur = 0;
        db.close(0);
    } catch (DbException& e) {
        std::cerr << ERRPFX << e.what() << std::endl;
        exit(1);
    } catch (std::exception& e) {
        std::cerr << ERRPFX << e.what() << std::endl;
        exit(1);
    } catch (std::string& e) {
        std::cerr << ERRPFX << e << std::endl;
        exit(1);
    }
}

void usage(const char *prog, bool error = false)
{
    FILE *f = (error) ? stderr : stdout;
    fprintf(f, "Usage: %s <command> <arguments>\n"
        "Valid commands are:\n"
        "create <path to pcap trace>\n"
        "query <path to trace repos> <query string> [<time range>]\n"
        "time range := start <timestamp> end <timestamp>\n", prog);
    if (error)
        exit(1);
}

int main(int argc, char *argv[])
{
    if (argc < 2)
    {
        fprintf(stderr, "Error: %s called with no arguments!\n", argv[0]);
        usage(argv[0], true);
    }
    if (std::string("--help") == argv[1])
    {
        usage(argv[0]);
        exit(0);
    }
    std::string command(argv[1]);
    if (command == "create")
    {
        if (argc < 3)
        {
            std::cerr << "Expecting path to pcap trace" << std::endl;
            usage(argv[0], true);
        }
        std::string filename(argv[2]);
        CreateCommand cmd;
        cmd.run(filename);
    } else if (command == "query")
    {
        if (argc < 4)
        {
            std::cerr << "Expecting path to pcap trace repository and query" << std::endl;
            usage(argv[0], true);
        }
        std::string tr_path(argv[2]);

        timeval ts_start, ts_end;
        memset(&ts_start, 0, sizeof(ts_start));
        memset(&ts_end, 0, sizeof(ts_end));

        int query_end = argc;
        if (argc >= 8)
        {
            if (std::string("start") == argv[argc - 4] &&
                std::string("end") == argv[argc - 2])
            {
                ts_start = strtotv(argv[argc - 3]);
                ts_end = strtotv(argv[argc - 1]);

                query_end -= 4;
            }
        }
        std::string query(argv[3]);
        for (int i = 4; i < query_end; ++i)
        {
            query += " ";
            query += argv[i];
        }
        
        QueryCommand cmd;
        cmd.run(tr_path, query, ts_start, ts_end);
    }
}

