#!/usr/bin/env python

# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import sys
import yaml
import re

class Symbol(str):
    pass

clean_string = re.compile(r'^[-.:_/A-Za-z0-9 ]*$')
clean_symbol = re.compile(r'^[_A-Za-z0-9]*$')

def quote_erlang_string(s):
    if clean_string.match(s):
        return '"' + s + '"'
    else:
        return "[" + ",".join([str(ord(c)) for c in s]) + "]"

def quote_erlang_symbol(s):
    if clean_symbol.match(s):
        return s
    elif clean_string.match(s):
        return "'" + s + "'"
    else:
        print >>sys.stderr, "Cannot generate symbol", s
        sys.exit(1)

def gen_erlang(term, level=1):
    indent = " " * level
    separator = ",\n" + indent
    if isinstance(term, Symbol):
        return quote_erlang_symbol(term)
    elif isinstance(term, basestring):
        return quote_erlang_string(term)
    elif isinstance(term, int):
        return str(term)
    elif isinstance(term, tuple):
        tuplecontents = [gen_erlang(e, level=level+1) for e in term]
        if "\n" not in "".join(tuplecontents):
            separator = ", "
        return "{" + separator.join(tuplecontents) + "}"
    elif isinstance(term, list):
        listcontents = [gen_erlang(e, level=level+1) for e in term]
        return "[" + separator.join(listcontents) + "]"
    else:
        print "unknown type", type(term)
        sys.exit(1)

saslconfig = [(Symbol("sasl_error_logger"), Symbol("false")),
              (Symbol("errlog_type"), Symbol("error")),
              (Symbol("error_logger_mf_dir"), "sasl_log"),
              (Symbol("error_logger_mf_maxbytes"), 10485760),
              (Symbol("error_logger_mf_maxfiles"), 10),
              ]

def parse_address(address):
    parsed_address = address.split(":")
    if len(parsed_address) != 2:
        print >>sys.stderr, "Invalid address format", address
        sys.exit(1)
    return (parsed_address[0], int(parsed_address[1]))

def get_node_config(nodename, config):
    nodetype = None
    nodeconfig = None
    for t in ["frontendnodes", "storagenodes", "signingnodes"]:
        for node in config[t]:
            if node["name"] == nodename:
                nodetype = t
                nodeconfig = node
    if nodeconfig == None:
        print >>sys.stderr, "Cannot find config for node", nodename
        sys.exit(1)
    return (nodetype, nodeconfig)

def gen_http_servers(nodetype, nodeconfig, bind_address, bind_publicaddress, bind_publichttpaddress):
    if bind_address:
        (host, port) = parse_address(bind_address)
    else:
        (_, port) = parse_address(nodeconfig["address"])
        host = "0.0.0.0"
    if nodetype == "frontendnodes":
        if bind_publicaddress:
            (publichost, publicport) = parse_address(bind_publicaddress)
        else:
            (_, publicport) = parse_address(nodeconfig["publicaddress"])
            publichost = "0.0.0.0"

        http_servers = []
        https_servers = []
        if bind_publichttpaddress:
            (publichttphost, publichttpport) = parse_address(bind_publichttpaddress)
            http_servers.append((Symbol("external_http_api"), publichttphost, publichttpport, Symbol("v1")))
        https_servers.append((Symbol("external_https_api"), publichost, publicport, Symbol("v1")))
        https_servers.append((Symbol("frontend_https_api"), host, port, Symbol("frontend")))
        return (http_servers,
                https_servers)

    elif nodetype == "storagenodes":
        return ([],
                [(Symbol("storage_https_api"), host, port, Symbol("storage"))])
    elif nodetype == "signingnodes":
        return ([],
                [(Symbol("signing_https_api"), host, port, Symbol("signing"))])

def allowed_clients_frontend(mergenodenames):
    return [
        ("/ct/frontend/sendentry", mergenodenames),
        ("/ct/frontend/sendlog", mergenodenames),
        ("/ct/frontend/sendsth", mergenodenames),
        ("/ct/frontend/currentposition", mergenodenames),
        ("/ct/frontend/missingentries", mergenodenames),
    ]

