#! /usr/bin/env python
"""
    cmdline utility to perform cluster reconnaissance
"""


from eventlet.green import urllib2
from swift.common.ring import Ring
from urlparse import urlparse
try:
    import simplejson as json
except ImportError:
    import json
from hashlib import md5
import datetime
import eventlet
import optparse
import sys
import os


class Scout(object):
    """
    Obtain swift recon information
    """

    def __init__(self, recon_type, verbose=False, suppress_errors=False,
                    timeout=5):
        recon_uri = ["ringmd5", "async", "replication", "load", "diskusage",
            "unmounted", "quarantined", "sockstat"]
        if recon_type not in recon_uri:
            raise Exception("Invalid scout type requested")
        else:
            self.recon_type = recon_type
        self.verbose = verbose
        self.suppress_errors = suppress_errors
        self.timeout = timeout

    def scout_host(self, base_url, recon_type):
        """
        Perform the actual HTTP request to obtain swift recon telemtry.

        :param base_url: the base url of the host you wish to check. str of the
                        format 'http://127.0.0.1:6000/recon/'
        :param recon_type: the swift recon check to request.
        :returns: tuple of (recon url used, response body, and status)
        """
        url = base_url + recon_type
        try:
            body = urllib2.urlopen(url, timeout=self.timeout).read()
            content = json.loads(body)
            if self.verbose:
                print "-> %s: %s" % (url, content)
            status = 200
        except urllib2.HTTPError as err:
            if not self.suppress_errors or self.verbose:
                print "-> %s: %s" % (url, err)
            content = err
            status = err.code
        except urllib2.URLError as err:
            if not self.suppress_errors or self.verbose:
                print "-> %s: %s" % (url, err)
            content = err
            status = -1
        return url, content, status

    def scout(self, host):
        """
        Obtain telemetry from a host running the swift recon middleware.

        :param host: host to check
        :returns: tuple of (recon url used, response body, and status)
        """
        base_url = "http://%s:%s/recon/" % (host[0], host[1])
        url, content, status = self.scout_host(base_url, self.recon_type)
        return url, content, status


