###############################################################################
#                                                                             #
#   Copyright 2005 University of Cambridge Computer Laboratory.               #
#                                                                             #
#   This file is part of Nprobe.                                              #
#                                                                             #
#   Nprobe is free software; you can redistribute it and/or modify            #
#   it under the terms of the GNU General Public License as published by      #
#   the Free Software Foundation; either version 2 of the License, or         #
#   (at your option) any later version.                                       #
#                                                                             #
#   Nprobe is distributed in the hope that it will be useful,                 #
#   but WITHOUT ANY WARRANTY; without even the implied warranty of            #
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             #
#   GNU General Public License for more details.                              #
#                                                                             #
#   You should have received a copy of the GNU General Public License         #
#   along with Nprobe; if not, write to the Free Software                     #
#   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA #
#                                                                             #
###############################################################################


##############################################################################
## 
##
## Stats collector for tcp analysis
## 
##
## 
############################################################################

#from np_TCP import *
from np_packet_markers import *
from np_statscollector import StatsCollector
from nprobe import SERVER, intoa_string
from np_plot import DataSet, np_Plot, EmptyDataSetError, WrongDataTypeError, \
     DATA_TS
#from np_TCPDisplay import TcpDisplay
import np_TCPDisplay
from np_notesel import Sellist, CallBackError

############################################################################


class TCPStats(StatsCollector):

    def __init__(self, args, trace=0, quiet=0, savedata=1, logpath=None):

        StatsCollector.__init__(self, args, trace=trace, quiet=quiet,
                                logpath=logpath)

 	self.slat = [] # server initial delay, fields = (tm, delay, ob size)
 	self.splat = [] # server following delays, fields = (tm, delay, ob size)

 	self.sprtts = [] # server *known* prtts, fields (tm, sprtt)
 	self.allsprtts = [] # do including interpolated sprtts
 	self.cprtts = [] # client do

        self.conntimes = []

 	self.aslat = [] # apparent initial server delay (or prtt) ie. repstart - reqend
 	self.asplat = [] # ditto any following

        self.cdbws = []
        self.sdbws = []

        # tmp
        #self.reptms = []

        #
        # This is the mapping between packet triggers, stats draw (extract)
        # buttons, and button groups
        # fields: [0] type/button text, [1] groups list, [2] trigger,
        #  [3] delimiter after button flag, [4] plot title
        #
        self.extract_what = [
            ('Connections', ['TCPConn', 'Rank Server'],  None, 1, ''),
            ('S Latency',  ['Rank Server'], TRIG_RESP_FIRST, 0,
             'Initial server latencies'),
            ('S PLatency',  ['Rank Server'], TRIG_RESP_DEL, 0,
             'Subsequent server latencies'),
            ('S PRTT',  ['Rank Server'], TRIG_RTT, 0,
             'Genuine server PRTTs'),
            ('S AllPRTT',  ['Rank Server'], TRIG_ALL, 1,
             'Genuine and interpolated server PRTTs'),
            ('Bandwidths', ['Client data bandwidths', 'Server data bandwidths'], None, 1, '')
            ]

        self.extract_dir = {}
        for e in self.extract_what:
            self.extract_dir[e[0]] = e[2:]

###############################################################################

    #
    # This one collects prtts/delays based on individual connections
    #

    def add_TCPStats(self, TCPMach):

        ct_add = self.conntimes.append
        cprtts_add = self.cprtts.append
        aslat_add = self.aslat.append
        asplat_add = self.asplat.append

        conn = TCPMach.conn
        id = conn.id
        ot = conn.open

        for p in TCPMach.conn.pktlist:
 	    if p.dir == SERVER:
                ct_add((id, p.tm, p.trig, p.delay, p.prtt, p.indx))
                #print 'add %d %d %d %x %d %d' % (p.indx, id, p.tm, p.trig, p.delay, p.prtt)
            else:
 		cprtts_add((id, p.tm, p.prtt))
                
        # apparent rtts
	t =  TCPMach.ttms[0]
	if t:
            #print TCPMach.ttms[0]
	    aslat_add((id, t.rps, t.rps-t.rqe))
	    for t in TCPMach.ttms[1:]:
		asplat_add((id, t.rps, t.rps-t.rqe))

        self.cdbws.append((id, ot/1000000.0, conn.cdbw))
        self.sdbws.append((id, ot/1000000.0, conn.sdbw))

        # tmp
        #for t in conn.ttms:
            #self.reptms.append(t.rps)

