#!/usr/bin/env python

# Copyright (c) 2014-2016, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import sys
import readconfig
import re
import base64
from datetime import datetime

class Symbol(str):
    pass

class Binary(str):
    pass

clean_string = re.compile(r'^[-.:_/A-Za-z0-9 ]*$')
clean_symbol = re.compile(r'^[_A-Za-z0-9]*$')

def quote_erlang_string(s):
    if clean_string.match(s):
        return '"' + s + '"'
    else:
        return "[" + ",".join([str(ord(c)) for c in s]) + "]"

def quote_erlang_symbol(s):
    if clean_symbol.match(s):
        return s
    elif clean_string.match(s):
        return "'" + s + "'"
    else:
        print >>sys.stderr, "Cannot generate symbol", s
        sys.exit(1)

def gen_erlang(term, level=1):
    indent = " " * level
    separator = ",\n" + indent
    if isinstance(term, Symbol):
        return quote_erlang_symbol(term)
    elif isinstance(term, Binary):
        return "<<" + ",".join([str(ord(c)) for c in term]) + ">>"
    elif isinstance(term, basestring):
        return quote_erlang_string(term)
    elif isinstance(term, int):
        return str(term)
    elif isinstance(term, tuple):
        tuplecontents = [gen_erlang(e, level=level+1) for e in term]
        if "\n" not in "".join(tuplecontents):
            separator = ", "
        return "{" + separator.join(tuplecontents) + "}"
    elif isinstance(term, list):
        listcontents = [gen_erlang(e, level=level+1) for e in term]
        return "[" + separator.join(listcontents) + "]"
    else:
        print "unknown type", type(term)
        sys.exit(1)

saslconfig = [(Symbol("sasl_error_logger"), Symbol("false")),
              (Symbol("errlog_type"), Symbol("error")),
              (Symbol("error_logger_mf_dir"), "sasl_log"),
              (Symbol("error_logger_mf_maxbytes"), 10485760),
              (Symbol("error_logger_mf_maxfiles"), 10),
              ]

def parse_address(address):
    parsed_address = address.split(":")
    if len(parsed_address) != 2:
        print >>sys.stderr, "Invalid address format", address
        sys.exit(1)
    return (parsed_address[0], int(parsed_address[1]))

def get_node_config(nodename, config):
    nodetype = []
    nodeconfig = {}
    for t in ["frontendnodes", "storagenodes", "signingnodes", "mergenodes"]:
        for node in config[t]:
            if node["name"] == nodename:
                nodetype.append(t)
                nodeconfig[t] = node
    if len(nodetype) == 0:
        print >>sys.stderr, "Cannot find config for node", nodename
        sys.exit(1)
    if len(nodetype) >= 2 and set(nodetype) != set(["frontendnodes", "storagenodes"]):
        print >>sys.stderr, "Node type unsupported:", nodetype
        sys.exit(1)
    return (set(nodetype), nodeconfig)

def get_address(bind_address, nodeconfig):
    if bind_address:
        (host, port) = parse_address(bind_address)
    else:
        (_, port) = parse_address(nodeconfig["address"])
        host = "0.0.0.0"
    return (host, port)

