#!/usr/bin/env python

#
# Part of: https://github.com/mateidavid/fast5
#
# (c) 2017: Matei David, Ontario Institute for Cancer Research
# MIT License
#

import argparse
import datetime
import dateutil.parser
import logging
import math
import os
import sys

import fast5

import signal
signal.signal(signal.SIGPIPE, signal.SIG_DFL)

def add_fast5(fn, rel_dn, args):
    logger.info("adding fast5 fn=" + fn + " rel_dn=" + rel_dn)
    return [fn]

def add_dir(dn, args):
    l = list()
    logger.info("processing dir dn=" + dn)
    for t in os.walk(dn):
        rel_dn = os.path.relpath(t[0], dn)
        for rel_fn in t[2]:
            fn = os.path.join(t[0], rel_fn)
            if fast5.File.is_valid_file(fn):
                l += add_fast5(fn, rel_dn, args)
        if not args.recurse:
            break
    return l

def add_fofn(fn, args):
    l = list()
    logger.info("processing fofn fn=" + fn)
    if fn != "-":
        f = open(fn)
    else:
        f = sys.stdin
    for p in f:
        p = p.strip()
        if fast5.File.is_valid_file(p):
            l += add_fast5(p, "", args)
        else:
            logger.warning("fofn line not a fast5 file: " + p)
    if fn != "-":
        f.close()
    return l

def add_paths(pl, args):
    l = list()
    if len(pl) == 0:
        pl.append("-")
    for p in pl:
        if os.path.isdir(p):
            l += add_dir(p, args)
        elif fast5.File.is_valid_file(p):
            l += add_fast5(p, "", args)
        else:
            l += add_fofn(p, args)
    return l

def stat_file(ifn, args):
    d = dict()
    try:
        f = fast5.File(ifn)
        # cid params
        d["cid"] = f.get_channel_id_params()
        d["tid"] = f.get_tracking_id_params()
        # raw samples
        d["rs_rn_l"] = f.get_raw_samples_read_name_list()
        d["rs"] = dict()
        for rn in d["rs_rn_l"]:
            d["rs"][rn] = dict()
            d["rs"][rn]["params"] = f.get_raw_samples_params(rn)
            d["rs"][rn]["packed"] = not f.have_raw_samples_unpack(rn)
        # basecall groups
        d["bc_gr_l"] = f.get_basecall_group_list()
        d["bc"] = dict()
        d["bc_desc"] = dict()
        d["bc_summary"] = dict()
        for gr in d["bc_gr_l"]:
            d["bc"][gr] = dict()
            d["bc"][gr]["desc"] = f.get_basecall_group_description(gr)
            d["bc"][gr]["summary"] = f.get_basecall_summary(gr)
            d["bc"][gr]["start"] = dict()
            d["bc"][gr]["length"] = dict()
            d["bc"][gr]["count"] = dict()
            d["bc"][gr]["packed_fastq"] = dict()
            d["bc"][gr]["packed_events"] = dict()
            for st in [0, 1, 2]:
                d["bc"][gr]["packed_fastq"][st] = not f.have_basecall_fastq_unpack(st, gr)
                if st < 2:
                    if d["bc"][gr]["desc"]["have_events"][st]:
                        d["bc"][gr]["packed_events"][st] = not f.have_basecall_events_unpack(st, gr)
                        ev_params = f.get_basecall_events_params(st, gr)
                        d["bc"][gr]["start"][st] = ev_params["start_time"]
                        d["bc"][gr]["length"][st] = ev_params["duration"]
                        if d["bc"][gr]["start"][st] < 1e-3:
                            d["bc"][gr]["start"][st] = float('nan')
                        if d["bc"][gr]["length"][st] < 1e-3:
                            d["bc"][gr]["length"][st] = float('nan')
                        if False:
                            e = f.get_basecall_events(st, gr)
                            d["bc"][gr]["start"][st] = e[0]["start"]
                            d["bc"][gr]["length"][st] = e[-1]["start"] + e[-1]["length"] - e[0]["start"]
                            d["bc"][gr]["count"][st] = len(e)
                else:
                    d["bc"][gr]["packed_alignment"] = not f.have_basecall_alignment_unpack(gr)
        # eventdetection groups
        d["ed_gr_l"] = f.get_eventdetection_group_list()
        d["ed"] = dict()
        for gr in d["ed_gr_l"]:
            d["ed"][gr] = dict()
            d["ed"][gr]["rn_l"] = f.get_eventdetection_read_name_list(gr)
            d["ed"][gr]["rn"] = dict()
            for rn in d["ed"][gr]["rn_l"]:
                d["ed"][gr]["rn"][rn] = dict()
                d["ed"][gr]["rn"][rn]["packed"] = not f.have_eventdetection_events_unpack(gr, rn)
    except RuntimeError as e:
        d = dict()
    return d

def as_time(v, r):
    if math.isnan(v):
        return 'nan'
    x = float(v)/r
    m, s = divmod(x, 60)
    h, m = divmod(m, 60)
    return "%d:%02d:%02d.%03d" % (h, m, s, (x * 1000) % 1000)

def print_path(p, v, args):
    if type(v) == list:
        print(args.delim[1].join(str(e) for e in p) + args.delim[0] + args.delim[1].join(str(e) for e in v))
    else:
        print(args.delim[1].join(str(e) for e in p) + args.delim[0] + str(v))