class SwiftRecon(object):
    """
    Retrieve and report cluster info from hosts running recon middleware.
    """

    def __init__(self):
        self.verbose = False
        self.suppress_errors = False
        self.timeout = 5
        self.pool_size = 30
        self.pool = eventlet.GreenPool(self.pool_size)

    def get_devices(self, zone_filter, ring_file):
        """
        Get a list of hosts in the ring

        :param zone_filter: Only list zones matching given filter
        :param ring_file: Ring file to obtain hosts from
        :returns: a set of tuples containing the ip and port of hosts
        """
        ring_data = Ring(ring_file)
        if zone_filter:
            ips = set((n['ip'], n['port']) for n in ring_data.devs if n \
                if n['zone'] == zone_filter)
        else:
            ips = set((n['ip'], n['port']) for n in ring_data.devs if n)
        return ips

    def get_ringmd5(self, hosts, ringfile):
        """
        Compare ring md5sum's with those on remote host

        :param hosts: set of hosts to check. in the format of:
            set([('127.0.0.1', 6020), ('127.0.0.2', 6030)])
        :param ringfile: The local ring file to compare the md5sum with.
        """
        stats = {}
        matches = 0
        errors = 0
        md5sum = md5()
        with open(ringfile, 'rb') as f:
            block = f.read(4096)
            while block:
                md5sum.update(block)
                block = f.read(4096)
        ring_sum = md5sum.hexdigest()
        recon = Scout("ringmd5", self.verbose, self.suppress_errors,
                        self.timeout)
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print "[%s] Checking ring md5sum's on %s hosts..." % (now, len(hosts))
        if self.verbose:
            print "-> On disk md5sum: %s" % ring_sum
        for url, response, status in self.pool.imap(recon.scout, hosts):
            if status == 200:
                stats[url] = response[ringfile]
                if response[ringfile] != ring_sum:
                    print "!! %s (%s) doesn't match on disk md5sum" % \
                        (url, response[ringfile])
                else:
                    matches = matches + 1
                    if self.verbose:
                        print "-> %s matches." % url
            else:
                errors = errors + 1
        print "%s/%s hosts matched, %s error[s] while checking hosts." % \
                (matches, len(hosts), errors)
        print "=" * 79

    def async_check(self, hosts):
        """
        Obtain and print async pending statistics

        :param hosts: set of hosts to check. in the format of:
            set([('127.0.0.1', 6020), ('127.0.0.2', 6030)])
        """
        stats = {}
        recon = Scout("async", self.verbose, self.suppress_errors,
                        self.timeout)
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print "[%s] Checking async pendings on %s hosts..." % (now, len(hosts))
        for url, response, status in self.pool.imap(recon.scout, hosts):
            if status == 200:
                stats[url] = response['async_pending']
        if len(stats) > 0:
            low = min(stats.values())
            high = max(stats.values())
            total = sum(stats.values())
            average = total / len(stats)
            print "Async stats: low: %d, high: %d, avg: %d, total: %d" % (low,
                high, average, total)
        else:
            print "Error: No hosts available or returned valid information."
        print "=" * 79

    def umount_check(self, hosts):
        """
        Check for and print unmounted drives

        :param hosts: set of hosts to check. in the format of:
            set([('127.0.0.1', 6020), ('127.0.0.2', 6030)])
        """
        stats = {}
        recon = Scout("unmounted", self.verbose, self.suppress_errors,
                        self.timeout)
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print "[%s] Getting unmounted drives from %s hosts..." % \
            (now, len(hosts))
        for url, response, status in self.pool.imap(recon.scout, hosts):
            if status == 200:
                for i in response:
                    stats[url] = i['device']
        for host in stats:
            node = urlparse(host).netloc
            print "Not mounted: %s on %s" % (stats[host], node)
        print "=" * 79

    def replication_check(self, hosts):
        """
        Obtain and print replication statistics

        :param hosts: set of hosts to check. in the format of:
            set([('127.0.0.1', 6020), ('127.0.0.2', 6030)])
        """
        stats = {}
        recon = Scout("replication", self.verbose, self.suppress_errors,
                        self.timeout)
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print "[%s] Checking replication times on %s hosts..." % \
            (now, len(hosts))
        for url, response, status in self.pool.imap(recon.scout, hosts):
            if status == 200:
                stats[url] = response['object_replication_time']
        if len(stats) > 0:
            low = min(stats.values())
            high = max(stats.values())
            total = sum(stats.values())
            average = total / len(stats)
            print "[Replication Times] shortest: %s, longest: %s, avg: %s" % \
                (low, high, average)
        else:
            print "Error: No hosts available or returned valid information."
        print "=" * 79

    def load_check(self, hosts):
        """
        Obtain and print load average statistics

        :param hosts: set of hosts to check. in the format of:
            set([('127.0.0.1', 6020), ('127.0.0.2', 6030)])
        """
        load1 = {}
        load5 = {}
        load15 = {}
        recon = Scout("load", self.verbose, self.suppress_errors,
                        self.timeout)
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print "[%s] Checking load avg's on %s hosts..." % (now, len(hosts))
        for url, response, status in self.pool.imap(recon.scout, hosts):
            if status == 200:
                load1[url] = response['1m']
                load5[url] = response['5m']
                load15[url] = response['15m']
        stats = {"1m": load1, "5m": load5, "15m": load15}
        for item in stats:
            if len(stats[item]) > 0:
                low = min(stats[item].values())
                high = max(stats[item].values())
                total = sum(stats[item].values())
                average = total / len(stats[item])
                print "[%s load average] lowest: %s, highest: %s, avg: %s" % \
                    (item, low, high, average)
            else:
                print "Error: No hosts available or returned valid info."
        print "=" * 79

    def quarantine_check(self, hosts):
        """
        Obtain and print quarantine statistics

        :param hosts: set of hosts to check. in the format of:
            set([('127.0.0.1', 6020), ('127.0.0.2', 6030)])
        """
        objq = {}
        conq = {}
        acctq = {}
        recon = Scout("quarantined", self.verbose, self.suppress_errors,
                        self.timeout)
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print "[%s] Checking quarantine on %s hosts..." % (now, len(hosts))
        for url, response, status in self.pool.imap(recon.scout, hosts):
            if status == 200:
                objq[url] = response['objects']
                conq[url] = response['containers']
                acctq[url] = response['accounts']
        stats = {"objects": objq, "containers": conq, "accounts": acctq}
        for item in stats:
            if len(stats[item]) > 0:
                low = min(stats[item].values())
                high = max(stats[item].values())
                total = sum(stats[item].values())
                average = total / len(stats[item])
                print ("[Quarantined %s] low: %d, high: %d, avg: %d, total: %d"
                        % (item, low, high, average, total))
            else:
                print "Error: No hosts available or returned valid info."
        print "=" * 79

    def socket_usage(self, hosts):
        """
        Obtain and print /proc/net/sockstat statistics

        :param hosts: set of hosts to check. in the format of:
            set([('127.0.0.1', 6020), ('127.0.0.2', 6030)])
        """
        inuse4 = {}
        mem = {}
        inuse6 = {}
        timewait = {}
        orphan = {}
        recon = Scout("sockstat", self.verbose, self.suppress_errors,
                        self.timeout)
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print "[%s] Checking socket usage on %s hosts..." % (now, len(hosts))
        for url, response, status in self.pool.imap(recon.scout, hosts):
            if status == 200:
                inuse4[url] = response['tcp_in_use']
                mem[url] = response['tcp_mem_allocated_bytes']
                inuse6[url] = response['tcp6_in_use']
                timewait[url] = response['time_wait']
                orphan[url] = response['orphan']
        stats = {"tcp_in_use": inuse4, "tcp_mem_allocated_bytes": mem,
                    "tcp6_in_use": inuse6, "time_wait": timewait,
                    "orphan": orphan}
        for item in stats:
            if len(stats[item]) > 0:
                low = min(stats[item].values())
                high = max(stats[item].values())
                total = sum(stats[item].values())
                average = total / len(stats[item])
                print "[%s] low: %d, high: %d, avg: %d, total: %d" % \
                    (item, low, high, average, total)
            else:
                print "Error: No hosts or info available."
        print "=" * 79

    def disk_usage(self, hosts):
        """
        Obtain and print disk usage statistics

        :param hosts: set of hosts to check. in the format of:
            set([('127.0.0.1', 6020), ('127.0.0.2', 6030)])
        """
        stats = {}
        highs = []
        lows = []
        averages = []
        percents = {}
        recon = Scout("diskusage", self.verbose, self.suppress_errors,
                        self.timeout)
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print "[%s] Checking disk usage on %s hosts..." % (now, len(hosts))
        for url, response, status in self.pool.imap(recon.scout, hosts):
            if status == 200:
                hostusage = []
                for entry in response:
                    if entry['mounted']:
                        used = float(entry['used']) / float(entry['size']) \
                            * 100.0
                        hostusage.append(round(used, 2))
                stats[url] = hostusage

        for url in stats:
            if len(stats[url]) > 0:
                #get per host hi/los for another day
                low = min(stats[url])
                high = max(stats[url])
                total = sum(stats[url])
                average = total / len(stats[url])
                highs.append(high)
                lows.append(low)
                averages.append(average)
                for percent in stats[url]:
                    percents[int(percent)] = percents.get(int(percent), 0) + 1
            else:
                print "-> %s: Error. No drive info available." % url

        if len(lows) > 0:
            low = min(lows)
            high = max(highs)
            average = sum(averages) / len(averages)
            #dist graph shamelessly stolen from https://github.com/gholt/tcod
            print "Distribution Graph:"
            mul = 69.0 / max(percents.values())
            for percent in sorted(percents):
                print '% 3d%%%5d %s' % (percent, percents[percent], \
                    '*' * int(percents[percent] * mul))
            print "Disk usage: lowest: %s%%, highest: %s%%, avg: %s%%" % \
                (low, high, average)
        else:
            print "Error: No hosts available or returned valid information."
        print "=" * 79

    def main(self):
        """
        Retrieve and report cluster info from hosts running recon middleware.
        """
        print "=" * 79
        usage = '''
        usage: %prog [-v] [--suppress] [-a] [-r] [-u] [-d] [-l] [--objmd5]
        '''
        args = optparse.OptionParser(usage)
        args.add_option('--verbose', '-v', action="store_true",
            help="Print verbose info")
        args.add_option('--suppress', action="store_true",
            help="Suppress most connection related errors")
        args.add_option('--async', '-a', action="store_true",
            help="Get async stats")
        args.add_option('--replication', '-r', action="store_true",
            help="Get replication stats")
        args.add_option('--unmounted', '-u', action="store_true",
            help="Check cluster for unmounted devices")
        args.add_option('--diskusage', '-d', action="store_true",
            help="Get disk usage stats")
        args.add_option('--loadstats', '-l', action="store_true",
            help="Get cluster load average stats")
        args.add_option('--quarantined', '-q', action="store_true",
            help="Get cluster quarantine stats")
        args.add_option('--objmd5', action="store_true",
            help="Get md5sums of object.ring.gz and compare to local copy")
        args.add_option('--sockstat', action="store_true",
            help="Get cluster socket usage stats")
        args.add_option('--all', action="store_true",
            help="Perform all checks. Equal to -arudlq --objmd5 --sockstat")
        args.add_option('--zone', '-z', type="int",
            help="Only query servers in specified zone")
        args.add_option('--timeout', '-t', type="int", metavar="SECONDS",
            help="Time to wait for a response from a server", default=5)
        args.add_option('--swiftdir', default="/etc/swift",
            help="Default = /etc/swift")
        options, arguments = args.parse_args()


        if len(sys.argv) <= 1:
            args.print_help()
            sys.exit(0)

        swift_dir = options.swiftdir
        obj_ring = os.path.join(swift_dir, 'object.ring.gz')

        self.verbose = options.verbose
        self.suppress_errors = options.suppress
        self.timeout = options.timeout

        if options.zone:
            hosts = self.get_devices(options.zone, obj_ring)
        else:
            hosts = self.get_devices(None, obj_ring)

        if options.all:
            self.async_check(hosts)
            self.umount_check(hosts)
            self.replication_check(hosts)
            self.load_check(hosts)
            self.disk_usage(hosts)
            self.get_ringmd5(hosts, obj_ring)
            self.quarantine_check(hosts)
            self.socket_usage(hosts)
        else:
            if options.async:
                self.async_check(hosts)
            if options.unmounted:
                self.umount_check(hosts)
            if options.replication:
                self.replication_check(hosts)
            if options.loadstats:
                self.load_check(hosts)
            if options.diskusage:
                self.disk_usage(hosts)
            if options.objmd5:
                self.get_ringmd5(hosts, obj_ring)
            if options.quarantined:
                self.quarantine_check(hosts)
            if options.sockstat:
                self.socket_usage(hosts)


if __name__ == '__main__':
    try:
        reconnoiter = SwiftRecon()
        reconnoiter.main()
    except KeyboardInterrupt:
        print '\n'
