#!/usr/bin/python
from socket import *
import threading  
import time
import traceback
import xml.sax
import xml.sax.handler
from xml.sax.saxutils import XMLGenerator
import sqlite3
import os
import os.path
import sys
import logging

class ControlConnection:
    """
    Class implementing a connection to a DVBStreamer daemon.
    """
    def __init__(self, host, adapter, username='dvbstreamer', password='control'):
        """
        Create a connection object to communicate with a DVBStreamer daemon.
        """
        self.host = host
        self.adapter = adapter
        self.opened = False
        self.version = None
        self.welcome_message = None
        self.my_ip = None
        
        self.username = username
        self.password = password

    def open(self):
        """
        Open the connection to the DVBStreamer daemon.
        """
        if self.opened:
            return
        self.socket = socket(AF_INET,SOCK_STREAM)
        self.socket.connect((self.host, self.adapter + 54197))

        self.my_ip = self.socket.getsockname()[0]

        self.socket_file = self.socket.makefile('r+')
        self.opened = True
        (error_code, error_message, lines) = self.read_response()

        if error_code != 0:
            self.socket.close()
            self.opened = False
        else:
            self.welcome_message = error_message

        return self.opened

    def close(self):
        """
        Close the connection to the DVBStreamer daemon.
        """
        if self.opened:
            self.send_command('logout')
            self.socket_file.close()
            self.socket.close()
            self.opened = False

    def send_command(self, command):
        """
        Send a command to the connection DVBStreamer daemon.
        """
        if not self.opened:
            raise RuntimeError, 'not connected'

        self.socket_file.write(command + '\n')
        self.socket_file.flush()


    def read_response(self):
        """
        Read a response from the DVBStreamer deamon after a command has been
        sent.
        Returns a tuple of error code, error message and response lines.
        """
        more_lines = True
        lines = []
        error_code = -1
        error_message = ''
        while more_lines:
            line = self.socket_file.readline()

            if line.startswith('DVBStreamer/'):
                more_lines = False
                sections = line.split('/')
                self.version = sections[1]
                error_sections = sections[2].split(' ', 1)
                error_code = int(error_sections[0])
                if len(error_sections) > 1:
                    error_message = error_sections[1].strip()
                else:
                    error_message = ''
            elif line == '':
                more_lines = False
            else:
                lines.append(line.strip('\n\r'))

        return (error_code, error_message, lines)
    
    def execute_command(self, command):
        """
        Send command and wait for response
        Returns a tuple of error code, error message and response lines.
        """
        
        self.send_command(command)
        return self.read_response()
    
    def authenticate(self):
        """
        Send the authenticate command using the username and password specified 
        when the connection was created.
        """
        ec, em, lines = self.execute_command('auth %s %s' % (self.username, self.password))
        if ec != 0:
            raise RuntimeError('Failed to authenticate')
        
    def get_services(self, mux=''):
        """
        Get the list of services available on all or a specific multiplex.
        """
        self.send_command('lsservices -id %s' % mux)
        (errcode, errmsg, lines) = self.read_response()
        if errcode != 0:
            return None
        services = []
        for line in lines:
            services.append(line.split(' : '))
        return services
    
    def get_service_info(self, service):
        """
        Retrieve information about the specified service.
        """
        self.send_command('serviceinfo "%s"' % service)
        (errcode, errmsg, lines) = self.read_response()
        if errcode != 0:
            return None
        dict = {}
        for line in lines:
            name, value = line.split(':', 2)
            dict[name.strip()] = value.strip()
        return dict
    
    def capture_epg_data(self):
        """
        Start capturing EPG data and send it to this connection.
        """
        self.authenticate()
        self.execute_command('epgcaprestart')
        self.send_command('epgdata')

