/*
 * l7.{cc,hh} -- classify TCP/UDP flows using the l7-filter engine
 * 
 * Enrico Badino
 * Marco Canini
 * Sergio Mangialardi
 *
 * Copyright (c) 2007-09 by University of Genova - DIST - TNT laboratory
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version
 * 2 of the License, or (at your option) any later version.
 * http://www.gnu.org/licenses/gpl.txt
 *
 * $Id: l7.cc 2577 2009-10-05 10:14:43Z marco $
 */
/*
 * Portions of the file are taken from the userspace version of the
 * original l7-filter userspace version 0.3, distributed under the GPL license.
 * Copyright (c) 2006-2007 by Ethan Sommer <sommere@users.sf.net> and
 * Matthew Strait <quadong@users.sf.net>
 * http://l7-filter.sf.net
 */

#include <click/config.h>
#include <click/error.hh>
#include <click/hashmap.hh>
#include <click/straccum.hh>
#include <click/confparse.hh>
#include <clicknet/ip.h>
#include <clicknet/tcp.h>
#include <clicknet/udp.h>
#include <click/packet_anno.hh>
#include <click/handlercall.hh>

/*
 * from l7-classify
 */
#include <iostream>
#include <fstream>
#include <exception>
#include <vector>
#include <string>
#include <sstream>
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <signal.h>
#include <dirent.h>

#include "l7.hh"
#include "nipquad.hh"

namespace
{
    std::string find_pattern_file(const std::string& protocol, const std::string& location);
    std::string get_protocol_name(const std::string& line);
    std::string basename(const std::string& filename);
    std::string attribute(const std::string& line);
    bool is_comment(const std::string& line);
    std::string value(const std::string& line);
    int hex2dec(char c);
}

CLICK_DECLS

int L7::configure(Vector<String>& conf, ErrorHandler* errh)
{
    String patterns;
    use_appmarks_ = true;

    if (cp_va_kparse(conf, this, errh, 
            "PATTERNS", cpkM, cpString, &patterns,
            "MAX_PKTS", 0, cpUnsigned, &buffer_count_,
            "MAX_BYTES", 0, cpUnsigned, &byte_count_,
            "PATTERNS_DIR", 0, cpFilename, &location_,
            "USE_APPMARKS", 0, cpBool, &use_appmarks_,
            cpEnd) < 0)
        return -1;

    flow_cache_ = FlowCache::upstream_instance(this);

    if (!flow_cache_)
        return errh->error("Initialization failure!");

    flow_cache_->register_flow_state_holder(this, name());

    cp_spacevec(patterns, patternNames_);

    return 0;
}

int L7::initialize(ErrorHandler* errh)
{
    int nrules = 0;

    for (Vector<String>::iterator it = patternNames_.begin();
        it != patternNames_.end(); ++it)
    {
        if (add_pattern_from_file(
            find_pattern_file((*it).c_str(), location_.c_str())))
        {
            ++nrules;
        }
    }
    if (nrules < 1)
        errh->fatal("No valid rules, exiting");

    return 0;
}

void L7::write_header(DataExport& exporter) const
{
    if (use_appmarks_) {
        exporter += AppMarks::default_instance()->column("L7Mark");
    } else {
        DataExport::Column::EnumType et;
        et.push_back("NC+" + String(buffer_count_));
        et.push_back("NC-" + String(buffer_count_));
        for (Vector<String>::const_iterator it = patternNames_.begin();
            it != patternNames_.end(); ++it)
        {
            et.push_back(*it);
        }
        exporter += column("L7Mark", et, false);
    }

    exporter += constraint("L7Mark", DataExport::Constraint::KEY);
}

void L7::write_flow_state(DataExport& exporter, const BaseFlowState* fs_) const
{
    const FlowState* fs = static_cast<const FlowState*>(fs_);
    if (use_appmarks_) {
        exporter << *((const AppMark *) fs);
    } else {
        if (*fs == AppMarks::empty_mark)
            exporter << "NC-" + String(buffer_count_);
        else if (*fs == AppMarks::unknown_mark)
            exporter << "NC+" + String(buffer_count_);
        else
            exporter << AppMarks::default_instance()->resolve_proto(*fs);
    }
}

