#!/usr/bin/python
# -*- Python -*-
# Depends: openssl, python >= 2.2, python-optik or python >= 2.3, iproute, textutils

from __future__ import generators
import os, os.path, errno, string, sys

safechars='abcdefghijklmnopqrstuvwxyz' \
           + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' \
           + '0123456789_'
def makesafe_char(c):
    if c in safechars:
        return c
    else:
        return '_'

def makesafe(s):
    return ''.join([makesafe_char(c) for c in s])

def is_safe(s):
    for c in s:
        if c not in safechars:
            return False
    else:
        return True

class TempFile(file):
    def __init__(self, name, mode=0600):
        self.origname=name
        self.finalmode=mode
        file.__init__(self, name+'.tmp', 'w')

    def close(self):
        file.close(self)
        # only on _explicit_ close
        os.rename(self.name, self.origname)
        os.chmod(self.origname, self.finalmode)

class ExecTempFile(TempFile):
    def close(self):
        os.chmod(self.name, 0755)
        TempFile.close(self)

def mkdir(path, mode=0700):
    try:
        os.mkdir(path, mode)
    except OSError, e:
        if e.errno==errno.EEXIST:
            pass
        else:
            raise
    os.chmod(path, mode)

class Host:
    def __init__(self, name,
                 address=None,
                 vpnaddress=None,
                 routes=()):
        self.name=name
        self.address=address
        self.vpnaddress=vpnaddress
        self.routes=routes
        self.portsTaken={}
        self.active = (self.current_host == name)
        if self.active:
            if not os.path.exists('keys'):
                mkdir('keys', 0755)
            self.createCA()
            self.createCertificateRequest()
            self.signCertificateRequest()

    def createCA(self):
        cadir=os.path.join('keys', 'ca')
        if not os.path.exists(cadir):
            mkdir(cadir, 0755)

        index=os.path.join(cadir, 'index.txt')
        if not os.path.exists(index):
            open(index, 'w').close()

        serial=os.path.join(cadir, 'serial')
        if not os.path.exists(serial):
            f=TempFile(serial)
            print >>f, '01'
            f.close()

        cakey=os.path.join(cadir, 'ca.key')
        cacert=os.path.join(cadir, 'ca.crt')
        if not os.path.exists(cakey) or not os.path.exists(cakey):
            pipe = os.popen(' '.join((
                'openssl', 'req', '-days', '3650', '-nodes', '-new',
                '-x509', '-keyout', cakey,
                '-out', cacert,
                '-config', '/proc/self/fd/0',
                )), 'w')
            print >>pipe, '''\
[req]
prompt = no
default_bits		= 1024
distinguished_name	= req_distinguished_name

[req_distinguished_name]
commonName=OpenVPN CA
'''
            status = pipe.close()
            if status is not None:
                if os.WIFSIGNALED(status):
                    print >>sys.stderr, '%s: openssl exited due to a signal %r' % (sys.argv[0], os.WTERMSIG(status))
                elif os.WIFEXITED(status):
                    rc = os.WEXITSTATUS(status)
                    if rc==127:
                        print >>sys.stderr, '%s: openssl not found' % sys.argv[0]
                    else:
                        print >>sys.stderr, '%s: openssl exited with error %r' % (sys.argv[0], rc)
                else:
                    print >>sys.stderr, '%s: openssl got an unknown error: %r' % (sys.argv[0], status)
                sys.exit(1)
            os.chmod(cacert, 0644)

    def createCertificateRequest(self):
        hostkeydir=os.path.join('keys', 'host')
        if not os.path.exists(hostkeydir):
            mkdir(hostkeydir, 0700)

        key=os.path.join(hostkeydir, self.safename()+'.key')
        csr=os.path.join(hostkeydir, self.safename()+'.csr')
        if os.path.exists(key) and os.path.exists(csr):
            return
        if not os.path.exists(key) or not os.path.exists(csr):
            pipe = os.popen(' '.join((
                'openssl', 'req', '-days', '3650', '-nodes', '-new',
                '-keyout', key,
                '-out', csr,
                '-config', '/proc/self/fd/0'
                )), 'w')
            print >>pipe, '''\
[req]
prompt = no
default_bits = 1024
distinguished_name = req_distinguished_name

[req_distinguished_name]
commonName = %s
''' % self.name
            status = pipe.close()
            if status is not None:
                if os.WIFSIGNALED(status):
                    print >>sys.stderr, '%s: openssl exited due to a signal %r' % (sys.argv[0], os.WTERMSIG(status))
                elif os.WIFEXITED(status):
                    rc = os.WEXITSTATUS(status)
                    if rc==127:
                        print >>sys.stderr, '%s: openssl not found' % sys.argv[0]
                    else:
                        print >>sys.stderr, '%s: openssl exited with error %r' % (sys.argv[0], rc)
                else:
                    print >>sys.stderr, '%s: openssl got an unknown error: %r' % (sys.argv[0], status)
                sys.exit(1)

    def signCertificateRequest(self):
        hostkeydir=os.path.join('keys', 'host')
        cadir=os.path.join('keys', 'ca')

        csr=os.path.join(hostkeydir, self.safename()+'.csr')
        crt=os.path.join(hostkeydir, self.safename()+'.crt')
        if os.path.exists(csr) and not os.path.exists(crt):
            pipe = os.popen(' '.join((
                'openssl', 'ca', '-batch', '-days', '3650',
                '-name', 'CA_default',
                '-in', csr,
                '-out', crt,
                '-cert', os.path.join(cadir, 'ca.crt'),
                '-keyfile', os.path.join(cadir, 'ca.key'),
                '-md', 'sha1',
                '-policy', 'policy_match',
                '-outdir', cadir,
                '-config', '/proc/self/fd/0'
                )), 'w')
            print >>pipe, '''\
[CA_default]
database = %(cadir)s/index.txt
serial = %(cadir)s/serial

[policy_match]
commonName		= supplied
''' % {'cadir': cadir}
            status = pipe.close()
            if status is not None:
                if os.WIFSIGNALED(status):
                    print >>sys.stderr, '%s: openssl exited due to a signal %r' % (sys.argv[0], os.WTERMSIG(status))
                elif os.WIFEXITED(status):
                    rc = os.WEXITSTATUS(status)
                    if rc==127:
                        print >>sys.stderr, '%s: openssl not found' % sys.argv[0]
                    else:
                        print >>sys.stderr, '%s: openssl exited with error %r' % (sys.argv[0], rc)
                else:
                    print >>sys.stderr, '%s: openssl got an unknown error: %r' % (sys.argv[0], status)
                sys.exit(1)

    def reservePort(self, port, taker):
        assert not self.portsTaken.has_key(port), \
               "Host %s port %d is in use by both %s and %s" \
               % (self, port, self.portsTaken[port], taker)
        self.portsTaken[port]=taker

    def safename(self):
        return makesafe(self.name)

    def __str__(self):
        return self.name