def allowed_clients_public():
    noauth = Symbol("noauth")
    return [
        ("/ct/v1/add-chain", noauth),
        ("/ct/v1/add-pre-chain", noauth),
        ("/ct/v1/get-sth", noauth),
        ("/ct/v1/get-sth-consistency", noauth),
        ("/ct/v1/get-proof-by-hash", noauth),
        ("/ct/v1/get-entries", noauth),
        ("/ct/v1/get-entry-and-proof", noauth),
        ("/ct/v1/get-roots", noauth),
    ]

def allowed_clients_signing(frontendnodenames, mergenodenames):
    return [
        ("/ct/signing/sct", frontendnodenames),
        ("/ct/signing/sth", mergenodenames),
    ]

def allowed_clients_storage(frontendnodenames, mergenodenames):
    return [
        ("/ct/storage/sendentry", frontendnodenames),
        ("/ct/storage/entrycommitted", frontendnodenames),
        ("/ct/storage/fetchnewentries", mergenodenames),
        ("/ct/storage/getentry", mergenodenames),
    ]

def allowed_servers_frontend(signingnodenames, storagenodenames):
    return [
        ("/ct/storage/sendentry", storagenodenames),
        ("/ct/storage/entrycommitted", storagenodenames),
        ("/ct/signing/sct", signingnodenames),
    ]

def gen_config(nodename, config, localconfig):
    print "generating config for", nodename
    paths = localconfig["paths"]
    bind_address = localconfig.get("addresses", {}).get(nodename)
    bind_publicaddress = localconfig.get("publicaddresses", {}).get(nodename)
    bind_publichttpaddress = localconfig.get("publichttpaddresses", {}).get(nodename)
    options = localconfig.get("options", [])

    configfile = open(paths["configdir"] + "/" + nodename + ".config", "w")
    print >>configfile, "%% catlfish configuration file (-*- erlang -*-)"

    (nodetype, nodeconfig) = get_node_config(nodename, config)
    (http_servers, https_servers) = gen_http_servers(nodetype, nodeconfig, bind_address, bind_publicaddress, bind_publichttpaddress=bind_publichttpaddress)

    catlfishconfig = []
    plopconfig = []

    if nodetype == "frontendnodes":
        catlfishconfig.append((Symbol("known_roots_path"), localconfig["paths"]["knownroots"]))
        if "sctcaching" in options:
            catlfishconfig.append((Symbol("sctcache_root_path"), paths["db"] + "sctcache/"))

    catlfishconfig += [
        (Symbol("https_servers"), https_servers),
        (Symbol("http_servers"), http_servers),
        (Symbol("https_certfile"), paths["https_certfile"]),
        (Symbol("https_keyfile"), paths["https_keyfile"]),
        (Symbol("https_cacertfile"), paths["https_cacertfile"]),
    ]

    lagerconfig = [
        (Symbol("handlers"), [
            (Symbol("lager_console_backend"), Symbol("info")),
            (Symbol("lager_file_backend"), [(Symbol("file"), nodename + "-error.log"), (Symbol("level"), Symbol("error"))]),
            (Symbol("lager_file_backend"), [(Symbol("file"), nodename + "-debug.log"), (Symbol("level"), Symbol("debug"))]),
            (Symbol("lager_file_backend"), [(Symbol("file"), nodename + "-console.log"), (Symbol("level"), Symbol("info"))]),
        ])
    ]

    if nodetype in ("frontendnodes", "storagenodes"):
        plopconfig += [
            (Symbol("entry_root_path"), paths["db"] + "certentries/"),
        ]
    if nodetype == "frontendnodes":
        plopconfig += [
            (Symbol("index_path"), paths["db"] + "index"),
        ]
    elif nodetype == "storagenodes":
        plopconfig += [
            (Symbol("newentries_path"), paths["db"] + "newentries"),
        ]
    if nodetype in ("frontendnodes", "storagenodes"):
        plopconfig += [
            (Symbol("entryhash_root_path"), paths["db"] + "entryhash/"),
            (Symbol("indexforhash_root_path"), paths["db"] + "certindex/"),
        ]
    if nodetype == "frontendnodes":
        plopconfig += [
            (Symbol("sth_path"), paths["db"] + "sth"),
            (Symbol("entryhash_from_entry"),
             (Symbol("catlfish"), Symbol("entryhash_from_entry"))),
        ]

    signingnodes = config["signingnodes"]
    signingnodeaddresses = ["https://%s/ct/signing/" % node["address"] for node in config["signingnodes"]]
    mergenodenames = [node["name"] for node in config["mergenodes"]]
    storagenodeaddresses = ["https://%s/ct/storage/" % node["address"] for node in config["storagenodes"]]
    frontendnodenames = [node["name"] for node in config["frontendnodes"]]

    allowed_clients = []
    allowed_servers = []

    if nodetype == "frontendnodes":
        storagenodenames = [node["name"] for node in config["storagenodes"]]
        plopconfig.append((Symbol("storage_nodes"), storagenodeaddresses))
        plopconfig.append((Symbol("storage_nodes_quorum"), config["storage-quorum-size"]))
        services = [Symbol("ht")]
        allowed_clients += allowed_clients_frontend(mergenodenames)
        allowed_clients += allowed_clients_public()
        allowed_servers += allowed_servers_frontend([node["name"] for node in signingnodes], storagenodenames)
    elif nodetype == "storagenodes":
        allowed_clients += allowed_clients_storage(frontendnodenames, mergenodenames)
        services = []
    elif nodetype == "signingnodes":
        allowed_clients += allowed_clients_signing(frontendnodenames, mergenodenames)
        services = [Symbol("sign")]

    plopconfig += [
        (Symbol("publickey_path"), paths["publickeys"]),
        (Symbol("services"), services),
    ]
    if nodetype == "signingnodes":
        plopconfig.append((Symbol("log_private_key"), paths["logprivatekey"]))
        hsm = localconfig.get("hsm")
        if hsm:
            plopconfig.append((Symbol("hsm"), [hsm.get("library"), str(hsm.get("slot")), "ecdsa", hsm.get("label"), hsm.get("pin")]))
    plopconfig += [
        (Symbol("log_public_key"), paths["logpublickey"]),
        (Symbol("own_key"), (nodename, "%s/%s-private.pem" % (paths["privatekeys"], nodename))),
    ]
    if nodetype == "frontendnodes":
        plopconfig.append((Symbol("signing_nodes"), signingnodeaddresses))
    plopconfig += [
        (Symbol("allowed_clients"), allowed_clients),
        (Symbol("allowed_servers"), allowed_servers),
    ]

    erlangconfig = [
        (Symbol("sasl"), saslconfig),
        (Symbol("catlfish"), catlfishconfig),
        (Symbol("lager"), lagerconfig),
        (Symbol("plop"), plopconfig),
    ]
    
    print >>configfile, gen_erlang(erlangconfig) + ".\n"
    
    configfile.close()