Packet* L7::simple_action(Packet* p)
{
    FlowState* fs = static_cast<FlowState*>(flow_cache_->state(this, p));

    if (!fs)
    {
        fs = new FlowState();
        flow_cache_->state(this, p) = fs;

        if (!fs)
        {
            click_chatter("out of memory!");
            p->kill();

            return 0;
        }
    }
    fs->handle_packet(p, this);

    return p;
}

L7::FlowState::FlowState()
{
    packet_counter[0] = 0;
    packet_counter[1] = 0;
    lengthsofar[0] = 0;
    lengthsofar[1] = 0;
    buffer[0] = 0;
    buffer[1] = 0;
    
    set_mark(AppMarks::empty_mark);
}

L7::FlowState::~FlowState()
{
    delete [] buffer[0];
    delete [] buffer[1];
}

void L7::FlowState::handle_packet(Packet* p, L7* parent)
{
    if (*this == AppMarks::empty_mark)
    {
        int thlen = 8; /* UDP header length */
        const click_ip* iph = p->ip_header();

        if (iph->ip_p == IP_PROTO_TCP)
            thlen = p->tcp_header()->th_off * 4;

        int dir = PAINT_ANNO(p);
        packet_counter[dir]++;

        int payload = p->length() - (p->ip_header_length() + thlen);

        if (payload > 0)
        {
            if (buffer[dir] == 0)
                buffer[dir] = new char[parent->byte_count_ + 1];
            
            const char* start = reinterpret_cast<const char*>(p->transport_header() + thlen);
            append_to_buffer(start, payload, parent, buffer[dir], lengthsofar[dir]);
            set_mark(parent->classify(buffer[dir]));
        }
        
        if (*this != AppMarks::empty_mark)
        {
            SET_APPMARK_ANNO(p, get_mark());
            
            if (packet_counter[dir] > parent->max_pkt_match_)
                parent->max_pkt_match_ = packet_counter[dir];
        }
        else if (packet_counter[dir] >= parent->buffer_count_ || lengthsofar[dir] == parent->byte_count_)
        {
            set_mark(AppMarks::unknown_mark);
        }
        
        if (*this != AppMarks::empty_mark)
        {
            delete [] buffer[0];
            delete [] buffer[1];
            buffer[0] = buffer[1] = 0;
        }
    }
    
    SET_APPMARK_ANNO(p, get_mark());
}

void L7::FlowState::append_to_buffer(const char* app_data, int appdatalen, L7* parent, char* buffer, uint32_t& lengthsofar)
{
    uint32_t length = 0;
    uint32_t oldlength = lengthsofar;
    uint32_t ln = ((parent->byte_count_) - lengthsofar);
    
    /*
     * Strip nulls.  Add it to the end of the current data.
     */
    uint32_t len = static_cast<uint32_t>(appdatalen);
    
    for (uint32_t i=0; length<ln && i<len; ++i)
    {
        if (app_data[i] != '\0')
        {
            buffer[length + oldlength] = app_data[i];
            ++length;
        }
    }

    buffer[length + oldlength] = '\0';
    lengthsofar += length;
}

L7::Pattern::Pattern(const std::string& filename):
    cflags_(REG_EXTENDED | REG_ICASE | REG_NOSUB), eflags_(0)
{
    if (!parse_pattern_file(filename))
    {
        StringAccum sa;
        sa << "error reading pattern file " << filename.c_str();
        ErrorHandler::default_handler()->error(sa.take_string().c_str());
        throw std::exception();
    }

    if (regcomp(&preg_, pre_process(pattern_.c_str()).c_str(), cflags_) != 0)
    {
        StringAccum sa;
        sa << "error compiling " << name_.c_str() << " -- " << pattern_.c_str();
        ErrorHandler::default_handler()->fatal(sa.take_string().c_str());
    }
    
    mark_ = AppMarks::default_instance()->lookup_by_proto_name(name_.c_str());
    if (mark_ == AppMarks::empty_mark)
    {
        StringAccum sa;
        sa << "error looking up AppMark for protocol name: " << name_.c_str();
        ErrorHandler::default_handler()->fatal(sa.take_string().c_str());
    }
}

AppMark L7::classify(char* buffer)
{
    std::vector<Pattern>::iterator end = patterns_.end();
    for (std::vector<Pattern>::iterator iter=patterns_.begin();
        iter != end; ++iter)
    {
        
        if (iter->matches(buffer))
            return iter->get_mark();
    }
    
    return AppMarks::empty_mark;
}

