#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2015, NORDUnet A/S.
# See LICENSE for licensing information.

import sys
import base64
import select
import requests
from time import sleep
from certtools import timing_point, build_merkle_tree, write_file, \
     create_ssl_context
from mergetools import chunks, backup_sendlog, get_logorder, \
     get_verifiedsize, get_missingentriesforbackup, read_chain, \
     hexencode, setverifiedsize, sendentries_merge, verifyroot, \
     get_nfetched, parse_args

def backup_loop(nodename, nodeaddress, own_key, paths, verifiedsize, chunk):
    for trynumber in range(5, 0, -1):
        sendlogresult = \
          backup_sendlog(nodename, nodeaddress, own_key, paths,
                         {"start": verifiedsize, "hashes": chunk})
        if sendlogresult == None:
            if trynumber == 1:
                return None
            select.select([], [], [], 10.0)
            print >>sys.stderr, "tries left:", trynumber
            sys.stderr.flush()
            continue
        return sendlogresult


def merge_backup(args, config, localconfig, secondaries):
    paths = localconfig["paths"]
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    mergedb = paths["mergedb"]
    chainsdir = mergedb + "/chains"
    logorderfile = mergedb + "/logorder"
    currentsizefile = mergedb + "/fetched"
    timing = timing_point()

    nfetched = get_nfetched(currentsizefile, logorderfile)
    timing_point(timing, "get nfetched")
    logorder = get_logorder(logorderfile, nfetched)
    tree_size = len(logorder)
    timing_point(timing, "get logorder")

    tree = build_merkle_tree(logorder)
    root_hash = tree[-1][0]
    timing_point(timing, "build tree")

    for secondary in secondaries:
        if secondary["name"] == config["primarymergenode"]:
            continue
        nodeaddress = "https://%s/" % secondary["address"]
        nodename = secondary["name"]
        timing = timing_point()
        print >>sys.stderr, "backing up to node", nodename
        sys.stderr.flush()
        verifiedsize = get_verifiedsize(nodename, nodeaddress, own_key, paths)
        timing_point(timing, "get verified size")
        print >>sys.stderr, "verified size", verifiedsize
        sys.stderr.flush()

        entries = [base64.b64encode(entry) for entry in logorder[verifiedsize:]]

        print >>sys.stderr, "determining end of log:",
        for chunk in chunks(entries, 100000):
            sendlogresult = backup_loop(nodename, nodeaddress, own_key, paths, verifiedsize, chunk[:10])
            if sendlogresult == None:
                print >>sys.stderr, "sendlog result was None"
                sys.exit(1)
            if sendlogresult["result"] != "ok":
                print >>sys.stderr, "backup_sendlog:", sendlogresult
                sys.exit(1)
            verifiedsize += len(chunk)
            print >>sys.stderr, verifiedsize,
            sys.stderr.flush()

        if verifiedsize > 100000:
            verifiedsize -= 100000
        else:
            verifiedsize = 0

        timing_point(timing, "checklog")

        entries = [base64.b64encode(entry) for entry in logorder[verifiedsize:]]
        print >>sys.stderr, "sending log:",
        sys.stderr.flush()
        for chunk in chunks(entries, 1000):
            sendlogresult = backup_loop(nodename, nodeaddress, own_key, paths, verifiedsize, chunk)
            if sendlogresult == None:
                sys.exit(1)
            if sendlogresult["result"] != "ok":
                print >>sys.stderr, "backup_sendlog:", sendlogresult
                sys.exit(1)
            verifiedsize += len(chunk)
            print >>sys.stderr, verifiedsize,
            sys.stderr.flush()
        print >>sys.stderr
        timing_point(timing, "sendlog")
        print >>sys.stderr, "log sent"
        sys.stderr.flush()

        missingentries = get_missingentriesforbackup(nodename, nodeaddress,
                                                     own_key, paths)
        timing_point(timing, "get missing")

        while missingentries:
            print >>sys.stderr, "missing entries:", len(missingentries)
            sys.stderr.flush()

            fetched_entries = 0
            print >>sys.stderr, "fetching missing entries",
            sys.stderr.flush()
            with requests.sessions.Session() as session:
                for missingentry_chunk in chunks(missingentries, 100):
                    missingentry_hashes = [base64.b64decode(missingentry) for missingentry in missingentry_chunk]
                    hashes_and_entries = [(hash, read_chain(chainsdir, hash)) for hash in missingentry_hashes]
                    sendentryresult = sendentries_merge(nodename, nodeaddress,
                                                        own_key, paths,
                                                        hashes_and_entries, session)
                    if sendentryresult["result"] != "ok":
                        print >>sys.stderr, "sendentry_merge:", sendentryresult
                        sys.exit(1)
                    fetched_entries += len(missingentry_hashes)
                    print >>sys.stderr, fetched_entries,
                    sys.stderr.flush()
            print >>sys.stderr
            sys.stderr.flush()
            timing_point(timing, "send missing")

            missingentries = get_missingentriesforbackup(nodename, nodeaddress,
                                                         own_key, paths)
            timing_point(timing, "get missing")

        verifyrootresult = verifyroot(nodename, nodeaddress, own_key, paths,
                                      tree_size)
        if verifyrootresult["result"] != "ok":
            print >>sys.stderr, "verifyroot:", verifyrootresult
            sys.exit(1)
        secondary_root_hash = base64.b64decode(verifyrootresult["root_hash"])
        if root_hash != secondary_root_hash:
            print >>sys.stderr, "secondary root hash was", \
              hexencode(secondary_root_hash)
            print >>sys.stderr, "  expected", hexencode(root_hash)
            sys.exit(1)
        timing_point(timing, "verifyroot")

        setverifiedsize(nodename, nodeaddress, own_key, paths, tree_size)
        backuppath = mergedb + "/verified." + nodename
        backupdata = {"tree_size": tree_size,
                      "sha256_root_hash": hexencode(root_hash)}
        #print >>sys.stderr, "DEBUG: writing to", backuppath, ":", backupdata
        write_file(backuppath, backupdata)

        if args.timing:
            print >>sys.stderr, "timing: merge_backup:", timing["deltatimes"]
            sys.stderr.flush()

def main():
    """
    Read logorder file up until what's indicated by fetched file and
    build the tree.

    Distribute entries to all secondaries, write tree size and tree head
    to backup.<secondary> files as each secondary is verified to have
    the entries.

    Sleep some and start over.
    """
    args, config, localconfig = parse_args()
    all_secondaries = \
      [n for n in config.get('mergenodes', []) if \
       n['name'] != config['primarymergenode']]
    paths = localconfig["paths"]
    create_ssl_context(cafile=paths["https_cacertfile"])

    if len(args.node) == 0:
        nodes = all_secondaries
    else:
        nodes = [n for n in all_secondaries if n["name"] in args.node]

    while True:
        merge_backup(args, config, localconfig, nodes)
        if args.interval is None:
            break
        print >>sys.stderr, "sleeping", args.interval, "seconds"
        sleep(args.interval)

if __name__ == '__main__':
    sys.exit(main())