#! /usr/bin/env python
###############################################################################
#                                                                             #
#   Copyright 2005 University of Cambridge Computer Laboratory.               #
#                                                                             #
#   This file is part of Nprobe.                                              #
#                                                                             #
#   Nprobe is free software; you can redistribute it and/or modify            #
#   it under the terms of the GNU General Public License as published by      #
#   the Free Software Foundation; either version 2 of the License, or         #
#   (at your option) any later version.                                       #
#                                                                             #
#   Nprobe is distributed in the hope that it will be useful,                 #
#   but WITHOUT ANY WARRANTY; without even the implied warranty of            #
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             #
#   GNU General Public License for more details.                              #
#                                                                             #
#   You should have received a copy of the GNU General Public License         #
#   along with Nprobe; if not, write to the Free Software                     #
#   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA #
#                                                                             #
###############################################################################


import os, sys, re, shutil
from sys import argv
from getopt import getopt
from string import replace
from time import asctime
	    
##############################################################################
##############################################################################


indc = '  '

def get_fields(fieldsf):
 
    lines = fieldsf.readlines()
    lno = 0
    lined = {}
    fieldd = {}
    tlist = []
    way = 0

    #
    # Originating source - these must duplicate the constant values defined in
    # common/hdr/basic_defs.h
    #
    SERVER = 0x1
    CLIENT = 0x2
    CLIENT_OR_SERVER = (CLIENT | SERVER)

    ways = {'server': SERVER, 'client': CLIENT, 'client/server': CLIENT_OR_SERVER}

    te = re.compile('(?P<field>[-_a-zA-Z]+)\s+(?P<fn>[_a-zA-Z]+)(\s+(?P<top>\S+))?(\s+(?P<lenp>\S+))?.*')

    de = re.compile('way\s*=\s*(?P<way>(client/server)|(server)|(client))')

    for line in lines:

        lno += 1
        #print lno
        l = line.strip()
        if not l:
            continue
        if l[0] == '#':
            continue

        if lined.has_key(l):
            print 'Error: duplicate lines %d and %d \'%s\'' % (lined[l], lno, l)
            sys.exit(1)
        else:
            lined[l] = lno

        m = te.match(l)
        if m:
            field = m.group('field')
            fn = m.group('fn')
            top = m.group('top')
            lenp = m.group('lenp')
            if fieldd.has_key(field):
                print 'Error: duplicate fields lines %d and %d \'%s\'' % (fieldd[field], lno, field)
                sys.exit(1)
            else:
                fieldd[field] = lno

            if not way:
                print 'Error: field specified but no way specified line %d' % (lno)
                sys.exit(1)
            tlist.append((field, (fn, top, lenp, way)))
        else:
            m = de.match(l)
            if m:
                way = ways[m.group('way')]
            else:
                print 'don\'t understand input line %d \'%s\'' % (lno, line.replace('\n', ''))
                sys.exit(1)
            

    return tlist
	    
##############################################################################

def pt(node):

    #print node

    d = node[0]
    c = node[1]

    keys = d.keys()
    for k in keys:
        #print k, d[k]
        n = d[k]
        if not n:
            return
        pt(n)
        if len(n[0]) == 1 and not n[1]:
            newkey, newnode = n[0].items()[0]
            newkey = k + newkey
            #print newnode
            d[newkey] = newnode
            del d[k]
	    
##############################################################################


def print_node(n, s, indent):
    print '%s%s: [{' % ('  '*indent, s),
    for k in n[0].keys():
        print k + ', ',
    print '},',
    print n[1],
    print ']'

def print_ttree(t, s, indent):
    print_node(t, s, indent)
    for k, n in t[0].items():
        print_ttree(n, k, indent+1)
        
    


sw_ref = 0
sw_stack = []

case_stack = []

max_indent = 100
#indc = '  '
ind = 2
inds = []

