#!/bin/env python
"""
Converts an NS-2 tracefile to a dictionnary, then generate a plottable trace.
Copyright (c) 2008 Nicta, Olivier Mehani <olivier.mehani@nicta.com.au>
All rights reserved.

$Id: nsaggregator.py 179 2008-04-21 07:05:25Z omehani $

Generic (or intented to be) NS-2 tracefile to Python dictionnary converter.  A
tracefile is also generated according the the user-given selectors.

"""

__version__ = "$Revision: 179 $"

import sys
import re
import math

class NsParser():
    _filename = None
    _event_re = None
    _packet_type_re = None

    _file = None
    _re = None
    _re_compiled = None

    _re_trace_normal = None
    _re_trace_wireless = None
    _re_trace_wireless_pos = None
    _re_trace_wireless_new = None

    def __init__(self, filename, event_re, packet_type_re):
	self._filename = filename
	self._event_re =  event_re
	self._packet_type_re = packet_type_re

	if self._filename == "-":
	    self._file = sys.stdin
	else:
	    self._file = file(self._filename, 'r')
	self._re = "^" + self._event_re + r"\s.*\s" + self._packet_type_re + r"\s"
	self._re_compiled = re.compile(self._re)

	# Time SrcNode DstNode Type Size -Flags- Flow SrcAddr.Port DstAddr.Port
	#   Seq ID ...
	self._re_trace_normal = re.compile(r"^(.) (\d+(\.\d+)?) (\d+) (\d+) " \
		"([_\w]+) (\d+) ([-\w]{7}) (\d+) ((\d+\.){1,3}\d+) " \
		"((\d+\.){1,3}\d+) (-?\d+) (\d+)")
	# Time _Node_ TraceName Reason ID Type Size [TTS DstMac SrcMac Type]
	wireless_re = r"^(.) (\d+(\.\d+)?) _(\d+)_ (\w{3}) ([- \w]{4}) " + \
		"(-?\d+) ([_\w]+) (\d+)  \[([abcdef\d]+) ([abcdef\d]+) " \
		"([abcdef\d]+) ([abcdef\d]+)\]"
	self._re_trace_wireless = re.compile(wireless_re)
	# Time Node (X Y) TraceName Reason ID Type Size [TTS DstMac SrcMac
	#   Type]
	self._re_trace_wireless_pos = re.compile(wireless_re.replace(
	    r"_(\d+)_", r"(\d+) \(([ \d]+\.\d+) ([ \d]+\.\d+)\)"))
	# FIXME: TBI
	self._re_trace_wireless_new = re.compile(r"FIXME^")

	self._items_trace_normal = ["event", "time", "void0", "src_node",
		"dst_node", "pkt_type", "size", "flags", "flow",
		"src_addr_port", "void1", "dst_addr_port", "void2", "seq",
		"id"]
	self._items_trace_wireless = ["event", "void0", "time", "dst_node",
		"trace_name", "reason", "id", "pkt_type", "size", "tts",
		"dst_mac", "src_mac", "frm_type"]
	self._items_trace_wireless_pos = ["event", "void0", "time", "dst_node",
		"pos_x", "pos_y", "trace_name", "reason", "id", "pkt_type",
		"size", "tts", "dst_mac", "src_mac", "frm_type"]

    def __del__(self):
	if self._filename != "-":
	    self._file.close()

    def parse_next_line(self,line=None):
	while 1:
	    line = self._get_next_matching_line()

	    if self._re_trace_normal.match(line):
		return self._parse_trace_normal(line)
	    elif re.match(self._re_trace_wireless, line):
		return self._parse_trace_wireless(line)
	    elif re.match(self._re_trace_wireless_pos, line):
		return self._parse_trace_wireless_pos(line)
	    elif re.match(self._re_trace_wireless_new, line):
		return self._parse_trace_wireless_new(line)
	    else:
		raise UnmatchedLineException(line)

    def _parse_trace_normal(self,line):
	matches = self._re_trace_normal.findall(line)[0]
	return dict([(self._items_trace_normal[i], matches[i]) for i in \
	    xrange(len(matches))])

    def _parse_trace_wireless(self, line):
	matches = self._re_trace_wireless.findall(line)[0]
	return dict([(self._items_trace_wireless[i], matches[i]) for i in \
	    xrange(len(matches))])

    def _parse_trace_wireless_pos(self,line):
	matches = self._re_trace_wireless_pos.findall(line)[0]
	return dict([(self._items_trace_wireless_pos[i], matches[i]) for i in \
	    xrange(len(matches))])

    def _parse_trace_wireless_new(self,line):
	raise NotImplementedError("NsParser._parse_trace_wireless_new()")
	matches = self._re_trace_wireless_new.findall(line)[0]
	return dict([(self._items_trace_wireless_new[i], matches[i]) for i in \
		xrange(len(matches))])


    def _get_next_matching_line(self):
	line=""
	while not re.match(self._re_compiled, line):
	    line = self._get_next_line()
	return line

    def _get_next_line(self):
	line = self._file.readline()
	if line == "":
	    raise EofException()
	return line.rstrip()