class EndPoint:
    def __init__(self, host, port,
                 passive=False, tlsServer=False, ping=None,
                 rp_filter=None):
        assert isinstance(host, Host)
        self.host=host
        self.port=port
        self.host.reservePort(port, self)
        self.passive=passive
        self.tlsServer=tlsServer
        self.ping=ping
        self.rp_filter=rp_filter

class Tunnel:
    template = """\
# This file is auto-generated, DO NOT EDIT!

dev-node /dev/net/tun
user nobody
group nogroup
comp-lzo
verb 5
persist-key
mlock
dh /usr/share/doc/openvpn/examples/sample-keys/dh1024.pem
"""
    def __init__(self, left, right, tlsAuth=None):
        assert isinstance(left, EndPoint)
        assert isinstance(right, EndPoint)
        self.left=left
        self.right=right
        if tlsAuth is None:
            l=[left.host.name, right.host.name]
            l.sort()
            tlsAuth=' '.join(['CaRPaLTUnnEl']+l)
        self.tlsAuth=tlsAuth
        if not left.tlsServer and not right.tlsServer:
            left.tlsServer=True
        if left.tlsServer and right.tlsServer:
            right.tlsServer=False
        self.generate()

    def configPath(self, local=None, remote=None, file=None):
        assert local or not remote
        l=['config']
        if local:
            l.append(local.host.safename())
            if remote:
                l.append('peers')
                l.append(remote.host.safename())
        if file:
            l.append(file)
        return apply(os.path.join, l)

    def generateTunnel(self, local, remote):
        if local.host.active:
            remotecrt=os.path.join('peerkeys',
                                   remote.host.safename()+'.crt')
            if not os.path.exists(remotecrt):
                print >>sys.stderr, '%s: no peerkey for %s, skipping.' \
                      % (sys.argv[0], remote.host.name)
                return

            mkdir(self.configPath(), 0755)
            mkdir(self.configPath(local), 0755)
            mkdir(self.configPath(local, file='peers'), 0755)
            mkdir(self.configPath(local, remote), 0755)

            f=TempFile(self.configPath(local, remote, 'ca.crt'), 0644)
            for line in file(os.path.join('keys', 'ca', 'ca.crt')):
                assert line[-1]=='\n'
                line=line[:-1]
                print >>f, line
            for line in file(remotecrt):
                assert line[-1]=='\n'
                line=line[:-1]
                print >>f, line
            f.close()

            f=TempFile(self.configPath(local, remote, 'config'), 0644)
            print >>f, self.template
            print >>f, ('''
            up "%(upscript)s '%(local)s' '%(remote)s'"
            ''' % {
                'local': local.host.safename(),
                'remote': remote.host.safename(),
                'upscript': self.upscript,
                }).strip()
            if local.host.address is not None:
                print >>f, 'local', local.host.address
            print >>f, 'lport', local.port
            print >>f, 'rport', remote.port
            if local.tlsServer:
                print >>f, 'tls-server'
            else:
                print >>f, 'tls-client'
            print >>f, 'ca', os.path.join('peerkeys', remote.host.safename()+'.crt')
            hostkeydir=os.path.join('keys', 'host')
            print >>f, 'cert', os.path.join(hostkeydir,
                                            local.host.safename()+'.crt')
            print >>f, 'key', os.path.join(hostkeydir, 
                                           local.host.safename()+'.key')
            if not local.passive:
                print >>f, 'remote', remote.host.address
            if local.ping is not None:
                print >>f, 'ping', local.ping
            f.close()

            if local.host.vpnaddress is not None:
                f=TempFile(self.configPath(local, file='vpnaddress'), 0644)
                print >>f, local.host.vpnaddress
                f.close()

            f=TempFile(self.configPath(local, remote, 'routes'), 0644)
            for route in remote.host.routes:
                print >>f, route
            f.close()

            if local.rp_filter is not None:
                f=TempFile(self.configPath(local, remote, 'rp_filter'), 0644)
                print >>f, local.rp_filter
                f.close()

            f=TempFile(self.configPath(local, remote, 'tlsauth'), 0600)
            print >>f, self.tlsAuth
            f.close()

            try:
                os.symlink(
                    self.configPath(local, remote, 'config'),
                    local.host.safename() \
                    + '-' + remote.host.safename() \
                    + '.conf')
            except OSError, e:
                if e.errno==errno.EEXIST:
                    pass
                else:
                    raise

    def generate(self):
        self.generateTunnel(self.left, self.right)
        self.generateTunnel(self.right, self.left)