def gen_code(ttree, swf):

    def write(s):
        #print s,
        swf.write(s)

    def make_indents():
        for i in range(max_indent):
            inds.append(indc*i)
 

    def case(c):
        global ind
        # switch on both cases - someone may be using wrong case
        if c.islower():
            c2 = c.upper()
        else:
            c2 = c.lower()
        write('%scase \'%s\':\n%scase \'%s\':\n' % (inds[ind], c, inds[ind], c2))
        ind += 1
        case_stack.append(c)

    def endcase():
        global ind
        write('%sbreak; // end case %s\n' % (inds[ind], case_stack.pop()))
        ind -= 1

    def cmp(pref):
        global ind
        if len(pref) > 1:
            # use case-independant comparison
            write('%sif(ci_seqstrcmp(buf, \"%s\"))\n%s{\n' % (inds[ind], pref.lower(), inds[ind+1]))
        else:
            # case-independant - correct first
            if pref.islower():
                pref2 = pref.upper()
            else:
                pref2 = pref.upper()
            write('%sif(!(*buf == \'%s\' || *buf == \'%s\'))\n%s{\n' % (inds[ind], pref, pref2, inds[ind+1]))
        write('%sgoto unrecognised;\n%s}\n' % (inds[ind+2], inds[ind+1]))
        write('%selse\n%s{\n' % (inds[ind], inds[ind+1]))
        write('%sbuf += %d;\n' % (inds[ind+2], len(pref)))
        ind += 1

    def close_cmp(pref):

        global ind
        write('%s} // end cmp %s\n' % (inds[ind], pref))
        ind -= 1

    def gotit(const):

        write('%s*field = %s;\n%sgoto done;\n' % (inds[ind], const, inds[ind]))

    def strip_return():
        write('%sADJ_BUF(pp, buf);\n%scontinue;\n'% (inds[ind], inds[ind])) 


    def by_keylen(a, b):
        return len(b[0]) - len(a[0])

    def open_sw():
        global sw_ref, ind
        sw_stack.append(sw_ref)
        write('%sswitch (*buf++) //%d\n%s{\n' % (inds[ind], sw_ref, inds[ind+1]))
        sw_ref += 1
        ind += 2

    def close_sw(vals, outer):
        global ind
        if vals:
            write('%sgetval = %s;\n' % (inds[ind], vals[0]))
            write('%swaymask = 0x%x;\n' % (inds[ind], vals[3]))
            if vals[0] == 'get_hdrline':
                write('%stop = %s;\n' % (inds[ind], vals[1]))
                write('%scharlen = %s;\n' %(inds[ind], vals[2]))
                write('%sgoto check_end;\n%sbreak;\n' % (inds[ind], inds[ind]))
        else:
            if outer:
                write('%scase \'\\r\':\n' % (inds[ind]))
                ind += 1
                strip_return()
                ind -=1
            write('%sdefault:\n%sgoto unrecognised;\n%sbreak;\n' % (inds[ind], inds[ind+1], inds[ind+1]))
        write('%s}//end switch %d\n' % (inds[ind-1], sw_stack.pop()))
        ind -= 2

    def close_sw_outer():
        global ind
        write('%scase \'\\r\':\n' % (inds[ind]))
        ind += 1
        strip_return()
        ind -=1
        write('%sdefault:\n%sgoto unrecognised;\n%sbreak;\n' % (inds[ind], inds[ind+1], inds[ind+1]))
        write('%s}//end switch %d\n' % (inds[ind-1], sw_stack.pop()))
        ind -= 2
        

    def check_end(vals):
        global ind
        write('%sgetval = %s;\n' % (inds[ind], vals[0]))
        write('%swaymask = 0x%x;\n' % (inds[ind], vals[3]))
        if vals[0] == 'get_hdrline':
            write('%stop = %s;\n' % (inds[ind], vals[1]))
            write('%scharlen = %s;\n' %(inds[ind], vals[2]))
        write('%sgoto check_end;\n' % (inds[ind]))

        
    def wt(node, s, outer_sw):
        global ind
        ind += 1
        d = node[0]
        const = node[1]
            
        in_cmp = 0
        in_sw = 0
        ents =  d.items()
        ents.sort(by_keylen)
        
        if len(ents) > 1 or (const and ents):
            open_sw()
            in_sw = 1
            
        for k, n in ents:
            if in_sw:
                case(k[0])

                if len(k) > 1:
                    cmp(k[1:])
                    in_cmp = 1
            else:
                cmp(k)
                in_cmp = 1
                
            wt(n, k, 0)
            
            if in_cmp:
                close_cmp(k[1:])
                in_cmp = 0
            if in_sw:
                endcase()
                
        if const and (not ents):
            check_end(const)
        if in_sw:
            close_sw(const, outer_sw)
            in_sw = 0
        ind -= 1    

    make_indents()
    wt(ttree, '', 1)
            


	    
