#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2015, NORDUnet A/S.
# See LICENSE for licensing information.
#
# Distribute the 'sth' file and all missing entries to all frontend nodes.
# See catlfish/doc/merge.txt for more about the merge process.
#
import sys
import json
import requests
import logging
from time import sleep
from base64 import b64encode, b64decode
from os import stat
from certtools import timing_point, create_ssl_context
from mergetools import get_curpos, get_logorder, chunks, get_missingentries, \
     publish_sth, sendlog, sendentries, parse_args, perm, \
     get_frontend_verifiedsize, frontend_verify_entries, \
     waitforfile, flock_ex_or_fail, Status

def sendlog_helper(entries, curpos, nodename, nodeaddress, own_key, paths,
                   statusupdates):
    logging.info("sending log")
    for chunk in chunks(entries, 1000):
        for trynumber in range(5, 0, -1):
            sendlogresult = sendlog(nodename, nodeaddress,
                                    own_key, paths,
                                    {"start": curpos, "hashes": chunk})
            if sendlogresult == None:
                if trynumber == 1:
                    sys.exit(1)
                sleep(10)
                logging.warning("tries left: %d", trynumber)
                continue
            break
        if sendlogresult["result"] != "ok":
            logging.error("sendlog: %s", sendlogresult)
            sys.exit(1)
        curpos += len(chunk)
        statusupdates.status("PROG sending log: %d" % curpos)
    logging.info("log sent")

def fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb,
                            timing, statusupdates):
    missingentries = get_missingentries(nodename, nodeaddress, own_key,
                                        paths)
    timing_point(timing, "get missing")

    while missingentries:
        logging.info("about to send %d missing entries", len(missingentries))
        sent_entries = 0
        with requests.sessions.Session() as session:
            for missingentry_chunk in chunks(missingentries, 100):
                missingentry_hashes = [b64decode(missingentry) for missingentry in missingentry_chunk]
                hashes_and_entries = [(ehash, chainsdb.get(ehash)) for ehash in missingentry_hashes]
                sendentryresult = sendentries(nodename, nodeaddress,
                                             own_key, paths,
                                             hashes_and_entries, session)
                if sendentryresult["result"] != "ok":
                    logging.error("sendentries: %s", sendentryresult)
                    sys.exit(1)
                sent_entries += len(missingentry_hashes)
                statusupdates.status(
                    "PROG sending missing entries: %d" % sent_entries)
        timing_point(timing, "send missing")

        missingentries = get_missingentries(nodename, nodeaddress,
                                                     own_key, paths)
        timing_point(timing, "get missing")

def merge_dist(args, localconfig, frontendnodes, timestamp):
    maxwindow = localconfig.get("maxwindow", 1000)
    paths = localconfig["paths"]
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    mergedb = paths["mergedb"]
    chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains")
    logorderfile = mergedb + "/logorder"
    sthfile = mergedb + "/sth"
    statusfile = mergedb + "/merge_dist.status"
    s = Status(statusfile)
    create_ssl_context(cafile=paths["https_cacertfile"])
    timing = timing_point()

    try:
        sth = json.loads(open(sthfile, 'r').read())
    except (IOError, ValueError):
        logging.warning("No valid STH file found in %s", sthfile)
        return timestamp
    if sth['timestamp'] < timestamp:
        logging.warning("New STH file older than the previous one: %d < %d",
                     sth['timestamp'], timestamp)
        return timestamp
    if sth['timestamp'] == timestamp:
        return timestamp
    timestamp = sth['timestamp']

    logorder = get_logorder(logorderfile, sth['tree_size'])
    timing_point(timing, "get logorder")

    for frontendnode in frontendnodes:
        nodeaddress = "https://%s/" % frontendnode["address"]
        nodename = frontendnode["name"]
        timing = timing_point()

        logging.info("distributing for node %s", nodename)
        curpos = get_curpos(nodename, nodeaddress, own_key, paths)
        timing_point(timing, "get curpos")
        logging.info("current position %d", curpos)

        verifiedsize = get_frontend_verifiedsize(nodename, nodeaddress, own_key, paths)
        timing_point(timing, "get verified size")
        logging.info("verified size %d", verifiedsize)

        assert verifiedsize >= curpos

        while verifiedsize < len(logorder):
            uptopos = min(verifiedsize + maxwindow, len(logorder))
        
            entries = [b64encode(entry) for entry in logorder[verifiedsize:uptopos]]
            sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths, s)
            timing_point(timing, "sendlog")

            fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing, s)

            verifiedsize = frontend_verify_entries(nodename, nodeaddress, own_key, paths, uptopos)
        
        logging.info("sending sth to node %s", nodename)
        publishsthresult = publish_sth(nodename, nodeaddress, own_key, paths, sth)
        if publishsthresult["result"] != "ok":
            logging.info("publishsth: %s", publishsthresult)
            sys.exit(1)
        timing_point(timing, "send sth")

        if args.timing:
            logging.debug("timing: merge_dist: %s", timing["deltatimes"])

    return timestamp

def main():
    """
    Wait until 'sth' exists and read it.

    Distribute missing entries and the STH to all frontend nodes.

    If `--mergeinterval', wait until 'sth' is updated and read it and
    start distributing again.
    """
    args, config, localconfig = parse_args()
    paths = localconfig["paths"]
    mergedb = paths["mergedb"]
    lockfile = mergedb + "/.merge_dist.lock"
    timestamp = 0

    loglevel = getattr(logging, args.loglevel.upper())
    if args.mergeinterval is None:
        logging.basicConfig(level=loglevel)
    else:
        logging.basicConfig(filename=args.logdir + "/merge_dist.log",
                            level=loglevel)

    if not flock_ex_or_fail(lockfile):
        logging.critical("unable to take lock %s", lockfile)
        return 1

    if len(args.node) == 0:
        nodes = config["frontendnodes"]
    else:
        nodes = [n for n in config["frontendnodes"] if n["name"] in args.node]

    if args.mergeinterval is None:
        if merge_dist(args, localconfig, nodes, timestamp) < 0:
            return 1
        return 0

    sth_path = localconfig["paths"]["mergedb"] + "/sth"
    sth_statinfo = waitforfile(sth_path)
    while True:
        if merge_dist(args, localconfig, nodes, timestamp) < 0:
            return 1
        sth_statinfo_old = sth_statinfo
        while sth_statinfo == sth_statinfo_old:
            sleep(args.mergeinterval / 30)
            sth_statinfo = stat(sth_path)
    return 0

if __name__ == '__main__':
    sys.exit(main())