#!/usr/bin/env python

# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import sys
import yaml
import re

class Symbol(str):
    pass

clean_string = re.compile(r'^[-.:_/A-Za-z0-9 ]*$')
clean_symbol = re.compile(r'^[_A-Za-z0-9]*$')

def quote_erlang_string(s):
    if clean_string.match(s):
        return '"' + s + '"'
    else:
        return "[" + ",".join([str(ord(c)) for c in s]) + "]"

def quote_erlang_symbol(s):
    if clean_symbol.match(s):
        return s
    elif clean_string.match(s):
        return "'" + s + "'"
    else:
        print >>sys.stderr, "Cannot generate symbol", s
        sys.exit(1)

def gen_erlang(term, level=1):
    indent = " " * level
    separator = ",\n" + indent
    if isinstance(term, Symbol):
        return quote_erlang_symbol(term)
    elif isinstance(term, basestring):
        return quote_erlang_string(term)
    elif isinstance(term, int):
        return str(term)
    elif isinstance(term, tuple):
        tuplecontents = [gen_erlang(e, level=level+1) for e in term]
        if "\n" not in "".join(tuplecontents):
            separator = ", "
        return "{" + separator.join(tuplecontents) + "}"
    elif isinstance(term, list):
        listcontents = [gen_erlang(e, level=level+1) for e in term]
        return "[" + separator.join(listcontents) + "]"
    else:
        print "unknown type", type(term)
        sys.exit(1)

saslconfig = [(Symbol("sasl_error_logger"), Symbol("false")),
              (Symbol("errlog_type"), Symbol("error")),
              (Symbol("error_logger_mf_dir"), "sasl_log"),
              (Symbol("error_logger_mf_maxbytes"), 10485760),
              (Symbol("error_logger_mf_maxfiles"), 10),
              ]

def parse_address(address):
    parsed_address = address.split(":")
    if len(parsed_address) != 2:
        print >>sys.stderr, "Invalid address format", address
        sys.exit(1)
    return (parsed_address[0], int(parsed_address[1]))

def get_node_config(nodename, config):
    nodetype = []
    nodeconfig = {}
    for t in ["frontendnodes", "storagenodes", "signingnodes", "mergenodes"]:
        for node in config[t]:
            if node["name"] == nodename:
                nodetype.append(t)
                nodeconfig[t] = node
    if len(nodetype) == 0:
        print >>sys.stderr, "Cannot find config for node", nodename
        sys.exit(1)
    if len(nodetype) >= 2 and set(nodetype) != set(["frontendnodes", "storagenodes"]):
        print >>sys.stderr, "Node type unsupported:", nodetype
        sys.exit(1)
    return (set(nodetype), nodeconfig)

def get_address(bind_address, nodeconfig):
    if bind_address:
        (host, port) = parse_address(bind_address)
    else:
        (_, port) = parse_address(nodeconfig["address"])
        host = "0.0.0.0"
    return (host, port)

def gen_http_servers(nodetype, nodeconfig, bind_address, bind_publicaddress, bind_publichttpaddress):
    http_servers = []
    https_servers = []

    if "frontendnodes" in nodetype and "mergenodes" in nodetype:
        print >>sys.stderr, "cannot have both frontend node and merge node at the same time", nodetype
        sys.exit(1)

    if "frontendnodes" in nodetype:
        (host, port) = get_address(bind_address, nodeconfig["frontendnodes"])
        if bind_publicaddress:
            (publichost, publicport) = parse_address(bind_publicaddress)
        else:
            (_, publicport) = parse_address(nodeconfig["publicaddress"])
            publichost = "0.0.0.0"

        if bind_publichttpaddress:
            (publichttphost, publichttpport) = parse_address(bind_publichttpaddress)
            http_servers.append((Symbol("external_http_api"), publichttphost, publichttpport, Symbol("v1")))
        https_servers.append((Symbol("external_https_api"), publichost, publicport, Symbol("v1")))
        https_servers.append((Symbol("frontend_https_api"), host, port, Symbol("frontend")))
    if "storagenodes" in nodetype:
        (host, port) = get_address(bind_address, nodeconfig["storagenodes"])
        https_servers.append((Symbol("storage_https_api"), host, port, Symbol("storage")))
    if "signingnodes" in nodetype:
        (host, port) = get_address(bind_address, nodeconfig["signingnodes"])
        https_servers.append((Symbol("signing_https_api"), host, port, Symbol("signing")))
    if "mergenodes" in nodetype:
        (host, port) = get_address(bind_address, nodeconfig["mergenodes"])
        https_servers.append((Symbol("frontend_https_api"), host, port, Symbol("frontend")))
    if nodetype - set(["frontendnodes", "storagenodes", "signingnodes", "mergenodes"]):
        print >>sys.stderr, "unknown nodetype", nodetype
        sys.exit(1)

    return (http_servers,
            https_servers)