bool L7::Pattern::matches(char* buffer)
{
    return regexec(&preg_, buffer, 0, 0, eflags_) == 0;
}

// Returns true on sucess, false on failure
bool L7::add_pattern_from_file(const std::string& filename)
{
    try
    {
        Pattern pattern(filename);
        patterns_.push_back(pattern);
    }
    catch (...)
    {
        return false;
    }
    
    return true;
}

bool L7::Pattern::parse_pattern_file(const std::string& filename)
{
    std::ifstream file(filename.c_str());

    if (!file.is_open())
    {
        StringAccum sa;
        sa << "error opening pattern file " << filename.c_str();
        ErrorHandler::default_handler()->fatal(sa.take_string().c_str());
        return false;
    }
    // What we're looking for. It's either the protocol name, the kernel
    // pattern,
    // which we'll use if no other is present, or any of various (ok, two)
    // userspace config lines.
    enum
    {
        protocol,
        kpattern,
        userspace
    } state = protocol;

    std::string name;
    std::string line;

    while (!file.eof())
    {
        getline(file, line);

        if (is_comment(line))
            continue;

        if (state == protocol)
        {
            name_ = get_protocol_name(line);

            if (name_ != basename(filename))
            {
                std::cerr << "Error: Protocol declared in file does not match file name.\nFile name is ";
                std::cerr << basename(filename) << ", but the file says " << name_ << std::endl;
                return false;
            }
            state = kpattern;
            continue;
        }

        if (state == kpattern)
        {
            pattern_ = line;
            state = userspace;
            continue;
        }

        if (state == userspace)
        {
            if (line.find_first_of('=') == std::string::npos)
            {
                std::cerr << "Warning: ignored bad line in pattern file:\n\t" << line << std::endl;
                continue;
            }

            if (attribute(line) == "userspace pattern")
            {
                pattern_ = value(line);
            }
            else if (attribute(line) == "userspace flags")
            {
                if (!parse_flags(value(line)))
                    return false;
            }
            else
                std::cerr << "Warning: ignored unknown pattern file attribute \"" << attribute(line) << "\"" << std::endl;
        }
    }
    return true;
}

bool L7::Pattern::parse_flags(const std::string& line)
{
    std::string flag;
    cflags_ = 0;
    eflags_ = 0;
    
    for (unsigned int i=0; i<line.size(); ++i)
    {
        if (!isspace(line[i]))
            flag += line[i];
    
        if (isspace(line[i]) || i == line.size() - 1)
        {
            if (flag == "REG_EXTENDED")
                cflags_ |= REG_EXTENDED;
            else if (flag == "REG_ICASE")
                cflags_ |= REG_ICASE;
            else if (flag == "REG_NOSUB")
                cflags_ |= REG_NOSUB;
            else if (flag == "REG_NEWLINE")
                cflags_ |= REG_NEWLINE;
            else if (flag == "REG_NOTBOL")
                eflags_ |= REG_NOTBOL;
            else if (flag == "REG_NOTEOL")
                eflags_ |= REG_NOTEOL;
            else
            {
                std::cerr << "Error: encountered unknown flag in pattern file " << flag << std::endl;
                
                return false;
            }
            flag = "";
        }
    }
    return true;
}

std::string L7::Pattern::pre_process(const std::string& s)
{
    size_t len = s.size();
    char* result = new char[len + 1];
    size_t sindex = 0;
    size_t rindex = 0;
    
    while (sindex < len)
    {
        if (sindex + 3 < len && s[sindex] == '\\' && s[sindex + 1] == 'x' && isxdigit(s[sindex + 2]) && isxdigit(s[sindex + 3]))
        {
            result[rindex] = hex2dec(s[sindex + 2]) * 16 + hex2dec(s[sindex + 3]);
    
            switch (result[rindex])
            {
            case '$':
            case '(':
            case ')':
            case '*':
            case '+':
            case '.':
            case '?':
            case '[':
            case ']':
            case '^':
            case '|':
            case '{':
            case '}':
            case '\\':
                std::cerr << "Warning: regexp contains a regexp control character, " << result[rindex];
                std::cerr << ", in hex (\\x" << s[sindex + 2] << s[sindex + 3];
                std::cerr << ".\nI recommend that you write this as " << result[rindex];
                std::cerr << " or \\" << result[rindex] << " depending on what you meant." << std::endl;
                break;
            case '\0':
                std::cerr << "Warning: null (\\x00) in layer7 regexp. " << "A null terminates the regexp string!" << std::endl;
                break;
            default:
                break;
            }
            sindex += 3; /* 4 total */
        }
        else
            result[rindex] = s[sindex];
    
        ++sindex;
        ++rindex;
    }
    result[rindex] = '\0';
    std::string res = result;
    delete[] result;
    
    return res;
}