def list_file(ifn, include_fn, args):
    d = stat_file(ifn, args)
    if include_fn:
        print_path(["file"], ifn, args)
    if "cid" not in d:
        return
    # tid
    for k in ["device_id", "asic_id", "flow_cell_id", "exp_script_purpose"]:
        if k not in d["tid"]:
            continue
        print_path(["tid", k], d["tid"][k], args)
    if "exp_start_time" in d["tid"]:
        if 'T' in d["tid"]["exp_start_time"]:
            exp_start_time = dateutil.parser.parse(d["tid"]["exp_start_time"])
        else:
            exp_start_time = datetime.datetime.fromtimestamp(int(d["tid"]["exp_start_time"]))
        print_path(["tid", "exp_start_date"], exp_start_time.date().isoformat(), args)
        print_path(["tid", "exp_start_time"], exp_start_time.time().isoformat(), args)
    # cid
    for k in ["channel_number", "sampling_rate"]:
        if k not in d["cid"]:
            continue
        print_path(["cid", k], d["cid"][k], args)
    sampling_rate = d["cid"]["sampling_rate"]
    # rs
    for rn in d["rs_rn_l"]:
        print_path(["rs", rn, "packed"], int(d["rs"][rn]["packed"]), args)
        print_path(["rs", rn, "read_number"], d["rs"][rn]["params"]["read_number"], args)
        print_path(["rs", rn, "read_id"], d["rs"][rn]["params"]["read_id"], args)
        print_path(["rs", rn, "start"], as_time(d["rs"][rn]["params"]["start_time"], sampling_rate), args)
        print_path(["rs", rn, "length"], as_time(d["rs"][rn]["params"]["duration"], sampling_rate), args)
    # bc
    for gr in d["bc_gr_l"]:
        print_path(["bc", gr, "id"], d["bc"][gr]["desc"]["name"] + ":" + d["bc"][gr]["desc"]["version"], args)
        for st in [0, 1, 2]:
            if not d["bc"][gr]["desc"]["have_subgroup"][st]:
                continue
            # fastq
            fq_len = 0
            print_path(["bc", gr, st, "fastq", "packed"], int(d["bc"][gr]["packed_fastq"][st]), args)
            if d["bc"][gr]["desc"]["have_fastq"][st]:
                for k in ["sequence_length", "mean_qscore"]:
                    fk = ["basecall_1d_template", "basecall_1d_complement", "basecall_2d"][st] + "/" + k
                    if fk not in d["bc"][gr]["summary"]:
                        continue
                    print_path(["bc", gr, st, "fastq", k], d["bc"][gr]["summary"][fk], args)
                    if k == "sequence_length":
                        fq_len = d["bc"][gr]["summary"][fk]
            if st < 2:
                # events
                if d["bc"][gr]["desc"]["have_events"][st]:
                    print_path(["bc", gr, st, "events", "packed"], int(d["bc"][gr]["packed_events"][st]), args)
                    for k in ["start", "length"]:
                        print_path(["bc", gr, st, "events", k], as_time(float(d["bc"][gr][k][st]), 1.0), args)
                    if st in d["bc"][gr]["count"]:
                        print_path(["bc", gr, st, "events", "count"], d["bc"][gr]["count"][st], args)
                    print_path(["bc", gr, st, "bps"], "%.2f" % (float(fq_len) / d["bc"][gr]["length"][st]), args)
                # model
                print_path(["bc", gr, st, "model"], int(d["bc"][gr]["desc"]["have_model"][st]), args)
            else:
                print_path(["bc", gr, st, "alignment", "packed"], int(d["bc"][gr]["packed_alignment"]), args)
        if d["bc"][gr]["desc"]["have_subgroup"][2]:
            print_path(["bc", gr, "bc_1d_gr"], d["bc"][gr]["desc"]["bc_1d_gr"], args)
        if d["bc"][gr]["desc"]["have_subgroup"][0] or d["bc"][gr]["desc"]["have_subgroup"][1]:
            print_path(["bc", gr, "ed_gr"], d["bc"][gr]["desc"]["ed_gr"], args)
    # ed
    for gr in d["ed_gr_l"]:
        for rn in d["ed"][gr]["rn_l"]:
            print_path(["ed", gr, rn, "packed"], int(d["ed"][gr]["rn"][rn]["packed"]), args)


if __name__ == "__main__":
    description = """
    Summarize contents of ONT fast5 files.
    """
    parser = argparse.ArgumentParser(description=description, epilog="")
    parser.add_argument("--log-level", default="warning",
                        help="log level")
    #
    parser.add_argument("--delim", default="\t/",
                        help="Delimiters list; first char used between path and value, second char used between path elements.")
    parser.add_argument("-R", "--recurse", action="store_true",
                        help="Recurse in input directories.")
    #
    parser.add_argument("inputs", nargs="*", default=[], action="append",
                        help="Input directories, fast5 files, or files of fast5 file names.")
    args = parser.parse_args()

    numeric_log_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(numeric_log_level, int):
        raise ValueError("Invalid log level: '%s'" % args.log_level)
    logging.basicConfig(level=numeric_log_level,
                        format="%(asctime)s %(name)s.%(levelname)s %(message)s",
                        datefmt="%Y/%m/%d %H:%M:%S")
    logger = logging.getLogger(os.path.basename(__file__))
    fast5.Logger.set_levels_from_options([args.log_level.lower()])
    # fix delim
    args.delim = list(args.delim)
    while len(args.delim) < 2:
        args.delim.append("")
    logger.debug("args: " + str(args))

    fl = add_paths(args.inputs[0], args)
    for ifn in fl:
        list_file(ifn, len(fl) > 1, args)