def gen_http_servers(nodetype, nodeconfig, bind_addresses, bind_publicaddress, bind_publichttpaddress):
    http_servers = []
    https_servers = []

    if "frontendnodes" in nodetype and "mergenodes" in nodetype:
        print >>sys.stderr, "cannot have both frontend node and merge node at the same time", nodetype
        sys.exit(1)

    if "frontendnodes" in nodetype:
        (host, port) = get_address(bind_addresses["frontend"], nodeconfig["frontendnodes"])
        if bind_publicaddress:
            (publichost, publicport) = parse_address(bind_publicaddress)
        else:
            (_, publicport) = parse_address(nodeconfig["frontendnodes"]["publicaddress"])
            publichost = "0.0.0.0"

        if bind_publichttpaddress:
            (publichttphost, publichttpport) = parse_address(bind_publichttpaddress)
            http_servers.append((Symbol("external_http_api"), publichttphost, publichttpport, Symbol("v1")))
        https_servers.append((Symbol("external_https_api"), publichost, publicport, Symbol("v1")))
        https_servers.append((Symbol("frontend_https_api"), host, port, Symbol("frontend")))
    if "storagenodes" in nodetype:
        (host, port) = get_address(bind_addresses["storage"], nodeconfig["storagenodes"])
        https_servers.append((Symbol("storage_https_api"), host, port, Symbol("storage")))
    if "signingnodes" in nodetype:
        (host, port) = get_address(bind_addresses["signing"], nodeconfig["signingnodes"])
        https_servers.append((Symbol("signing_https_api"), host, port, Symbol("signing")))
    if "mergenodes" in nodetype:
        (host, port) = get_address(bind_addresses["merge"], nodeconfig["mergenodes"])
        https_servers.append((Symbol("frontend_https_api"), host, port, Symbol("frontend")))
    if nodetype - set(["frontendnodes", "storagenodes", "signingnodes", "mergenodes"]):
        print >>sys.stderr, "unknown nodetype", nodetype
        sys.exit(1)

    return (http_servers,
            https_servers)

def allowed_clients_frontend(mergenodenames, primarymergenodename):
    return [
        ("/plop/v1/frontend/sendentry", mergenodenames),
        ("/plop/v1/frontend/sendlog", mergenodenames),
        ("/plop/v1/frontend/publish-sth", [primarymergenodename]),
        ("/plop/v1/frontend/verify-entries", [primarymergenodename]),
        ("/plop/v1/frontend/currentposition", mergenodenames),
        ("/plop/v1/frontend/missingentries", mergenodenames),
    ]

def allowed_clients_mergesecondary(primarymergenodename):
    return [
        ("/plop/v1/merge/sendentry", [primarymergenodename]),
        ("/plop/v1/merge/sendlog", [primarymergenodename]),
        ("/plop/v1/merge/verifyroot", [primarymergenodename]),
        ("/plop/v1/merge/verifiedsize", [primarymergenodename]),
        ("/plop/v1/merge/setverifiedsize", [primarymergenodename]),
        ("/plop/v1/merge/missingentries", [primarymergenodename]),
    ]

def allowed_clients_public():
    noauth = Symbol("noauth")
    return [
        ("/ct/v1/add-chain", noauth),
        ("/ct/v1/add-pre-chain", noauth),
        ("/ct/v1/get-sth", noauth),
        ("/ct/v1/get-sth-consistency", noauth),
        ("/ct/v1/get-proof-by-hash", noauth),
        ("/ct/v1/get-entries", noauth),
        ("/ct/v1/get-entry-and-proof", noauth),
        ("/ct/v1/get-roots", noauth),
    ]

def allowed_clients_signing(frontendnodenames, primarymergenodename):
    return [
        ("/plop/v1/signing/sct", frontendnodenames),
        ("/plop/v1/signing/sth", [primarymergenodename]),
    ]

def allowed_clients_storage(frontendnodenames, mergenodenames):
    return [
        ("/plop/v1/storage/sendentry", frontendnodenames),
        ("/plop/v1/storage/entrycommitted", frontendnodenames),
        ("/plop/v1/storage/fetchnewentries", mergenodenames),
        ("/plop/v1/storage/getentry", mergenodenames),
    ]

def allowed_servers_frontend(signingnodenames, storagenodenames):
    return [
        ("/plop/v1/storage/sendentry", storagenodenames),
        ("/plop/v1/storage/entrycommitted", storagenodenames),
        ("/plop/v1/signing/sct", signingnodenames),
    ]