namespace
{
    std::string find_pattern_file(const std::string& protocol, const std::string& location)
    {
        DIR* scratchdir;
        dirent** namelist;
        int n = scandir(location.c_str(), &namelist, 0, alphasort);
        std::string answer;
        bool found = false;
            
        if (n < 0)
        {
            StringAccum sa;
            sa << "Couldn't open " << location.c_str();
            ErrorHandler::default_handler()->fatal(sa.take_string().c_str());
        }
        else
        {
            while (n--)
            {
                StringAccum sa;
                sa << location.c_str() << "/" << namelist[n]->d_name;
                String fulldirname = sa.take_string();
                
                if (!found && (scratchdir = opendir(fulldirname.c_str())))
                {
                    closedir(scratchdir);
    
                    if (!strcmp(namelist[n]->d_name, ".") || !strcmp(namelist[n]->d_name, ".."))
                        ; // do nothing
                    else
                    {
                        StringAccum sa;
                        sa << fulldirname << "/" << protocol.c_str() << ".pat";
    
                        String filename = sa.take_string();
                        std::ifstream test(filename.c_str());
    
                        /*
                         * read in the pattern from the file
                         */
                        if (test.is_open())
                        {
                            answer = filename.c_str();
                            found = true;
                        }
                    }
                }
                free(namelist[n]);
            }
            free(namelist);
        }
    
        if (!found)
        {
            StringAccum sa;
            sa << "Couldn't find a pattern definition file for " << protocol.c_str();
            ErrorHandler::default_handler()->fatal(sa.take_string().c_str());
        }
        
        return answer;
    }
    
    std::string get_protocol_name(const std::string& line)
    {
        std::string name;
        
        for (unsigned int i=0; i<line.size(); ++i)
        {
            if (!isspace(line[i]))
                name += line[i];
            else
                break;
        }
        return name;
    }
    
    // Returns the given file name from the last slash to the next dot
    std::string basename(const std::string& filename)
    {
        size_t lastslash = filename.find_last_of('/');
        size_t nextdot = filename.find_first_of('.', lastslash);
    
        return filename.substr(lastslash + 1, nextdot - (lastslash + 1));
    }
    
    // Returns true if the line (from a pattern file) is a comment
    bool is_comment(const std::string& line)
    {
        // blank lines are comments
        if (line.empty())
            return true;
        
        // lines starting with # are comments
        if (line[0] == '#')
            return true;
        
        // lines with only whitespace are comments
        for (unsigned int i=0; i<line.size(); ++i)
            if (!isspace(line[i]))
                return false;
        
        return true;
    }
    
    std::string attribute(const std::string& line)
    {
        return line.substr(0, line.find_first_of('='));
    }
    
    // Returns, e.g. ".*foo" if the line is "userspace pattern=.*foo"
    std::string value(const std::string& line)
    {
        return line.substr(line.find_first_of('=') + 1);
    }
    
    // parse the regexec and regcomp flags
    // Returns true on sucess, false if any unrecognized flags were encountered
    
    int hex2dec(char c)
    {
        switch (c)
        {
        case '0'
        ...'9':
            return c - '0';
        case 'a'
        ...'f':
            return c - 'a' + 10;
        case 'A'
        ...'F':
            return c - 'A' + 10;
        default:
            std::cerr << "Bad hex digit, " << c << ", in regular expression!" << std::endl;
            exit(1);
        }
    }
}

/*
 * HANDLERS
 */

String L7::read_max_pkt_match(Element* e, void*)
{
    L7* cf = static_cast <L7*>(e);
    
    return String(cf->max_pkt_match_);
}

void L7::add_handlers()
{
    add_read_handler("max_pkt_match", read_max_pkt_match, 0);
}

ELEMENT_REQUIRES(FlowCache)
EXPORT_ELEMENT(L7)

#include <click/vector.cc>

CLICK_ENDDECLS