def allowed_clients_frontend(mergenodenames, primarymergenode):
    return [
        ("/plop/v1/frontend/sendentry", mergenodenames),
        ("/plop/v1/frontend/sendlog", mergenodenames),
        ("/plop/v1/frontend/sendsth", [primarymergenode]),
        ("/plop/v1/frontend/currentposition", mergenodenames),
        ("/plop/v1/frontend/missingentries", mergenodenames),
    ]

def allowed_clients_mergesecondary(primarymergenode):
    return [
        ("/plop/v1/merge/sendentry", [primarymergenode]),
        ("/plop/v1/merge/sendlog", [primarymergenode]),
        ("/plop/v1/merge/verifyroot", [primarymergenode]),
        ("/plop/v1/merge/verifiedsize", [primarymergenode]),
        ("/plop/v1/merge/setverifiedsize", [primarymergenode]),
        ("/plop/v1/merge/missingentries", [primarymergenode]),
    ]

def allowed_clients_public():
    noauth = Symbol("noauth")
    return [
        ("/ct/v1/add-chain", noauth),
        ("/ct/v1/add-pre-chain", noauth),
        ("/ct/v1/get-sth", noauth),
        ("/ct/v1/get-sth-consistency", noauth),
        ("/ct/v1/get-proof-by-hash", noauth),
        ("/ct/v1/get-entries", noauth),
        ("/ct/v1/get-entry-and-proof", noauth),
        ("/ct/v1/get-roots", noauth),
    ]

def allowed_clients_signing(frontendnodenames, primarymergenode):
    return [
        ("/plop/v1/signing/sct", frontendnodenames),
        ("/plop/v1/signing/sth", [primarymergenode]),
    ]

def allowed_clients_storage(frontendnodenames, mergenodenames):
    return [
        ("/plop/v1/storage/sendentry", frontendnodenames),
        ("/plop/v1/storage/entrycommitted", frontendnodenames),
        ("/plop/v1/storage/fetchnewentries", mergenodenames),
        ("/plop/v1/storage/getentry", mergenodenames),
    ]

def allowed_servers_frontend(signingnodenames, storagenodenames):
    return [
        ("/plop/v1/storage/sendentry", storagenodenames),
        ("/plop/v1/storage/entrycommitted", storagenodenames),
        ("/plop/v1/signing/sct", signingnodenames),
    ]

def parse_ratelimit_expression(expression):
    if expression == "none":
        return Symbol("none")
    parts = expression.split(" ")
    if not (len(parts) == 3 and parts[1] == 'per' and parts[2] in ["second", "minute", "hour"]):
        print >>sys.stderr, "Ratelimit expressions must have the format \"<frequency> per second|minute|hour\" or \"none\""
        sys.exit(1)
    return (int(parts[0]), Symbol(parts[2]))

def parse_ratelimit((type, description)):
    descriptions = [parse_ratelimit_expression(s.strip()) for s in description.split(",")]
    if len(descriptions) != 1:
        print >>sys.stderr, "%s: Only one ratelimit expression supported right now" % (type,)
    return (Symbol(type), descriptions)