class NsAggregator():
    _nodes = []
    # Granularity of the output, in seconds
    _granularity = None
    _last_event_time = 0.
    _last_aggregation_time = 0.

    def __init__(self, nodes_param_list, granularity = 1.):
	self._granularity = float(granularity)
	self._last_event_time = 0.
	self._last_aggregation_time = 0.
	for node_param in nodes_param_list:
	    new_node = NsNodeInfo(*node_param)
	    self._nodes.append(new_node)


    def new_event(self, event_dict):
	self._last_event_time = float(event_dict['time'])
	time_delta = self._last_event_time - self._last_aggregation_time

	if  time_delta > self._granularity:
	    aggregation_steps = list(self._compute_aggregation_steps())
	    # the last event has not been recorded yet, hence it is provided in
	    # the exception to allow the controlling loop to make a new call to
	    # this procedure after having sorted out the missing aggregation
	    # steps which list is also provided in the exception
	    raise MustAggregateException(self._last_event_time,
		    aggregation_steps, event_dict)

	for node in self._nodes:
	    node.new_event(event_dict)

    def _compute_aggregation_steps(self):
	t = self._last_aggregation_time + self._granularity
	# FIXME: is this really correct? shouldn't it be '<'?
	while t <= self._last_event_time:
	    yield t
	    t += self._granularity

    def aggregate(self, t=None):
	if t == None:
	    t = self._last_event_time
	for node in self._nodes:
	    node.aggregate(t)
	self._last_aggregation_time = t
    
    def aggregate_last(self):
	if self._last_event_time != self._last_aggregation_time:
	    for node in self._nodes:
		node.aggregate()


    def report_header(self):
	header= "#time\t"
	for node in self._nodes:
	    header += node.report_header()
	return header

    def report_last_data(self):
	line = str(self._last_aggregation_time)
	for node in self._nodes:
	    line += "\t" + node.report_last_data()
	return line


class NsNodeInfo():
    _id = None
    _event_re = ".*"
    _event_re_compiled = None
    _pkt_type_re = ".*"
    _pkt_type_re_compiled = None
    _with_pos = False

    _time = -1
    _last_event_time = 0
    _previous_aggregation_time = 0
    _last_aggregation_time = 0

    _last_sizes = []
    _last_sum = 0
    _complete_sum = 0
    _last_packets = 0
    _all_packets = 0

    _current_position = (None, None) 
    _last_position = (None, None) 

    _complete_sum = 0
    _last_sum = 0

    def __init__(self, id, event_re, pkt_type_re, with_pos = None):
	self._id = int(id)
	self._event_re = event_re
	self._event_re_compiled = re.compile(self._event_re)
	self._pkt_type_re = pkt_type_re
	self._pkt_type_re_compiled = re.compile(self._pkt_type_re)
	if with_pos != None:
	    self._with_pos = with_pos == "yes"
	self._last_sizes = []
	self._sum = 0
	self._complete_sum = 0
	self._last_packets = 0
	self._all_packets = 0

    def new_event(self, event_dict):
	if int(event_dict['dst_node']) != self._id or \
	    (event_dict.has_key('trace_name') and \
			event_dict['trace_name'] != "AGT"):
	    return False
	if event_dict.has_key('pos_x') and event_dict.has_key('pos_y'):
	    self._current_position = (event_dict['pos_x'], event_dict['pos_y'])
	if not self._event_re_compiled.match(event_dict['event']) or \
		not self._pkt_type_re_compiled.match(event_dict['pkt_type']):
	    return False
	self._last_event_time = float(event_dict['time'])
	self._last_sizes.append(int(event_dict['size']))
	return True

    def aggregate(self, time=None):
	if not time:
	    time = self._last_aggregation_time
	if time == self._last_aggregation_time:
	    return
	self._previous_aggregation_time = self._last_aggregation_time
	self._last_aggregation_time = time
	self._last_sum = sum(self._last_sizes)
	self._complete_sum += self._last_sum
	self._last_packets = len(self._last_sizes)
	self._all_packets += self._last_packets
	self._last_sizes = []
	self._last_position = self._current_position
	self._current_position = (None, None)


    def report_header(self):
	header = "#" + str(self._id) + " " + self._pkt_type_re +  " BW (Mbps)\tAVG\tPkts\tPkts tot"
	if self._with_pos:
	    header +="\tX\tY"
	return header

    def report_last_data(self):
	line = str(8. * self._last_sum / \
		(self._last_aggregation_time - self._previous_aggregation_time) /1024**2)
	if self._last_aggregation_time == 0:
	    line += "\tself._complete_sum"
	else:
	    line += "\t" + str(8. * self._complete_sum / self._last_aggregation_time /1024**2)
	line += "\t" + str(self._last_packets) + "\t" + str(self._all_packets)
	if self._with_pos:
	    if self._last_position == (None, None):
		line += "\t.\t."
	    else:
		line +="\t" + str(self._last_position[0]) + \
			"\t" + str(self._last_position[1])
	return line


