#! /usr/bin/env python




import re
bool_feat_re = re.compile(r"^([a-z]+)(True|False)$")
int_feat_re = re.compile(r"^([a-z]+)([0-9]+)$")
real_feat_re = re.compile(r"^([a-z]+)([0-9]+\.?[0-9]*)$")
str_feat_re = re.compile(r"^([a-z]+)([A-Z][A-Za-z_0-9]+)$")




def parse_dir_feature(feat, number):
    bool_match = bool_feat_re.match(feat)
    if bool_match is not None:
        return (bool_match.group(1), "integer", int(bool_match.group(2) == "True"))
    int_match = int_feat_re.match(feat)
    if int_match is not None:
        return (int_match.group(1), "integer", float(int_match.group(2)))
    real_match = real_feat_re.match(feat)
    if real_match is not None:
        return (real_match.group(1), "real", float(real_match.group(2)))
    str_match = str_feat_re.match(feat)
    if str_match is not None:
        return (str_match.group(1), "text", str_match.group(2))
    return ("dirfeat%d" % number, "text", feat)




def larger_sql_type(type_a, type_b):
    assert type_a in [None, "text", "real", "integer"]
    assert type_b in [None, "text", "real", "integer"]

    if type_a is None:
        return type_b
    if type_b is None:
        return type_a
    if "text" in [type_a, type_b]:
        return "text"
    if "real" in [type_a, type_b]:
        return "real"
    assert type_a == type_b == "integer"
    return "integer"




def sql_type_and_value(value):
    if value is None:
        return None, None
    elif isinstance(value, bool):
        return "integer", int(value)
    elif isinstance(value, int):
        return "integer", value
    elif isinstance(value, float):
        return "real", value
    else:
        return "text", str(value)




def sql_type_and_value_from_str(value):
    if value == "None":
        return None, None
    elif value in ["True", "False"]:
        return "integer", value == "True"
    else:
        try:
            return "integer", int(value)
        except ValueError:
            pass
        try:
            return "real", float(value)
        except ValueError:
            pass
        return "text", str(value)




class FeatureGatherer:
    def __init__(self, features_from_dir=False, features_file=None):
        self.features_from_dir = features_from_dir

        self.dir_to_features = {}
        if features_file is not None:
            for line in open(features_file, "r").readlines():
                colon_idx = line.find(":")
                assert colon_idx != -1

                entries = [val.strip() for val in line[colon_idx+1:].split(",")]
                features = []
                for entry in entries:
                    equal_idx = entry.find("=")
                    assert equal_idx != -1
                    features.append((entry[:equal_idx],) + 
                            sql_type_and_value_from_str(entry[equal_idx+1:]))

                self.dir_to_features[line[:colon_idx]] = features

    def get_db_features(self, dbname, logmgr):
        from os.path import dirname
        dn = dirname(dbname)

        features = self.dir_to_features.get(dn, [])[:]

        if self.features_from_dir:
            features.extend(parse_dir_feature(feat, i)
                    for i, feat in enumerate(dn.split("-")))

        for name, value in logmgr.constants.iteritems():
            features.append((name,) + sql_type_and_value(value))

        return features




def scan(fg, dbnames, progress=True):
    features = {}
    dbname_to_run_id = {}
    uid_to_run_id = {}
    next_run_id = 1

    from pytools import ProgressBar
    if progress:
        pb = ProgressBar("Scanning...", len(dbnames))

    for dbname in dbnames:
        from pytools.log import LogManager
        logmgr = LogManager(dbname, "r")

        unique_run_id = logmgr.constants.get("unique_run_id")
        run_id = uid_to_run_id.get(unique_run_id)

        if run_id is None:
            run_id = next_run_id
            next_run_id += 1

            if unique_run_id is not None:
                uid_to_run_id[unique_run_id] = run_id

        dbname_to_run_id[dbname] = run_id

        if progress:
            pb.progress()

        for fname, ftype, fvalue in fg.get_db_features(dbname, logmgr):
            if fname in features:
                features[fname] = larger_sql_type(ftype, features[fname])
            else:
                if ftype is None:
                    ftype = "text"
                features[fname] = ftype

    if progress:
        pb.finished()

    return features, dbname_to_run_id




def make_name_map(map_str):
    import re
    result = {}

    if not map_str:
        return result

    map_re = re.compile(r"^([a-z_A-Z0-9]+)=([a-z_A-Z0-9]+)$")
    for fmap_entry in map_str.split(","):
        match = map_re.match(fmap_entry)
        assert match is not None
        result[match.group(1)] = match.group(2)

    return result




def transfer_data_table(db_conn, tbl_name, data_table):
    db_conn.executemany("insert into %s (%s) values (%s)" %
            (tbl_name, 
                ", ".join(data_table.column_names),
                ", ".join("?" * len(data_table.column_names))),
            data_table.data)

    