###############################################################################

    def draw_conns(self, what, oblist):

        from np_TCPDisplay import TcpDisplay
        #import np_TCPDisplay
        obdict = self.obdict
        oblist.sort()
        clist = []
        for m in oblist:
            type = m[2]
            if type == 'TCPConn':
                #conn = obdict[m[0]][0].reconstruct(trace=1)
                conn = obdict[m[0]].reconstruct(trace=1)
                conn.adjust_tm_offsets(conn.abstart)
                clist.append(conn)
            elif type == 'Rank Server':
                for c in obdict[m[0]]:
                    conn = obdict[c].reconstruct(trace=1)
                    conn.adjust_tm_offsets(conn.abstart)
                    clist.append(conn)
            else:
                whoops('Can\'t draw TCP Connection type %s\n' % (type))
                sys.exit (0)

        t = TcpDisplay(clist, standalone='no', logfun=self.nullf, trace=1)
        
        return

###############################################################################

    def draw_times(self, what, l):

        obdict = self.obdict   
        stuff = self.conntimes
        if what in ['S Latency', 'S PLatency']:
            field = 3
        else:
            field = 4
        trig = self.extract_dir[what][0]
        title = self.extract_dir[what][2]
        print 'trig %x field %d' % (trig, field)
        sets = []
        nsets = len(l)
        i = 0
        for o in l:
            connids = obdict[o[0]]
            label = o[1]
            if nsets > 1:
                tag = i
            else:
                tag = None
                
            cd = {}
            for c in connids:
                cd[c] = c
              
            data = [[d[1]/1000000.0, d[field]/1000.0, d[0], [obdict[d[0]]]] for d in stuff if (d[0] in cd and d[2] & trig and not d[2] & TRIG_RTT_INVALID)]
            #for d in data:
                #print d
            try:
                sets.append(DataSet(data, DATA_TS, label, 
				 tag, callback=self.show_fun))
            except EmptyDataSetError, s:
                print 'Data set %s empty' % (s)
                continue
            i += 1

        if len(sets):    
            np_Plot(sets, standalone='no', path=obdict['filepath'], 
			    title=title, 
			    xlab='time s', ylab='Delay ms')

###############################################################################

    def draw_bw(self, oblist):

        bwsl = []
        for m in oblist:
            get = m[1]
            if get.find('Client data') >= 0:
                bwsl.append(self.cdbws)
            elif get.find('Server data') >= 0:
                bwsl.append(self.sdbws)
            else:
                print 'Don\'t know what bandwidth to draw'
                return

        sets = []
        i = 0
        for bws in bwsl:
            data = []
            for bwr in bws:
                for bw in bwr[2]:
                    data.append([bwr[1], bw, None, [bwr[0]]])
            sets.append(DataSet(data, DATA_TS, '', i, callback=self.bwconn))
            i += 1
        print self.basepath 
        np_Plot(sets, standalone='no', path=self.basepath,
                    title='', xlab='elapsed time s',
                    ylab='')
                                
###############################################################################

#
# Call-back for Sellist if desired to plot stats that this object knows about
#

    def drawfun(self, what, oblist, obdict):

        print what
        print oblist

        if what == 'Connections':
            self.draw_conns(what, oblist)

        elif what == 'Bandwidths':
            self.draw_bw(oblist)
            
        elif what in ['S Latency', 'S PLatency', 'S PRTT', 'S AllPRTT']:
            self.draw_times(what, oblist)
            
        else:
            print 'drawfun - don\'t recognise what to draw (given %s)' % (what)

