#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2015, NORDUnet A/S.
# See LICENSE for licensing information.

import sys
import struct
import subprocess
from time import sleep
from mergetools import get_logorder, verify_entry, get_new_entries, \
     chunks, fsync_logorder, get_entries, add_to_logorder, \
     hexencode, parse_args, perm
from certtools import timing_point, write_file, create_ssl_context

def merge_fetch(args, config, localconfig):
    paths = localconfig["paths"]
    storagenodes = config["storagenodes"]
    mergedb = paths["mergedb"]
    logorderfile = mergedb + "/logorder"
    chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains")
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    timing = timing_point()

    logorder = get_logorder(logorderfile)
    timing_point(timing, "get logorder")

    certsinlog = set(logorder)

    new_entries_per_node = {}
    new_entries = set()
    entries_to_fetch = {}

    for storagenode in storagenodes:
        print >>sys.stderr, "getting new entries from", storagenode["name"]
        sys.stderr.flush()
        new_entries_per_node[storagenode["name"]] = \
          set(get_new_entries(storagenode["name"],
                              "https://%s/" % storagenode["address"],
                              own_key, paths))
        new_entries.update(new_entries_per_node[storagenode["name"]])
        entries_to_fetch[storagenode["name"]] = []
    timing_point(timing, "get new entries")

    new_entries -= certsinlog
    print >>sys.stderr, "adding", len(new_entries), "entries"
    sys.stderr.flush()

    for ehash in new_entries:
        for storagenode in storagenodes:
            if ehash in new_entries_per_node[storagenode["name"]]:
                entries_to_fetch[storagenode["name"]].append(ehash)
                break

    verifycert = subprocess.Popen(
        [paths["verifycert_bin"], paths["known_roots"]],
        stdin=subprocess.PIPE, stdout=subprocess.PIPE)

    added_entries = 0
    for storagenode in storagenodes:
        print >>sys.stderr, "getting %d entries from %s:" % \
          (len(entries_to_fetch[storagenode["name"]]), storagenode["name"]),
        sys.stderr.flush()
        for chunk in chunks(entries_to_fetch[storagenode["name"]], 100):
            entries = get_entries(storagenode["name"],
                                  "https://%s/" % storagenode["address"],
                                  own_key, paths, chunk)
            for ehash in chunk:
                entry = entries[ehash]
                verify_entry(verifycert, entry, ehash)
                chainsdb.add(ehash, entry)
                add_to_logorder(logorderfile, ehash)
                logorder.append(ehash)
                certsinlog.add(ehash)
                added_entries += 1
            print >>sys.stderr, added_entries,
            sys.stderr.flush()
        print >>sys.stderr
        sys.stderr.flush()
    chainsdb.commit()
    fsync_logorder(logorderfile)
    timing_point(timing, "add entries")
    print >>sys.stderr, "added", added_entries, "entries"
    sys.stderr.flush()

    verifycert.communicate(struct.pack("I", 0))

    if args.timing:
        print >>sys.stderr, "timing: merge_fetch:", timing["deltatimes"]
        sys.stderr.flush()

    tree_size = len(logorder)
    if tree_size == 0:
        return (0, '')
    else:
        return (tree_size, logorder[tree_size-1])

def main():
    """
    Fetch new entries from all storage nodes.

    Indicate current position by writing the index in the logorder file
    (0-based) to the 'fetched' file.

    Sleep some and start over.
    """
    args, config, localconfig = parse_args()
    paths = localconfig["paths"]
    mergedb = paths["mergedb"]
    currentsizefile = mergedb + "/fetched"
    create_ssl_context(cafile=paths["https_cacertfile"])

    while True:
        logsize, last_hash = merge_fetch(args, config, localconfig)
        currentsize = {"index": logsize - 1, "hash": hexencode(last_hash)}
        #print >>sys.stderr, "DEBUG: writing to", currentsizefile, ":", currentsize
        write_file(currentsizefile, currentsize)
        if args.interval is None:
            break
        print >>sys.stderr, "sleeping", args.interval, "seconds"
        sleep(args.interval)

if __name__ == '__main__':
    sys.exit(main())