##############################################################################

def gen_parse(fieldsf, parsef):

    fields = get_fields(fieldsf)
    ttree = [{}, None]

    for field, fn in fields:
        d = ttree[0]
        for c in field:
            t = d.setdefault(c, [{}, None])
            d = t[0]
        t[1] = fn

    pt(ttree)

    gen_code(ttree,  parsef)

	    
##############################################################################

    

def copy_template(fieldsf, templf, parsef):

    c_insert_delim = re.compile('\s*/\*+\s*DO NOT DELETE THIS LINE - IT MARKS THE INSERTION POINT FOR AUTOMATICALLY GENERATED CODE\s*\*+/\s*')

    lines = templf.readlines()

    for l in lines:
        if c_insert_delim.match(l):
            gen_parse(fieldsf, parsef)
            continue
        else:
            parsef.write(l)

    
	    
##############################################################################

def write_head(f):

    f.write('int\nparse_http_fields(prec_t *pp, tcp_conn_t *tconnp, int way)\n{\n')
    f.write('%sint status;\n%sint (*getval)(')
    f.write('%swhile (pp->len > sizeof(char) && *(pp->buf) != \'\\n\')\n%s{\n' % (indc, indc*2))
    

    
	    
##############################################################################

def write_tail(f):

    f.write('%s}\n}' % (indc*2))
    
	    
##############################################################################

def gen_outf(args, scriptname):

    def opf(fnm, mode):
        try:
            return open(fnm, mode)
        except IOError, s:
            print '%s ERROR: %s' % (scriptname, s)
            sys.exit(1)

    fieldsf = opf(args[0], 'r')
    templf = opf(args[1], 'r')
    parsef = opf(args[2], 'w')

    parsef.write('\n\n/*\n * %s\n * Atomatically generated from %s and %s\n * by %s@%s 0n %s\n */\n\n' % (args[2], args[0], args[1], os.environ['LOGNAME'], os.environ['HOSTNAME'], asctime()))

    copy_template(fieldsf, templf, parsef)
    #gen_parse(fieldsf, parsef)

    parsef.write('\n\n/*\n * End %s\n */\n\n' % (args[1]))
	    
##############################################################################

def backup(files):

    def copy(s, d):
        try:
            shutil.copy2(s, d)
        except IOError, e:
            print 'ERROR backing up %s: %s' % (s, e)
            sys.exit(1)

    for f in files:
        if os.path.isfile(f):
            print 'Backing up %s' % (f)
            copy(f, f + '.bak')
	    
##############################################################################
	
def main():

    scriptname = os.path.basename(argv[0])

    try:
        optlist, args = getopt(argv[1:], '')

    except getopt.error, s:
        print '%s: Unrecognised option' % (scriptname)
        sys.exit(1)

    if len(args) != 3:
        print '%s ERROR: Takes three arguments - %d given' % (scriptname, len(args))
        sys.exit(1)

    backup([args[2]])
    gen_outf(args, scriptname)
        
	    
##############################################################################


# Call main when run as script
if __name__ == '__main__':
        main()