def allowed_servers_primarymerge(frontendnodenames, secondarymergenames):
    return [
        ("/plop/v1/frontend/verify-entries", frontendnodenames),
        ("/plop/v1/frontend/sendlog", frontendnodenames),
        ("/plop/v1/frontend/sendentry", frontendnodenames),
        ("/plop/v1/frontend/publish-sth", frontendnodenames),
        ("/plop/v1/merge/verifiedsize", secondarymergenames),
        ("/plop/v1/merge/verifyroot", secondarymergenames),
        ("/plop/v1/merge/setverifiedsize", secondarymergenames),
        ("/plop/v1/merge/sendlog", secondarymergenames),
        ("/plop/v1/merge/sendentry", secondarymergenames),
    ]

def parse_ratelimit_expression(expression):
    if expression == "none":
        return Symbol("none")
    parts = expression.split(" ")
    if not (len(parts) == 3 and parts[1] == 'per' and parts[2] in ["second", "minute", "hour"]):
        print >>sys.stderr, "Ratelimit expressions must have the format \"<frequency> per second|minute|hour\" or \"none\""
        sys.exit(1)
    return (int(parts[0]), Symbol(parts[2]))

def parse_ratelimit((type, description)):
    descriptions = [parse_ratelimit_expression(s.strip()) for s in description.split(",")]
    if len(descriptions) != 1:
        print >>sys.stderr, "%s: Only one ratelimit expression supported right now" % (type,)
    return (Symbol(type), descriptions)

def api_keys(config):
    return [(node["nodename"], Binary(base64.b64decode(node["publickey"]))) for node in config["apikeys"]]