def gen_config(nodename, config, localconfig):
    print "generating config for", nodename
    paths = localconfig["paths"]
    bind_address = localconfig.get("addresses", {}).get(nodename)
    bind_publicaddress = localconfig.get("publicaddresses", {}).get(nodename)
    bind_publichttpaddress = localconfig.get("publichttpaddresses", {}).get(nodename)
    options = localconfig.get("options", [])

    configfile = open(paths["configdir"] + "/" + nodename + ".config", "w")
    print >>configfile, "%% catlfish configuration file (-*- erlang -*-)"

    (nodetype, nodeconfig) = get_node_config(nodename, config)
    (http_servers, https_servers) = gen_http_servers(nodetype, nodeconfig, bind_address, bind_publicaddress, bind_publichttpaddress=bind_publichttpaddress)

    catlfishconfig = []
    plopconfig = []

    if nodetype & set(["frontendnodes", "mergenodes"]):
        catlfishconfig.append((Symbol("known_roots_path"), localconfig["paths"]["knownroots"]))
    if "frontendnodes" in nodetype:
        if "sctcaching" in options:
            catlfishconfig.append((Symbol("sctcache_root_path"), paths["db"] + "sctcache/"))
        if "ratelimits" in localconfig:
            ratelimits = map(parse_ratelimit, localconfig["ratelimits"].items())
            catlfishconfig.append((Symbol("ratelimits"), ratelimits))

    catlfishconfig += [
        (Symbol("https_servers"), https_servers),
        (Symbol("http_servers"), http_servers),
        (Symbol("https_certfile"), paths["https_certfile"]),
        (Symbol("https_keyfile"), paths["https_keyfile"]),
        (Symbol("https_cacertfile"), paths["https_cacertfile"]),
    ]

    catlfishconfig.append((Symbol("mmd"), config["mmd"]))

    lagerconfig = [
        (Symbol("handlers"), [
            (Symbol("lager_console_backend"), Symbol("info")),
            (Symbol("lager_file_backend"), [(Symbol("file"), nodename + "-error.log"), (Symbol("level"), Symbol("error"))]),
            (Symbol("lager_file_backend"), [(Symbol("file"), nodename + "-debug.log"), (Symbol("level"), Symbol("debug"))]),
            (Symbol("lager_file_backend"), [(Symbol("file"), nodename + "-console.log"), (Symbol("level"), Symbol("info"))]),
        ])
    ]

    if "dbbackend" in localconfig:
        plopconfig += [
            (Symbol("db_backend"), Symbol(localconfig["dbbackend"])),
        ]
        assert nodetype == set("mergenodes")

    print "nodetype", nodetype
    if nodetype & set(["frontendnodes", "storagenodes"]):
        plopconfig += [
            (Symbol("entry_root_path"), paths["db"] + "certentries"),
            (Symbol("entryhash_root_path"), paths["db"] + "entryhash"),
            (Symbol("indexforhash_root_path"), paths["db"] + "certindex"),
        ]
    if "frontendnodes" in nodetype:
        plopconfig += [
            (Symbol("index_path"), paths["db"] + "index"),
            (Symbol("sth_path"), paths["db"] + "sth"),
            (Symbol("sendsth_verified_path"), paths["db"] + "sendsth-verified"),
            (Symbol("entryhash_from_entry"),
             (Symbol("catlfish"), Symbol("entryhash_from_entry"))),
        ]
    if "storagenodes" in nodetype:
        plopconfig += [
            (Symbol("newentries_path"), paths["db"] + "newentries"),
            (Symbol("lastverifiednewentry_path"), paths["db"] + "lastverifiednewentry"),
        ]
    if nodetype & set(["frontendnodes", "mergenodes"]):
        plopconfig += [
            (Symbol("verify_entry"),
             (Symbol("catlfish"), Symbol("verify_entry"))),
        ]
    if "mergenodes" in nodetype:
        plopconfig += [
            (Symbol("verifiedsize_path"), paths["mergedb"] + "/verifiedsize"),
            (Symbol("index_path"), paths["mergedb"] + "/logorder"),
            (Symbol("entry_root_path"), paths["mergedb"] + "/chains"),
            ]

    signingnodes = config["signingnodes"]
    signingnodeaddresses = ["https://%s/plop/v1/signing/" % node["address"] for node in config["signingnodes"]]
    mergenodenames = [node["name"] for node in config["mergenodes"]]
    primarymergenode = config["primarymergenode"]
    storagenodeaddresses = ["https://%s/plop/v1/storage/" % node["address"] for node in config["storagenodes"]]
    frontendnodenames = [node["name"] for node in config["frontendnodes"]]

    allowed_clients = []
    allowed_servers = []
    services = set()

    if "frontendnodes" in nodetype:
        storagenodenames = [node["name"] for node in config["storagenodes"]]
        plopconfig.append((Symbol("storage_nodes"), storagenodeaddresses))
        plopconfig.append((Symbol("storage_nodes_quorum"), config["storage-quorum-size"]))
        services.add(Symbol("ht"))
        allowed_clients += allowed_clients_frontend(mergenodenames, primarymergenode)
        allowed_clients += allowed_clients_public()
        allowed_servers += allowed_servers_frontend([node["name"] for node in signingnodes], storagenodenames)
    if "storagenodes" in nodetype:
        allowed_clients += allowed_clients_storage(frontendnodenames, mergenodenames)
    if "signingnodes" in nodetype:
        allowed_clients += allowed_clients_signing(frontendnodenames, primarymergenode)
        services = [Symbol("sign")]
    if "mergenodes" in nodetype:
        storagenodenames = [node["name"] for node in config["storagenodes"]]
        plopconfig.append((Symbol("storage_nodes"), storagenodeaddresses))
        plopconfig.append((Symbol("storage_nodes_quorum"), config["storage-quorum-size"]))
        services.add(Symbol("ht"))
        allowed_clients += allowed_clients_mergesecondary(primarymergenode)

    plopconfig += [
        (Symbol("publickey_path"), paths["publickeys"]),
        (Symbol("services"), list(services)),
    ]
    if "signingnodes" in nodetype:
        hsm = localconfig.get("hsm")
        if "logprivatekey" in paths:
            plopconfig.append((Symbol("log_private_key"), paths["logprivatekey"]))
        if hsm:
            plopconfig.append((Symbol("hsm"), [hsm.get("library"), str(hsm.get("slot")), "ecdsa", hsm.get("label"), hsm.get("pin")]))
        if not ("logprivatekey" in paths or hsm):
            print >>sys.stderr, "Neither logprivatekey nor hsm configured for signing node", nodename
            sys.exit(1)
    plopconfig += [
        (Symbol("log_public_key"), paths["logpublickey"]),
        (Symbol("own_key"), (nodename, "%s/%s-private.pem" % (paths["privatekeys"], nodename))),
    ]
    if "frontendnodes" in nodetype:
        plopconfig.append((Symbol("signing_nodes"), signingnodeaddresses))
    plopconfig += [
        (Symbol("allowed_clients"), list(allowed_clients)),
        (Symbol("allowed_servers"), list(allowed_servers)),
    ]

    erlangconfig = [
        (Symbol("sasl"), saslconfig),
        (Symbol("catlfish"), catlfishconfig),
        (Symbol("lager"), lagerconfig),
        (Symbol("plop"), plopconfig),
    ]

    print >>configfile, gen_erlang(erlangconfig) + ".\n"
    
    configfile.close()