###############################################################################

    def bwconn(self, l):

        connlist = []

        for c in l:
            C = self.obdict[c].reconstruct(trace=1)
            C.adjust_tm_offsets(C.abstart)
            connlist.append(C)
            
        t = np_TCPDisplay.TcpDisplay(connlist, standalone='no',
                                     logfun=self.nullf, trace=1)
        del(t.display)
        del(t)

###############################################################################

#
# Call-back for plotter if desired to recreate data point object
# - list will be of FileRec objects
#
    
    def show_fun(self, list):
        
        connlist = []
        for o in list:
            conn = o.reconstruct()
            conn.adjust_tm_offsets(conn.abstart)
            connlist.append(conn)
            
        t = np_TCPDisplay.TcpDisplay(connlist, standalone='no',
                                     logfun=self.nullf, trace=1)
        del(t.display)
        del(t)
        
###############################################################################

    def calc_ave_prtts(self):

	t = 0
	n = 0

	for p in self.sprtts:
	    t = t+p[1]
	    n = n+1

	if n:
	    self.saprtt = t/n

	t = 0
	n = 0

	for p in self.cprtts:
	    t = t+p[1]
	    n = n+1

	if n:
	    self.caprtt = t/n

	self.tartt = self.saprtt + self.caprtt

###############################################################################

    def results(self):

        def sort_a(a, b):
            return len(b[1])-len(a[1])

	#adict.printself()
    ##     ss = adict.report_string()
##         for goo in ss:
##             write_log(goo)

#        adict.write_log(logf, 'A', 'B')

        #
        # Message buffer - stuff specific to this analysis
        #
        msgs = []

	#
	# Find most connected server
	#

	# Fields are (server.NBO, [nconns, [connlist]])
	slist = self.serv_dict.items()
	slist.sort(sort_a)
	#print slist
	#slist.reverse()

	for i in range(min(10, len(slist))):
	    c = slist[i]
            rank = self.rank_str(i)
	    str = 'Rank Server %s most connected server %s (%d):' % (rank ,
                                        intoa_string(c[0]), len(c[1]))
	    
            self.write_log(str)
            c[1].sort()
            self.obdict[str] = c[1]

## 	#
## 	# Find most connected client
## 	#

## 	# Fields are (server.NBO, [nconns, [connlist]])
## 	clist = self.cli_dict.items()
## 	clist.sort(sort_a)

## 	for i in range(min(10, len(clist))):
## 	    c = clist[i]
##             rank = self.rank_str(i)
## 	    str = 'Rank Client %s most connected client %s (%d):' % (rank ,
##                                         intoa_string(c[0]), len(c[1]))
	    
##             self.write_log(str)
##             c[1].sort()
##             self.obdict[str] = c[1]

        
        str = 'XMsg Bandwidths: Client data bandwidths: Client data bandwidths: Client data bandwidths: '
        self.write_log(str)
        str = 'XMsg Bandwidths: Server data bandwidths: Server data bandwidths: Server data bandwidths: '
        self.write_log(str)
        
 
	
##         # tmp
##         f = open('/local/scratch/jch1003/nprobe/tmp/bluebird.989851149/footmp', 'w')
##         for t in self.reptms:
##             #print t/1000
##             f.write('%.3f\n' % (t/1000000.0))

        
        if self.quiet:
            self.ectrs.printself(self.tfilt, f=self.logf, leader='')
            self.adict.printself_tofile(self.logf)
            self.dump_log()
            self.close_log()
        else:
            self.ectrs.printself_tolist(self.log, self.tfilt)
            self.adict.printself_tolist(self.log)
            self.save_log()
            # draw button spec from stats object
            draw_menu = [[e[0], e[1], e[3]] for e in self.extract_what]
        
            Sellist(draw_menu, self.log, self.obdict, self.drawfun)

	