class EPGProcessor(xml.sax.handler.ContentHandler):
    def __init__(self, dbase, host, adapter, username, password):
        self.connection = ControlConnection(host, adapter, username, password)
        self.__stop = False
        self.dbase = dbase
        self.thread = threading.Thread(target=self.__run)
        self.thread.setName('EPGProcessor-%s-%d'% (host, adapter))
        self.thread.start()
    
    def stop(self):
        self.__stop = True
        try:
            if self.connection:
                self.connection.close()
        except:
            traceback.print_exc()
        self.thread.join()

    def __run(self):
        while not self.__stop:
            try:
                self.connection.open()
                self.connection.capture_epg_data()
                xml.sax.parse(self.connection.socket_file, self)
            except:
                traceback.print_exc()
                self.connection.close()
            
            if not self.__stop:
                time.sleep(2.0)
    
    def startElement(self, name, attrs):
        if name == 'event':
            # Start transaction
            self.event_ref = (parse_int(attrs.getValue('net')), 
                              parse_int(attrs.getValue('ts')),
                              parse_int(attrs.getValue('source')),
                              parse_int(attrs.getValue('event')))
            
        elif name == 'new':
            new_start = time.strptime(attrs.getValue('start'), '%Y-%m-%d %H:%M:%S')
            new_end = time.strptime(attrs.getValue('start'), '%Y-%m-%d %H:%M:%S')
            new_ca = attrs.getValue('ca') == 'yes'
            
            self.dbase.new_event(self.event_ref, new_start, new_end, new_ca)
            
        elif name == 'detail':
            self.detail_name = attrs.getValue('name')
            self.detail_lang = attrs.getValue('lang')
            self.detail_content = ''
            
        elif name == 'rating':
            rating_system = attrs.getValue('system')
            rating_value = attrs.getValue('value')
            self.dbase.update_rating(self.event_ref, rating_system, rating_value)
            
        self.current_element = name

    def endElement(self, name):
        if name == 'detail':
            self.dbase.update_detail(self.event_ref, self.detail_name, self.detail_lang, self.detail_content)
    
    def characters(self, content):
        if self.current_element == 'detail':
            self.detail_content += content

def parse_int(str):
    if str.startswith('0x'):
        return int(str[2:], 16)
    if str.startswith('0'):
        return int(str[1:], 8)
    return int(str)