def gen_config(nodename, config, localconfig):
    print "generating config for", nodename
    paths = localconfig["paths"]
    apikeys = api_keys(config)
    bind_addresses = {
        "frontend": localconfig.get("frontendaddresses", {}).get(nodename),
        "storage": localconfig.get("storageaddresses", {}).get(nodename),
        "signing": localconfig.get("signingaddresses", {}).get(nodename),
        "merge": localconfig.get("mergeaddresses", {}).get(nodename),
        }
    bind_publicaddress = localconfig.get("ctapiaddresses", {}).get(nodename)
    bind_publichttpaddress = localconfig.get("publichttpaddresses", {}).get(nodename)
    options = localconfig.get("options", [])

    configfile = open(paths["configdir"] + "/" + nodename + ".config", "w")
    print >>configfile, "%% catlfish configuration file (-*- erlang -*-)"

    plopcontrolfilename = nodename + ".plopcontrol"

    plopconfigfilename = paths["configdir"] + "/" + nodename + ".plopconfig"
    plopconfigfile = open(plopconfigfilename, "w")

    print >>plopconfigfile, "%% plop configuration file (-*- erlang -*-)"

    (nodetype, nodeconfig) = get_node_config(nodename, config)
    if nodename == config["primarymergenode"]:
        (http_servers, https_servers) = [], []
    else:
        (http_servers, https_servers) = gen_http_servers(nodetype, nodeconfig, bind_addresses, bind_publicaddress, bind_publichttpaddress=bind_publichttpaddress)

    catlfishconfig = []
    plopconfig = []
    reloadableplopconfig = []

    if nodetype & set(["frontendnodes", "mergenodes"]):
        catlfishconfig.append((Symbol("known_roots_path"), localconfig["paths"]["knownroots"]))
    if "frontendnodes" in nodetype:
        if "sctcaching" in options:
            catlfishconfig.append((Symbol("sctcache_root_path"), paths["db"] + "sctcache/"))
        if "ratelimits" in localconfig:
            ratelimits = map(parse_ratelimit, localconfig["ratelimits"].items())
            catlfishconfig.append((Symbol("ratelimits"), ratelimits))

    catlfishconfig += [
        (Symbol("https_servers"), https_servers),
        (Symbol("http_servers"), http_servers),
        (Symbol("https_certfile"), paths["https_certfile"]),
        (Symbol("https_keyfile"), paths["https_keyfile"]),
    ]

    catlfishconfig.append((Symbol("mmd"), config["mmd"]))

    lagerconfig = [
        (Symbol("handlers"), [
            (Symbol("lager_console_backend"), Symbol("info")),
            (Symbol("lager_file_backend"), [(Symbol("file"), nodename + "-error.log"), (Symbol("level"), Symbol("error"))]),
            (Symbol("lager_file_backend"), [(Symbol("file"), nodename + "-debug.log"), (Symbol("level"), Symbol("debug"))]),
            (Symbol("lager_file_backend"), [(Symbol("file"), nodename + "-console.log"), (Symbol("level"), Symbol("info"))]),
        ])
    ]

    plopconfig += [
        (Symbol("https_cacertfile"), paths["https_cacertfile"]),
        (Symbol("https_cacert_fingerprint"), Binary(base64.b16decode(config["cafingerprint"]))),
    ]

    if "dbbackend" in localconfig:
        dbbackend = localconfig["dbbackend"]
        if dbbackend not in ("fsdb", "permdb"):
            print >>sys.stderr, "DB backend not recognized:", dbbackend
            sys.exit(1)
        plopconfig += [
            (Symbol("db_backend"), Symbol(dbbackend)),
        ]
        if dbbackend == "permdb" and len(localconfig["localnodes"]) != 1:
            print >>sys.stderr, "When using permdb, all services have to be in the same node"
            sys.exit(1)

    #print "nodetype", ", ".join(nodetype)
    if nodetype & set(["frontendnodes", "storagenodes"]):
        plopconfig += [
            (Symbol("entry_root_path"), paths["db"] + "certentries"),
            (Symbol("entryhash_root_path"), paths["db"] + "entryhash"),
            (Symbol("indexforhash_root_path"), paths["db"] + "certindex"),
        ]
    if "frontendnodes" in nodetype:
        plopconfig += [
            (Symbol("index_path"), paths["db"] + "index"),
            (Symbol("sth_path"), paths["db"] + "sth"),
            (Symbol("sendsth_verified_path"), paths["db"] + "sendsth-verified"),
            (Symbol("entryhash_from_entry"),
             (Symbol("catlfish"), Symbol("entryhash_from_entry"))),
        ]
    if "storagenodes" in nodetype:
        plopconfig += [
            (Symbol("newentries_path"), paths["db"] + "newentries"),
            (Symbol("lastverifiednewentry_path"), paths["db"] + "lastverifiednewentry"),
        ]
    if nodetype & set(["frontendnodes", "mergenodes"]):
        plopconfig += [
            (Symbol("verify_entry"),
             (Symbol("catlfish"), Symbol("verify_entry"))),
        ]
    if "mergenodes" in nodetype:
        plopconfig += [
            (Symbol("verifiedsize_path"), paths["mergedb"] + "/verifiedsize"),
            (Symbol("index_path"), paths["mergedb"] + "/logorder"),
            (Symbol("entry_root_path"), paths["mergedb"] + "/chains"),
            ]

    signingnodes = config["signingnodes"]
    signingnodeaddresses = ["https://%s/plop/v1/signing/" % node["address"] for node in config["signingnodes"]]
    mergenodenames = [node["name"] for node in config["mergenodes"]]
    primarymergenodename = config["primarymergenode"]
    storagenodeaddresses = ["https://%s/plop/v1/storage/" % node["address"] for node in config["storagenodes"]]
    frontendnodenames = [node["name"] for node in config["frontendnodes"]]
    frontendnodeaddresses = ["https://%s/plop/v1/frontend/" % node["address"] for node in config["frontendnodes"]]

    allowed_clients = []
    allowed_servers = []
    services = set()

    if "frontendnodes" in nodetype:
        storagenodenames = [node["name"] for node in config["storagenodes"]]
        reloadableplopconfig.append((Symbol("storage_nodes"), storagenodeaddresses))
        reloadableplopconfig.append((Symbol("storage_nodes_quorum"), config["storage-quorum-size"]))
        services.add(Symbol("ht"))
        allowed_clients += allowed_clients_frontend(mergenodenames, primarymergenodename)
        allowed_clients += allowed_clients_public()
        allowed_servers += allowed_servers_frontend([node["name"] for node in signingnodes], storagenodenames)
    if "storagenodes" in nodetype:
        allowed_clients += allowed_clients_storage(frontendnodenames, mergenodenames)
    if "signingnodes" in nodetype:
        allowed_clients += allowed_clients_signing(frontendnodenames, primarymergenodename)
        services = [Symbol("sign")]
    if "mergenodes" in nodetype:
        reloadableplopconfig.append((Symbol("storage_nodes"), storagenodeaddresses))
        reloadableplopconfig.append((Symbol("storage_nodes_quorum"), config["storage-quorum-size"]))
        services.add(Symbol("ht"))
        if nodename == primarymergenodename:
            mergesecondarynames = [node["name"] for node in config["mergenodes"] if node["name"] != primarymergenodename]
            mergesecondaryaddresses = ["https://%s/plop/v1/merge/" % node["address"] for node in config["mergenodes"] if node["name"] != primarymergenodename]
            merge = localconfig["merge"]
            plopconfig.append((Symbol("db_backend_opt"), [(Symbol("write_flag"), Symbol("read"))]))
            plopconfig.append((Symbol("merge_delay"), merge["min-delay"]))
            plopconfig.append((Symbol("merge_dist_winsize"), merge["dist-window-size"]))
            plopconfig.append((Symbol("merge_dist_sendlog_chunksize"), merge["dist-sendlog-chunksize"]))
            plopconfig.append((Symbol("merge_dist_sendentries_chunksize"), merge["dist-sendentries-chunksize"]))
            reloadableplopconfig.append((Symbol("merge_backup_winsize"), merge["backup-window-size"]))
            reloadableplopconfig.append((Symbol("merge_backup_sendlog_chunksize"), merge["backup-sendlog-chunksize"]))
            reloadableplopconfig.append((Symbol("merge_backup_sendentries_chunksize"), merge["backup-sendentries-chunksize"]))
            plopconfig.append((Symbol("frontend_nodes"), frontendnodeaddresses))
            reloadableplopconfig.append((Symbol("merge_secondaries"), zip(mergesecondarynames, mergesecondaryaddresses)))
            plopconfig.append((Symbol("sth_path"), paths["mergedb"] + "/sth"))
            plopconfig.append((Symbol("fetched_path"), paths["mergedb"] + "/fetched"))
            plopconfig.append((Symbol("verified_path"), paths["mergedb"] + "/verified"))
            allowed_servers += allowed_servers_primarymerge(frontendnodenames, mergesecondarynames)
        else:
            allowed_clients += allowed_clients_mergesecondary(primarymergenodename)

    plopconfig += [
        (Symbol("services"), list(services)),
    ]
    if "signingnodes" in nodetype:
        hsm = localconfig.get("hsm")
        if "logprivatekey" in paths:
            plopconfig.append((Symbol("log_private_key"), paths["logprivatekey"]))
        if hsm:
            plopconfig.append((Symbol("hsm"), [hsm.get("library"), str(hsm.get("slot")), "ecdsa", hsm.get("label"), hsm.get("pin")]))
        if not ("logprivatekey" in paths or hsm):
            print >>sys.stderr, "Neither logprivatekey nor hsm configured for signing node", nodename
            sys.exit(1)
    plopconfig += [
        (Symbol("log_public_key"), Binary(base64.b64decode(config["logpublickey"]))),
        (Symbol("own_key"), (nodename, "%s/%s-private.pem" % (paths["privatekeys"], nodename))),
    ]
    if "frontendnodes" in nodetype:
        reloadableplopconfig.append((Symbol("signing_nodes"), signingnodeaddresses))
    plopconfig += [
        (Symbol("plopconfig"), plopconfigfilename),
        (Symbol("plopcontrol"), plopcontrolfilename),
    ]

    reloadableplopconfig += [
        (Symbol("allowed_clients"), list(allowed_clients)),
        (Symbol("allowed_servers"), list(allowed_servers)),
        (Symbol("apikeys"), apikeys),
        (Symbol("version"), config["version"]),
    ]

    erlangconfig = [
        (Symbol("sasl"), saslconfig),
        (Symbol("catlfish"), catlfishconfig),
        (Symbol("lager"), lagerconfig),
        (Symbol("plop"), plopconfig),
    ]

    print >>configfile, gen_erlang(erlangconfig) + ".\n"
    print >>plopconfigfile, gen_erlang(reloadableplopconfig) + ".\n"

    configfile.close()
    plopconfigfile.close()


