#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2015, NORDUnet A/S.
# See LICENSE for licensing information.

import sys
import json
from time import sleep
from base64 import b64encode, b64decode
from certtools import timing_point, \
     create_ssl_context
from mergetools import get_curpos, get_logorder, chunks, get_missingentries, \
     sendsth, sendlog, sendentry, read_chain, parse_args

def merge_dist(args, localconfig, frontendnodes, timestamp):
    paths = localconfig["paths"]
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    mergedb = paths["mergedb"]
    chainsdir = mergedb + "/chains"
    logorderfile = mergedb + "/logorder"
    sthfile = mergedb + "/sth"
    create_ssl_context(cafile=paths["https_cacertfile"])
    timing = timing_point()

    try:
        sth = json.loads(open(sthfile, 'r').read())
    except (IOError, ValueError):
        print >>sys.stderr, "No valid STH file found in", sthfile
        return timestamp
    if sth['timestamp'] < timestamp:
        print >>sys.stderr, "New STH file older than the previous one:", \
          sth['timestamp'], "<", timestamp
        return timestamp
    if sth['timestamp'] == timestamp:
        return timestamp
    timestamp = sth['timestamp']

    logorder = get_logorder(logorderfile, sth['tree_size'])
    timing_point(timing, "get logorder")

    for frontendnode in frontendnodes:
        nodeaddress = "https://%s/" % frontendnode["address"]
        nodename = frontendnode["name"]
        timing = timing_point()

        print >>sys.stderr, "distributing for node", nodename
        sys.stderr.flush()
        curpos = get_curpos(nodename, nodeaddress, own_key, paths)
        timing_point(timing, "get curpos")
        print >>sys.stderr, "current position", curpos
        sys.stderr.flush()

        entries = [b64encode(entry) for entry in logorder[curpos:]]
        print >>sys.stderr, "sending log:",
        sys.stderr.flush()
        for chunk in chunks(entries, 1000):
            for trynumber in range(5, 0, -1):
                sendlogresult = sendlog(nodename, nodeaddress,
                                        own_key, paths,
                                        {"start": curpos, "hashes": chunk})
                if sendlogresult == None:
                    if trynumber == 1:
                        sys.exit(1)
                    sleep(10)
                    print >>sys.stderr, "tries left:", trynumber
                    sys.stderr.flush()
                    continue
                break
            if sendlogresult["result"] != "ok":
                print >>sys.stderr, "sendlog:", sendlogresult
                sys.exit(1)
            curpos += len(chunk)
            print >>sys.stderr, curpos,
            sys.stderr.flush()
        print >>sys.stderr
        timing_point(timing, "sendlog")
        print >>sys.stderr, "log sent"
        sys.stderr.flush()

        missingentries = get_missingentries(nodename, nodeaddress, own_key,
                                            paths)
        timing_point(timing, "get missing")

        print >>sys.stderr, "missing entries:", len(missingentries)
        sys.stderr.flush()
        sent_entries = 0
        print >>sys.stderr, "send missing entries",
        sys.stderr.flush()
        for missingentry in missingentries:
            ehash = b64decode(missingentry)
            sendentryresult = sendentry(nodename, nodeaddress, own_key, paths,
                                        read_chain(chainsdir, ehash), ehash)
            if sendentryresult["result"] != "ok":
                print >>sys.stderr, "sendentry:", sendentryresult
                sys.exit(1)
            sent_entries += 1
            if sent_entries % 1000 == 0:
                print >>sys.stderr, sent_entries,
                sys.stderr.flush()
        print >>sys.stderr
        sys.stderr.flush()
        timing_point(timing, "send missing")

        print >>sys.stderr, "sending sth to node", nodename
        sys.stderr.flush()
        sendsthresult = sendsth(nodename, nodeaddress, own_key, paths, sth)
        if sendsthresult["result"] != "ok":
            print >>sys.stderr, "sendsth:", sendsthresult
            sys.exit(1)
        timing_point(timing, "send sth")

        if args.timing:
            print >>sys.stderr, "timing: merge_dist:", timing["deltatimes"]
            sys.stderr.flush()

    return timestamp

def main():
    """
    Distribute missing entries and the STH to all frontend nodes.
    """
    args, config, localconfig = parse_args()
    timestamp = 0

    if len(args.node) == 0:
        nodes = config["frontendnodes"]
    else:
        nodes = [n for n in config["frontendnodes"] if n["name"] in args.node]

    while True:
        timestamp = merge_dist(args, localconfig, nodes, timestamp)
        if args.interval is None:
            break
        print >>sys.stderr, "sleeping", args.interval, "seconds"
        sleep(args.interval)

if __name__ == '__main__':
    sys.exit(main())