def gather_single_file(outfile, infiles):
    from pytools import ProgressBar
    pb = ProgressBar("Importing...", len(infiles))

    import sqlite3
    db_conn = sqlite3.connect(outfile)

    from pytools.log import _set_up_schema
    _set_up_schema(db_conn)

    from pickle import dumps

    seen_constants = set()
    seen_quantities = set()

    for dbname in infiles:
        pb.progress()
        
        from pytools.log import LogManager
        logmgr = LogManager(dbname, "r")

        # transfer warnings
        transfer_data_table(db_conn, "warnings", logmgr.get_warnings())

        # transfer constants
        for key, val in logmgr.constants.iteritems():
            if key not in seen_constants:
                db_conn.execute("insert into constants values (?,?)",
                        (key, buffer(dumps(val))))
                seen_constants.add(key)

        for qname, qdata in logmgr.quantity_data.iteritems():
            db_conn.execute("""insert into quantities values (?,?,?,?)""", (
                  qname, qdata.unit, qdata.description,
                  buffer(dumps(qdata.default_aggregator))))

            if qname not in seen_quantities:
                db_conn.execute("""create table %s 
                  (step integer, rank integer, value real)""" % qname)
                seen_quantities.add(qname)

            transfer_data_table(db_conn, qname, logmgr.get_table(qname))

    pb.finished()

    db_conn.commit()
    db_conn.close()




def gather_multi_file(outfile, infiles, fmap, qmap, fg, features,
        dbname_to_run_id):
    from pytools import ProgressBar
    pb = ProgressBar("Importing...", len(infiles))

    import sqlite3
    db_conn = sqlite3.connect(outfile)
    run_columns = [
            "id integer primary key",
            "dirname text"] + ["%s %s" % (fmap.get(fname, fname), ftype)
                    for fname, ftype in features.iteritems()]
    db_conn.execute("create table runs (%s)" % ",".join(run_columns))
    db_conn.execute("create index runs_id on runs (id)")

    db_conn.execute("""create table quantities (
            id integer primary key,
            name text, 
            unit text, 
            description text, 
            rank_aggregator text
            )""")

    created_tables = set()

    from os.path import dirname

    uid_to_run_id = {}
    written_run_ids = set()

    for dbname in infiles:
        pb.progress()
        
        run_id = dbname_to_run_id[dbname]

        from pytools.log import LogManager
        logmgr = LogManager(dbname, "r")

        if run_id not in written_run_ids:
            dbfeatures = fg.get_db_features(dbname, logmgr)
            qry = "insert into runs (%s) values (%s)" % (
                ",".join(["id", "dirname"]+[fmap.get(f[0], f[0]) for f in dbfeatures]), 
                ",".join("?" * (len(dbfeatures)+2)))
            rows = db_conn.execute(qry, [run_id, dirname(dbname)]+[f[2] for f in dbfeatures])

            written_run_ids.add(run_id)

        for qname, qdat in logmgr.quantity_data.iteritems():
            tgt_qname = qmap.get(qname, qname)

            if not tgt_qname in created_tables:
                created_tables.add(tgt_qname)
                db_conn.execute("create table %s ("
                  "run_id integer, step integer, rank integer, value real)" % tgt_qname)
                  
                db_conn.execute("create index %s_main on %s (run_id,step,rank)" % (
                    tgt_qname, tgt_qname))

                agg = qdat.default_aggregator
                try:
                    agg = agg.__name__
                except AttributeError:
                    if agg is not None:
                        agg = str(agg)

                db_conn.execute("insert into quantities "
                        "(name,unit,description,rank_aggregator)"
                        "values (?,?,?,?)",
                        (tgt_qname, qdat.unit, qdat.description, agg))

            cursor = logmgr.db_conn.execute("select %s,step,rank,value from %s" % (
                run_id, qname))
            db_conn.executemany("insert into %s values (?,?,?,?)" % tgt_qname,
                    cursor)
        logmgr.close()
    pb.finished()

    db_conn.commit()
    db_conn.close()




def main():
    import sys
    from optparse import OptionParser

    parser = OptionParser(usage="%prog OUTDB DBFILES ...")
    parser.add_option("-1", "--single", action="store_true",
            help="Gather single-run instead of multi-run file")
    parser.add_option("-s", "--show-features", action="store_true",
            help="Only print the features found and quit")
    parser.add_option("-d", "--dir-features", action="store_true",
            help="Extract features from directory names")
    parser.add_option("-f", "--file-features", default=None,
            metavar="FILENAME",
            help="Read additional features from file, with lines like: "
            "'dirname: key=value, key=value'")
    parser.add_option("-m", "--feature-map", default=None,
            help="Specify a feature name map.",
            metavar="F1=FNAME1,F2=FNAME2")
    parser.add_option("-q", "--quantity-map", default=None,
            help="Specify a quantity name map.",
            metavar="Q1=QNAME1,Q2=QNAME2")
    options, args = parser.parse_args()

    if len(args) < 2:
        parser.print_help()
        sys.exit(1)

    outfile = args[0]
    infiles = args[1:]

    # list of run features as {name: sql_type}
    fg = FeatureGatherer(options.dir_features, options.file_features)
    features, dbname_to_run_id = scan(fg, infiles)

    fmap = make_name_map(options.feature_map)
    qmap = make_name_map(options.quantity_map)

    if options.show_features:
        for feat_name, feat_type in features.iteritems():
            print fmap.get(feat_name, feat_name), feat_type
        sys.exit(0)

    if options.single:
        if len(set(dbname_to_run_id.values())) > 1:
            raise ValueError(
                    "data seems to come from more than one run--"
                    "can't write single-run file")
        gather_single_file(outfile, infiles)
    else:
        gather_multi_file(outfile, infiles, fmap, qmap, fg, features,
                dbname_to_run_id)





if __name__ == "__main__":
    main()