def gen_testmakefile(config, testmakefile, machines, shellvars=False):
    configfile = open(testmakefile, "w")
    print >>configfile, "#", testmakefile, "generated by", sys.argv[0], datetime.now()

    frontendnodenames = set([node["name"] for node in config["frontendnodes"]])
    storagenodenames = set([node["name"] for node in config["storagenodes"]])
    signingnodenames = set([node["name"] for node in config["signingnodes"]])
    mergenodenames = set([node["name"] for node in config["mergenodes"]])
    erlangnodenames_and_apps = ['%s:%s' % (nn, 'catlfish' if nn != config["primarymergenode"] else "merge") for nn in frontendnodenames | storagenodenames | signingnodenames | mergenodenames]

    frontendnodeaddresses = [node["publicaddress"] for node in config["frontendnodes"]]
    storagenodeaddresses = [node["address"] for node in config["storagenodes"]]
    signingnodeaddresses = [node["address"] for node in config["signingnodes"]]
    mergenodeaddresses = [node["address"] for node in config["mergenodes"] if node["name"] != config["primarymergenode"]]

    delimiter = '"' if shellvars else ''
    
    print >>configfile, "NODES=" + delimiter + " ".join(frontendnodenames|storagenodenames|signingnodenames|mergenodenames) + delimiter
    print >>configfile, "ERLANGNODES=" + delimiter + " ".join(erlangnodenames_and_apps) + delimiter
    print >>configfile, "MACHINES=" + delimiter + " ".join([str(e) for e in range(1, machines+1)]) + delimiter
    print >>configfile, "TESTURLS=" + delimiter + " ".join(frontendnodeaddresses+storagenodeaddresses+signingnodeaddresses+mergenodeaddresses) + delimiter
    print >>configfile, "BASEURL=" + delimiter + config["baseurl"] + delimiter

    configfile.close()