def gen_testmakefile(config, testmakefile, machines):
    configfile = open(testmakefile, "w")
    frontendnodenames = [node["name"] for node in config["frontendnodes"]]
    storagenodenames = [node["name"] for node in config["storagenodes"]]
    signingnodenames = [node["name"] for node in config["signingnodes"]]
    mergenodenames = [node["name"] for node in config["mergenodes"]]
    erlangnodenames = frontendnodenames + storagenodenames + signingnodenames + \
      filter(lambda name: name != config["primarymergenode"], mergenodenames)

    frontendnodeaddresses = [node["publicaddress"] for node in config["frontendnodes"]]
    storagenodeaddresses = [node["address"] for node in config["storagenodes"]]
    signingnodeaddresses = [node["address"] for node in config["signingnodes"]]
    mergenodeaddresses = [node["address"] for node in config["mergenodes"] if node["name"] != config["primarymergenode"]]

    print >>configfile, "NODES=" + " ".join(frontendnodenames+storagenodenames+signingnodenames+mergenodenames)
    print >>configfile, "ERLANGNODES=" + " ".join(erlangnodenames)
    print >>configfile, "MACHINES=" + " ".join([str(e) for e in range(1, machines+1)])
    print >>configfile, "TESTURLS=" + " ".join(frontendnodeaddresses+storagenodeaddresses+signingnodeaddresses+mergenodeaddresses)
    print >>configfile, "BASEURL=" + config["baseurl"]

    configfile.close()


def main():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument('--config', help="System configuration", required=True)
    parser.add_argument('--localconfig', help="Local configuration")
    parser.add_argument("--testmakefile", metavar="file", help="Generate makefile variables for test")
    parser.add_argument("--machines", type=int, metavar="n", help="Number of machines")
    args = parser.parse_args()

    config = yaml.load(open(args.config))
    if args.testmakefile and args.machines:
        gen_testmakefile(config, args.testmakefile, args.machines)
    elif args.localconfig:
        localconfig = yaml.load(open(args.localconfig))
        localnodes = localconfig["localnodes"]
        for localnode in localnodes:
            gen_config(localnode, config, localconfig)
    else:
        print >>sys.stderr, "Nothing to do"
        sys.exit(1)

main()