def gen_testmakefile(config, testmakefile, machines):
    configfile = open(testmakefile, "w")
    frontendnodenames = [node["name"] for node in config["frontendnodes"]]
    storagenodenames = [node["name"] for node in config["storagenodes"]]
    signingnodename = [node["name"] for node in config["signingnodes"]]

    frontendnodeaddresses = [node["publicaddress"] for node in config["frontendnodes"]]
    storagenodeaddresses = [node["address"] for node in config["storagenodes"]]
    signingnodeaddresses = [node["address"] for node in config["signingnodes"]]

    print >>configfile, "NODES=" + " ".join(frontendnodenames+storagenodenames+signingnodename)
    print >>configfile, "MACHINES=" + " ".join([str(e) for e in range(1, machines+1)])
    print >>configfile, "TESTURLS=" + " ".join(frontendnodeaddresses+storagenodeaddresses+signingnodeaddresses)
    print >>configfile, "BASEURL=" + config["baseurl"]

    configfile.close()


def main():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument('--config', help="System configuration", required=True)
    parser.add_argument('--localconfig', help="Local configuration")
    parser.add_argument("--testmakefile", metavar="file", help="Generate makefile variables for test")
    parser.add_argument("--machines", type=int, metavar="n", help="Number of machines")
    args = parser.parse_args()

    config = yaml.load(open(args.config))
    if args.testmakefile and args.machines:
        gen_testmakefile(config, args.testmakefile, args.machines)
    elif args.localconfig:
        localconfig = yaml.load(open(args.localconfig))
        localnodes = localconfig["localnodes"]
        for localnode in localnodes:
            gen_config(localnode, config, localconfig)
    else:
        print >>sys.stderr, "Nothing to do"
        sys.exit(1)

main()