class EPGDatabase:
    EVENTS_TABLE = 'Events'
    EVENTS_COLUMN_SOURCE= 'source'
    EVENTS_COLUMN_EVENT = 'event'
    EVENTS_COLUMN_START = 'start'
    EVENTS_COLUMN_END   = 'end'
    EVENTS_COLUMN_CA    = 'ca'
    EVENTS_COLUMNS = [EVENTS_COLUMN_SOURCE , EVENTS_COLUMN_EVENT, 
                        EVENTS_COLUMN_START  , EVENTS_COLUMN_END  , EVENTS_COLUMN_CA]
    
    RATINGS_TABLE = 'Ratings'
    RATINGS_COLUMN_SOURCE = 'source'
    RATINGS_COLUMN_EVENT  = 'event'
    RATINGS_COLUMN_SYSTEM = 'system'
    RATINGS_COLUMN_VALUE  = 'value'
    RATINGS_COLUMNS = [RATINGS_COLUMN_SOURCE, RATINGS_COLUMN_EVENT, RATINGS_COLUMN_SYSTEM, RATINGS_COLUMN_VALUE]
    
    DETAILS_TABLE = 'Details'
    DETAILS_COLUMN_SOURCE = 'source'
    DETAILS_COLUMN_EVENT  = 'event'
    DETAILS_COLUMN_NAME   = 'name'
    DETAILS_COLUMN_LANG   = 'lang'
    DETAILS_COLUMN_VALUE  = 'value'
    DETAILS_COLUMNS = [DETAILS_COLUMN_SOURCE, DETAILS_COLUMN_EVENT, DETAILS_COLUMN_NAME, 
                        DETAILS_COLUMN_LANG, DETAILS_COLUMN_VALUE]
    
    def __init__(self, dbpath, updater=False, reaper=False):
        self.dbpath = dbpath
        self.create_tables()
        
        self.connections = {}
        
        self.message_q = []
        self.message_event = threading.Event()
        self.__quit = False
        if updater:
            self.updater_thread = threading.Thread(target=self.updater)
            self.updater_thread.setName('EPGDbase Updater')
            self.updater_thread.start()
        else:
            self.updater_thread = None

        if reaper:
            self.reaper_thread = threading.Thread(target=self.reaper)
            self.reaper_thread.setDaemon(True)
            self.reaper_thread.setName('EPGDbase Reaper')
            self.reaper_thread.start()
        
    def close(self):
        if self.updater_thread:
            self.__quit = True
            self.post_message(('quit',))
            self.updater_thread.join()
        
    def create_tables(self):
        connection = sqlite3.Connection(self.dbpath)
        cursor = connection.cursor()
        
        self.create_table(cursor, EPGDatabase.EVENTS_TABLE,  EPGDatabase.EVENTS_COLUMNS, 
                            [EPGDatabase.EVENTS_COLUMN_SOURCE, EPGDatabase.EVENTS_COLUMN_EVENT])
                              
        self.create_table(cursor, EPGDatabase.RATINGS_TABLE, EPGDatabase.RATINGS_COLUMNS, 
                            [EPGDatabase.RATINGS_COLUMN_SOURCE, EPGDatabase.RATINGS_COLUMN_EVENT,
                                EPGDatabase.RATINGS_COLUMN_SYSTEM])
        
        self.create_table(cursor, EPGDatabase.DETAILS_TABLE, EPGDatabase.DETAILS_COLUMNS, 
                            [EPGDatabase.DETAILS_COLUMN_SOURCE, EPGDatabase.DETAILS_COLUMN_EVENT,
                                EPGDatabase.DETAILS_COLUMN_NAME, EPGDatabase. DETAILS_COLUMN_LANG])
        connection.close()
        
    
    def create_table(self, cursor, table, fields, primary_key):
        sql = 'CREATE TABLE IF NOT EXISTS ' + table + '('
        for field in fields:
            sql += field + ','
        sql += 'PRIMARY KEY('
        for field in primary_key:
            sql += field + ','
        sql =sql[:-1] + '));'
        
        cursor.execute(sql)
    
    def get_connection(self):
        thread = threading.currentThread()
        if thread in self.connections:
            return self.connections[thread]
        connection = sqlite3.Connection(self.dbpath)
        self.connections[thread] = connection
        return connection
        
    def close_connection(self):
        thread = threading.currentThread()
        if thread in self.connections:
            self.connections[thread].close()
            del self.connections[thread]
        
    def get_events(self, source):
        connection = self.get_connection()
        cursor = connection.cursor()
        cursor.execute('SELECT %s FROM %s WHERE %s=?;' % (reduce(lambda x,y: x +','+y,EPGDatabase.EVENTS_COLUMNS), 
                                                         EPGDatabase.EVENTS_TABLE, 
                                                         EPGDatabase.EVENTS_COLUMN_SOURCE), (source,))
        return cursor.fetchall()
    
    def get_details(self, source, event):
        connection = self.get_connection()
        cursor = connection.cursor()
        cursor.execute('SELECT %s,%s,%s FROM %s WHERE %s=? AND %s=?;' % (EPGDatabase.DETAILS_COLUMN_NAME, 
                                                         EPGDatabase.DETAILS_COLUMN_LANG,
                                                         EPGDatabase.DETAILS_COLUMN_VALUE, 
                                                         EPGDatabase.DETAILS_TABLE, 
                                                         EPGDatabase.DETAILS_COLUMN_SOURCE,
                                                         EPGDatabase.DETAILS_COLUMN_EVENT), (source, event))
        details = {}
        for name,lang,value in cursor:
            if name in details:
                details[name].append((lang, value))
            else:
                details[name] = [(lang, value)]
        
        return details
    
    def get_ratings(self, source, event):
        connection = self.get_connection()
        cursor = connection.cursor()
        cursor.execute('SELECT %s,%s FROM %s WHERE %s=? AND %s=?;' % (EPGDatabase.RATINGS_COLUMN_SYSTEM,
                                                         EPGDatabase.RATINGS_COLUMN_VALUE,
                                                         EPGDatabase.RATINGS_TABLE, 
                                                         EPGDatabase.RATINGS_COLUMN_SOURCE,
                                                         EPGDatabase.RATINGS_COLUMN_EVENT), (source, event))
        return cursor.fetchall()
        
        
    def new_event(self, event_ref, start, end, ca):    
        self.post_message(('new', event_ref, start, end, ca))
        
    def __new_event(self, event_ref, start, end, ca):
        if 'TZ' in os.environ:
            tz = os.environ['TZ']
        else:
            tz = None
            
        os.environ['TZ']='UTC0'
        time.tzset()
        start_secs = int(time.mktime(start))
        end_secs = int(time.mktime(end))
        
        if tz is not None:
            os.environ['TZ'] = tz
            
        time.tzset()
        self.__update(EPGDatabase.EVENTS_TABLE, EPGDatabase.EVENTS_COLUMNS,
                        ('%04x.%04x.%04x' % event_ref[:3], event_ref[3], start_secs, end_secs, ca))
        
    
    def update_detail(self, event_ref, name, lang, value):
        self.post_message(('detail', event_ref, name, lang, value))
    
    def __update_detail(self, event_ref, name, lang, value):
        self.__update(EPGDatabase.DETAILS_TABLE, EPGDatabase.DETAILS_COLUMNS,
                        ('%04x.%04x.%04x' % event_ref[:3], event_ref[3], name, lang, value))
    
    def update_rating(self, event_ref, system, value):
        self.post_message(('rating', event_ref, system, value))
        
    def __update_rating(self, event_ref, system, value):
        self.__update(EPGDatabase.RATINGS_TABLE, EPGDatabase.RATINGS_COLUMNS,
                        ('%04x.%04x.%04x' % event_ref[:3], event_ref[3], system, value))
    
    def __update(self, table, columns, values):
        connection = self.get_connection()
        cursor = connection.cursor()
        sql = 'INSERT OR REPLACE INTO %s ( %s) VALUES (%s);' % (table, 
                                                                reduce(lambda x,y: x +','+y, columns),
                                                                reduce(lambda x,y: x.startswith('?') and x +',?' or '?,?', values))
        cursor.execute(sql, values)
        
    def post_message(self, msg):
        self.message_q.append(msg)
        self.message_event.set()
        
    def updater(self):
        connection = self.get_connection()
        updates = 0
        while not self.__quit:
            if len(self.message_q) == 0:
                self.message_event.wait()
                
            if len(self.message_q) > 0:
                msg = self.message_q.pop()
                if msg[0] == 'commit':
                    connection.commit()
                    updates = 0
                    
                elif msg[0] == 'new':
                    updates += 1
                    self.__new_event(*msg[1:])
                    
                elif msg[0] == 'detail':
                    updates += 1
                    self.__update_detail(*msg[1:])
                    
                elif msg[0] == 'rating':
                    updates += 1
                    self.__update_rating(*msg[1:])
                    
                elif msg[0] == 'reap':
                    updates = 0
                    self.__reap()
                    
                elif msg[0] == 'quit':
                    connection.commit()
                    
            if len(self.message_q) == 0:
                self.message_event.clear()
            

    def reaper(self):
        last_reap = 0.0
        while not self.__quit:
            now = time.time()
            if (now - last_reap) > 360.0:
                # Reap old events
                self.post_message(('reap',))
                last_reap = now
            else:
                self.post_message(('commit',))
            # Sleep for 30 seconds
            time.sleep(30.0)

    def __reap(self):
        connection = self.get_connection()
        cursor = connection.cursor()
        past24 = int(time.time() - (24.0 * 360.0))
        cursor.execute('DELETE FROM Events WHERE end<=?;',(past24,))
        # Remove orphaned details and ratings
        cursor.execute('DELETE FROM Details WHERE NOT EXISTS ' +
                        '(SELECT * FROM Events WHERE source=Details.source AND event=Details.event);')

        cursor.execute('DELETE FROM Ratings WHERE NOT EXISTS '+ 
                        '(SELECT * FROM Events WHERE source=Ratings.source AND event=Ratings.event);')