def printnodenames(config):
    frontendnodenames = set([node["name"] for node in config["frontendnodes"]])
    storagenodenames = set([node["name"] for node in config["storagenodes"]])
    signingnodenames = set([node["name"] for node in config["signingnodes"]])
    mergenodenames = set([node["name"] for node in config["mergenodes"]])

    print " ".join(frontendnodenames|storagenodenames|signingnodenames|mergenodenames)

def main():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument('--config', help="System configuration", required=True)
    parser.add_argument('--localconfig', help="Local configuration")
    parser.add_argument("--testmakefile", metavar="file", help="Generate makefile variables for test")
    parser.add_argument("--testshellvars", metavar="file", help="Generate shell variable file for test")
    parser.add_argument("--getnodenames", action='store_true', help="Get list of node names")
    parser.add_argument("--machines", type=int, metavar="n", help="Number of machines")
    args = parser.parse_args()

    if args.testmakefile and args.machines:
        config = readconfig.read_config(args.config)
        gen_testmakefile(config, args.testmakefile, args.machines)
    elif args.testshellvars and args.machines:
        config = readconfig.read_config(args.config)
        gen_testmakefile(config, args.testshellvars, args.machines, shellvars=True)
    elif args.getnodenames:
        config = readconfig.read_config(args.config)
        printnodenames(config)
    elif args.localconfig:
        localconfig = readconfig.read_config(args.localconfig)
        config = readconfig.verify_and_read_config(args.config, localconfig["logadminkey"])

        localnodes = localconfig["localnodes"]
        for localnode in localnodes:
            gen_config(localnode, config, localconfig)
    else:
        print >>sys.stderr, "Nothing to do"
        sys.exit(1)

main()