def hostname():
    from socket import gethostname, gethostbyname_ex
    shortname=gethostname()
    fqdn=gethostbyname_ex(shortname)[0]
    return fqdn

def main(args):
    try:
        from optparse import OptionParser
    except ImportError:
        from optik import OptionParser
    parser=OptionParser()
    parser.add_option('-f', '--file',
                      help='load config file FILE', metavar='FILE',
                      default='carpaltunnel-config')
    parser.add_option('-H', '--host',
                      help='current host is HOST', metavar='HOST')
    parser.add_option('-d', '--directory',
                      help='set base directory to DIR', metavar='DIR',
                      default='/etc/openvpn')
    parser.add_option('-u', '--upscript',
                      help='run SCRIPT when tunnel is created', metavar='SCRIPT',
                      default='/usr/lib/carpaltunnel/tunnel-up')
    parser.add_option('--push', action='append',
                      help='rsync key to HOST', metavar='HOST')
    parser.add_option('--pull', action='append',
                      help='rsync key from HOST', metavar='HOST')
    parser.add_option('--rsync-arg', action='append',
                      help='pass ARG to rsync (default --perms)', metavar='ARG',
                      default=['--perms'])
    #parser.add_option('--push-all', action='store_true')
    #parser.add_option('--pull-all', action='store_true')

    options, args = parser.parse_args(args)
    if not options.host:
        options.host=hostname()
    Host.current_host = options.host
    Tunnel.upscript = options.upscript

    try:
        os.chdir(options.directory)
    except OSError, e:
        print >>sys.stderr, '%s: changing into base dir: %s' % (sys.argv[0], e)
        sys.exit(1)

    os.umask(077)

    if options.push:
        mkdir('peerkeys', 0755)
        cadir=os.path.join('keys', 'ca')
        for host in options.push:
            pid=os.fork()
            if pid:
                # parent
                newpid, status=os.waitpid(pid, 0)
                assert newpid==pid
                if status:
                    print >>sys.stderr, \
                          '%s: push to %s failed with exit status %d.' \
                          % (sys.argv[0], host, status)
                    sys.exit(1)
            else:
                # child
                os.execvp('rsync',
                          ['rsync'] + options.rsync_arg + [
                    '--',
                    os.path.join(cadir, 'ca.crt'),
                    host+':'+os.path.join(options.directory,
                                          'peerkeys',
                                          makesafe(hostname())
                                          +'.crt')
                    ])
                sys.exit(1)

    elif options.pull:
        mkdir('peerkeys', 0755)
        for host in options.pull:
            pid=os.fork()
            if pid:
                # parent
                newpid, status=os.waitpid(pid, 0)
                assert newpid==pid
                if status:
                    print >>sys.stderr, \
                    '%s: pull from %s failed with exit status %d.' \
                    % (sys.argv[0], host, status)
                    sys.exit(1)
            else:
                # child
                os.execvp('rsync',
                          ['rsync'] + options.rsync_arg + [
                    '--',
                    host+':'+os.path.join(options.directory,
                                          'keys', 'ca',
                                          'ca.crt'),
                    os.path.join('peerkeys',
                                 makesafe(hostname()+'.crt'))
                    ])
                sys.exit(1)

    else:
        try:
            execfile(options.file,
                     {
                'Host': Host,
                'Tunnel': Tunnel,
                'EndPoint': EndPoint,
                })
        except IOError, e:
            print >>sys.stderr, '%s: Reading configuration file: %s' \
                  % (sys.argv[0], e)
            sys.exit(1)

if __name__ == '__main__':
    main(sys.argv[1:])
