#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2015, NORDUnet A/S.
# See LICENSE for licensing information.
#
# Copy entries indicated by file 'fetched' to all secondary merge nodes.
# See catlfish/doc/merge.txt for more about the merge process.
#
import sys
import base64
import select
import requests
from time import sleep
from base64 import b64encode, b64decode
from certtools import timing_point, build_merkle_tree, write_file, \
     create_ssl_context
from mergetools import chunks, backup_sendlog, get_logorder, \
     get_verifiedsize, get_missingentriesforbackup, \
     hexencode, setverifiedsize, sendentries_merge, verifyroot, \
     get_nfetched, parse_args, perm

def backup_loop(nodename, nodeaddress, own_key, paths, verifiedsize, chunk):
    for trynumber in range(5, 0, -1):
        sendlogresult = \
          backup_sendlog(nodename, nodeaddress, own_key, paths,
                         {"start": verifiedsize, "hashes": chunk})
        if sendlogresult == None:
            if trynumber == 1:
                return None
            select.select([], [], [], 10.0)
            print >>sys.stderr, "tries left:", trynumber
            sys.stderr.flush()
            continue
        return sendlogresult

sendlog_discover_chunksize = 100000

def sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths):
    print >>sys.stderr, "sending log:",
    sys.stderr.flush()
    for chunk in chunks(entries, 1000):
        sendlogresult = backup_loop(nodename, nodeaddress, own_key, paths, verifiedsize, chunk)
        if sendlogresult == None:
            sys.exit(1)
        if sendlogresult["result"] != "ok":
            print >>sys.stderr, "backup_sendlog:", sendlogresult
            sys.exit(1)
        verifiedsize += len(chunk)
        print >>sys.stderr, verifiedsize,
        sys.stderr.flush()
    print >>sys.stderr
    print >>sys.stderr, "log sent"
    sys.stderr.flush()

def fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing):
    missingentries = get_missingentriesforbackup(nodename, nodeaddress,
                                                 own_key, paths)
    timing_point(timing, "get missing")

    while missingentries:
        print >>sys.stderr, "missing entries:", len(missingentries)
        sys.stderr.flush()

        fetched_entries = 0
        print >>sys.stderr, "sending missing entries",
        sys.stderr.flush()
        with requests.sessions.Session() as session:
            for missingentry_chunk in chunks(missingentries, 100):
                missingentry_hashes = [base64.b64decode(missingentry) for missingentry in missingentry_chunk]
                hashes_and_entries = [(hash, chainsdb.get(hash)) for hash in missingentry_hashes]
                sendentryresult = sendentries_merge(nodename, nodeaddress,
                                                    own_key, paths,
                                                    hashes_and_entries, session)
                if sendentryresult["result"] != "ok":
                    print >>sys.stderr, "sendentries_merge:", sendentryresult
                    sys.exit(1)
                fetched_entries += len(missingentry_hashes)
                #print >>sys.stderr, fetched_entries,
                #sys.stderr.flush()
        print >>sys.stderr
        sys.stderr.flush()
        timing_point(timing, "send missing")

        missingentries = get_missingentriesforbackup(nodename, nodeaddress,
                                                     own_key, paths)
        timing_point(timing, "get missing")

def check_root(logorder, nodename, nodeaddress, own_key, paths, tree_size, timing):
    tree = build_merkle_tree(logorder[:tree_size])
    root_hash = tree[-1][0]
    timing_point(timing, "build tree")
    verifyrootresult = verifyroot(nodename, nodeaddress, own_key, paths,
                                  tree_size)
    if verifyrootresult["result"] != "ok":
        print >>sys.stderr, "verifyroot:", verifyrootresult
        sys.exit(1)
    secondary_root_hash = base64.b64decode(verifyrootresult["root_hash"])
    if root_hash != secondary_root_hash:
        print >>sys.stderr, "secondary root hash was", \
          hexencode(secondary_root_hash)
        print >>sys.stderr, "  expected", hexencode(root_hash)
        sys.exit(1)
    timing_point(timing, "verifyroot")
    return root_hash
        
def merge_backup(args, config, localconfig, secondaries):
    maxwindow = localconfig.get("maxwindow", 1000)
    paths = localconfig["paths"]
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    mergedb = paths["mergedb"]
    chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains")
    logorderfile = mergedb + "/logorder"
    currentsizefile = mergedb + "/fetched"
    timing = timing_point()

    nfetched = get_nfetched(currentsizefile, logorderfile)
    timing_point(timing, "get nfetched")
    logorder = get_logorder(logorderfile, nfetched)
    tree_size = len(logorder)
    timing_point(timing, "get logorder")

    for secondary in secondaries:
        if secondary["name"] == config["primarymergenode"]:
            continue
        nodeaddress = "https://%s/" % secondary["address"]
        nodename = secondary["name"]
        timing = timing_point()
        print >>sys.stderr, "backing up to node", nodename
        sys.stderr.flush()
        verifiedsize = get_verifiedsize(nodename, nodeaddress, own_key, paths)
        timing_point(timing, "get verified size")
        print >>sys.stderr, "verified size", verifiedsize
        sys.stderr.flush()

        if verifiedsize == tree_size:
            root_hash = check_root(logorder, nodename, nodeaddress, own_key, paths, tree_size, timing)
        else:
            while verifiedsize < tree_size:
                uptopos = min(verifiedsize + maxwindow, tree_size)

                entries = [b64encode(entry) for entry in logorder[verifiedsize:uptopos]]
                sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths)
                timing_point(timing, "sendlog")

                fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing)

                root_hash = check_root(logorder, nodename, nodeaddress, own_key, paths, uptopos, timing)

                verifiedsize = uptopos
                setverifiedsize(nodename, nodeaddress, own_key, paths, verifiedsize)

        backuppath = mergedb + "/verified." + nodename
        backupdata = {"tree_size": tree_size,
                      "sha256_root_hash": hexencode(root_hash)}
        #print >>sys.stderr, "DEBUG: writing to", backuppath, ":", backupdata
        write_file(backuppath, backupdata)

        if args.timing:
            print >>sys.stderr, "timing: merge_backup:", timing["deltatimes"]
            sys.stderr.flush()

def main():
    """
    Read logorder file up until what's indicated by fetched file and
    build the tree.

    Distribute entries to all secondaries, write tree size and tree head
    to backup.<secondary> files as each secondary is verified to have
    the entries.

    Sleep some and start over.
    """
    args, config, localconfig = parse_args()
    all_secondaries = \
      [n for n in config.get('mergenodes', []) if \
       n['name'] != config['primarymergenode']]
    paths = localconfig["paths"]
    create_ssl_context(cafile=paths["https_cacertfile"])

    if len(args.node) == 0:
        nodes = all_secondaries
    else:
        nodes = [n for n in all_secondaries if n["name"] in args.node]

    while True:
        merge_backup(args, config, localconfig, nodes)
        if args.interval is None:
            break
        print >>sys.stderr, "sleeping", args.interval, "seconds"
        sleep(args.interval)

if __name__ == '__main__':
    sys.exit(main())