class EofException (Exception):
    pass


class UnmatchedLineException (Exception):
    message = "Unmatched traceline: "

    def __init__(self, line):
	self.message += line

class MustAggregateException(Exception):
    time = 0
    missing_steps = {}
    next_event = None
    def __init__(self, time, missing_steps, next_event):
	self.time = time
	self.missing_steps = missing_steps
	self.next_event = next_event

if __name__ == "__main__":

    def print_usage():
	sys.stderr.write("usage: " + sys.argv[0] + " FILENAME|- " + \
		"GRANULARITY NODE EVENT_RE PKT_TYPE_RE WITH_POS(yes|no) [NODE EVENT_RE PKT_TYPE_RE WITH_POS ...]\n")

    if len(sys.argv) < 6:
	print_usage()
	sys.exit(1)
    elif (len(sys.argv) - 3)%4:
	print_usage()
	sys.stderr.write("	There must be exactly four parameters (NODE, " + \
		"EVENT_RE, PKT_TYPE_RE, WITH_POS) per node\n")
	sys.exit(1)
    else:
	nodes_param_list = []
	event_list = []
	pkt_type_list = []
	for i in xrange(len(sys.argv[2:])/4):
	    idx = 3+i*4
	    nodes_param_list.append(sys.argv[idx:idx+4])
	    event = sys.argv[idx+1]
	    pkt_type = sys.argv[idx+2]
	    if event not in event_list:
		event_list.append(event)
	    if pkt_type not in pkt_type_list:
		pkt_type_list.append(pkt_type)

	event_re = "(" + "|".join(event_list) + ")"
	pkt_type_re = "(" + "|".join(pkt_type_list) + ")"
	
	re.compile(event_re)

	parser = NsParser(sys.argv[1], event_re, pkt_type_re)

	aggregator = NsAggregator(nodes_param_list, sys.argv[2])
	print aggregator.report_header()

	while 1:
	    try:
		aggregator.new_event(parser.parse_next_line())
	    except UnmatchedLineException, e:
		sys.stderr.write("Warning! " + e.message + "\n")
		continue
	    except MustAggregateException, e:
		#sys.stderr.write("Aggregating steps at %f: %s before %s\n"  % (e.time, str(e.missing_steps), str(e.next_event)))
		for time in e.missing_steps:
		    aggregator.aggregate(time)
		    print aggregator.report_last_data()
		aggregator.new_event(e.next_event)
		continue
	    except EofException:
		try:
		    aggregator.aggregate_last()
		    print aggregator.report_last_data()
		except ZeroDivisionError:
		    sys.stderr.write("error: division by zero when finalising, this " + \
			"usually means the specified EVENT_TYPE or " + \
			"PKT_TYPE_RE is incorrect.\n")
		    sys.exit(1)
		sys.exit(0)