def usage():
    print '%s -c' % sys.argv[0]
    print 'Connects to one or more DVBStreamer daemons, specified in .epgdb/adapters, to process and store EPG data.'
    
    print 
    
    print '%s -x <outputfile>' % sys.argv[0]
    print 'Exports the database as an XMLTV file (<outputfile>) by connecting to the DVBStreamer daemons used to capture the data to find available services/channels'
    print 
    print 'Format of .epgdb/adapters file'
    print '------------------------------'
    print 'One DVBStreamer host/adapter specified per line in the following format:'
    print '<host>[:adapter[:<username>[:<password>]]]'
    print
    print 'host     - Either an IP address of hostname'
    print 'adapter  - The adapter number used by the DVBStreamer daemon (Default:0)'
    print 'username - Username used to authenitcate with the DVBStreamer daemon (Default:dvbstreamer)'
    print 'password - Password used to authenitcate with the DVBStreamer daemon (Default:control)'
    sys.exit(1)

    
def process_host(host):
    if host[0] == '[':
        # IPv6
        i = host.find(']')
        hostname = host[:i + 1]
        if len(host) > i + 1:
            rest = host[i + 2:]
        else:
            rest = ''
    else:
        tmp = host.find(':')
        if tmp == -1:
            rest = ''
            hostname = host
        else:
            rest =  host[tmp + 1:]
            hostname = host[:tmp]
    
    username = 'dvbstreamer'
    password = 'control'
    
    if rest:
        parts = rest.split(':')
        adapter = int(parts[0])
        if len(parts) > 1:
            username = parts[1]
        if len(parts) > 2:
            password = parts[2]
    else:
        adapter = 0
    return (hostname, adapter, username, password)

def capture_start(args):
    dbase = EPGDatabase(os.path.join(app_dir, 'database.db'), True, True)

    processors = []
    
    for adapter in get_adapters():
        processors.append(EPGProcessor(dbase, adapter[0], adapter[1], adapter[2], adapter[3]))
    
    while True:
        try:
            time.sleep(1.0)
        except KeyboardInterrupt:
            break
        except:
            traceback.print_exc()
            break
    
    for processor in processors:
        processor.stop()
    dbase.close()


def get_services_info(services, adapter):
    connection = ControlConnection(*adapter)
    try:
        connection.open()
        
        for id,name in connection.get_services():
            info = connection.get_service_info(id)
            net,ts,service = id.split('.')
            source = eval(info['Source'])
            services[id] = (name, '%s.%s.%04x' % (net, ts, source))
        
        connection.close()
    except:
        traceback.print_exc()
        
def format_time(secs):
    return time.strftime('%Y%m%d%H%M%s', time.gmtime(secs))

def export_xmltv(args):
    if len(args) < 1:
        usage()
    
    dbase = EPGDatabase(os.path.join(app_dir, 'database.db'))
    outfilename = args[0]
    services = {}
    for adapter in get_adapters():
        get_services_info(services, adapter)
    
    outfile = open(outfilename, 'w')
    generator = XMLGenerator(outfile, 'UTF-8')
    generator.startDocument()
    generator.startElement('tv', {})
    
    # Channels
    for id,(name, source) in services.items():
        generator.startElement('channel', {'id':id})
        generator.startElement('display-name', {})
        generator.characters(name)
        generator.endElement('display-name')
        generator.endElement('channel')
    
    # Programmes
    for id,(name, source) in services.items():
        for event in dbase.get_events(source):
            start = format_time(event[2])
            stop  = format_time(event[3])
            
            generator.startElement('programme', {'start':start, 'stop':stop, 'channel':id})
            
            details = dbase.get_details(source, event[1])
            if 'title' in details:
                for lang,value in details['title']:
                    generator.startElement('title', {'lang':lang})
                    generator.characters(value)
                    generator.endElement('title')
            
            if 'description' in details:
                for lang,value in details['description']:
                    generator.startElement('desc', {'lang':lang})
                    generator.characters(value)
                    generator.endElement('desc')
            
            if 'content' in details:
                content = details['content'][0][1]
                generator.startElement('content', {})
                generator.characters(content)
                generator.endElement('content')
            
            if 'series' in details:
                series = details['series'][0][1]
                generator.startElement('series', {})
                generator.characters(series)
                generator.endElement('series')
            
            ratings = dbase.get_ratings(source, event[1])
            for system,value in ratings:
                generator.startElement('rating', {'system':system})
                generator.startElement('value', {})
                generator.characters(value)
                generator.endElement('value')
                generator.endElement('rating')
            generator.endElement('programme')
        
    generator.endElement('tv')
    generator.endDocument()
    outfile.close()
    
    
def get_adapters():
    adapters = []
    try:
        adaptersfile = open(os.path.join(app_dir, 'adapters'), 'r')
        for line in adaptersfile:
            line = line.strip()
            if line:
                adapters.append(process_host(line))
        adaptersfile.close()
    except:
        traceback.print_exc()
        print 'WARNING: Failed to open adapters file, defaulting to localhost:0:dvbstreamer:control'
    if len(adapters) == 0:
        adapters = [('localhost', 0, 'dvbstreamer', 'control')]
    return adapters
    
# Location of the user data for epgdb    
app_dir = ''

def main():
    global app_dir
    app_dir = os.path.expanduser('~/.epgdb')
    
    if not os.path.exists(app_dir):
        os.mkdir(app_dir)
        
    if len(sys.argv) < 2:
        usage()
    elif sys.argv[1] == '-c':
        capture_start(sys.argv[2:])
    
    elif sys.argv[1] == '-x':
        export_xmltv(sys.argv[2:])
    
    else:
        usage()
    

if __name__ == '__main__':
    try:
        main()
    except SystemExit:
        raise
    except:
        traceback.print_exc()
    sys.exit(0)