#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2015, NORDUnet A/S.
# See LICENSE for licensing information.
#
# Fetch new entries from all storage nodes.
# See catlfish/doc/merge.txt for more about the merge process.
#
import sys
import struct
import subprocess
import requests
import signal
import logging
from time import sleep
from multiprocessing import Process, Pipe
from random import Random
from mergetools import get_logorder, verify_entry, get_new_entries, \
     chunks, fsync_logorder, get_entries, add_to_logorder, \
     hexencode, hexdecode, parse_args, perm, flock_ex_or_fail, Status, \
     terminate_child_procs
from certtools import timing_point, write_file, create_ssl_context

def merge_fetch_sequenced(args, config, localconfig):
    paths = localconfig["paths"]
    storagenodes = config["storagenodes"]
    mergedb = paths["mergedb"]
    logorderfile = mergedb + "/logorder"
    statusfile = mergedb + "/merge_fetch.status"
    s = Status(statusfile)
    chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains")
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    timing = timing_point()

    logorder = get_logorder(logorderfile)
    timing_point(timing, "get logorder")

    certsinlog = set(logorder)

    new_entries_per_node = {}
    new_entries = set()
    entries_to_fetch = {}

    for storagenode in storagenodes:
        logging.info("getting new entries from %s", storagenode["name"])
        new_entries_per_node[storagenode["name"]] = \
          set(get_new_entries(storagenode["name"],
                              "https://%s/" % storagenode["address"],
                              own_key, paths))
        new_entries.update(new_entries_per_node[storagenode["name"]])
        entries_to_fetch[storagenode["name"]] = []
    timing_point(timing, "get new entries")

    new_entries -= certsinlog
    logging.info("adding %d entries", len(new_entries))

    for ehash in new_entries:
        for storagenode in storagenodes:
            if ehash in new_entries_per_node[storagenode["name"]]:
                entries_to_fetch[storagenode["name"]].append(ehash)
                break

    verifycert = subprocess.Popen(
        [paths["verifycert_bin"], paths["known_roots"]],
        stdin=subprocess.PIPE, stdout=subprocess.PIPE)

    added_entries = 0
    for storagenode in storagenodes:
        nentries = len(entries_to_fetch[storagenode["name"]])
        logging.info("getting %d entries from %s", nentries, storagenode["name"])
        with requests.sessions.Session() as session:
            for chunk in chunks(entries_to_fetch[storagenode["name"]], 100):
                entries = get_entries(storagenode["name"],
                                      "https://%s/" % storagenode["address"],
                                      own_key, paths, chunk, session=session)
                for ehash in chunk:
                    entry = entries[ehash]
                    verify_entry(verifycert, entry, ehash)
                    chainsdb.add(ehash, entry)
                    add_to_logorder(logorderfile, ehash)
                    logorder.append(ehash)
                    certsinlog.add(ehash)
                    added_entries += 1
                s.status("PROG getting %d entries from %s: %d" %
                         (nentries, storagenode["name"], added_entries))
    chainsdb.commit()
    fsync_logorder(logorderfile)
    timing_point(timing, "add entries")
    logging.info("added %d entries", added_entries)

    verifycert.communicate(struct.pack("I", 0))

    if args.timing:
        logging.debug("timing: merge_fetch: %s", timing["deltatimes"])

    tree_size = len(logorder)
    if tree_size == 0:
        return (0, '')
    else:
        return (tree_size, logorder[tree_size-1])

def merge_fetch_worker(args, localconfig, storagenode, pipe):
    paths = localconfig["paths"]
    mergedb = paths["mergedb"]
    chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains")
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    to_fetch = set()
    timeout = max(3, args.mergeinterval / 10)
    while True:
        if pipe.poll(timeout):
            msg = pipe.recv().split()
            if len(msg) < 2:
                continue
            cmd = msg[0]
            ehash = msg[1]
            if cmd == 'FETCH':
                to_fetch.add(hexdecode(ehash))
            else:
                logging.warning("%s: unknown command from parent: %s",
                                storagenode["name"], msg)

        if len(to_fetch) > 0:
            logging.info("%s: fetching %d entries", storagenode["name"],
                         len(to_fetch))
            # TODO: Consider running the verifycert process longer.
            verifycert = subprocess.Popen(
                [paths["verifycert_bin"], paths["known_roots"]],
                stdin=subprocess.PIPE, stdout=subprocess.PIPE)
            # Chunking for letting other workers take the chainsdb lock.
            for chunk in chunks(list(to_fetch), 100):
                chainsdb.lock_ex()
                with requests.sessions.Session() as session:
                    entries = get_entries(storagenode["name"],
                                          "https://%s/" % storagenode["address"],
                                          own_key, paths, chunk, session=session)
                    for ehash in chunk:
                        entry = entries[ehash]
                        verify_entry(verifycert, entry, ehash)
                        chainsdb.add(ehash, entry)
                    chainsdb.commit()
                    chainsdb.release_lock()
                    for ehash in chunk:
                        pipe.send('FETCHED %s' % hexencode(ehash))
                        to_fetch.remove(ehash)
            verifycert.communicate(struct.pack("I", 0))

        new_entries = get_new_entries(storagenode["name"],
                                      "https://%s/" % storagenode["address"],
                                      own_key, paths)
        if len(new_entries) > 0:
            logging.info("%s: got %d new entries", storagenode["name"],
                         len(new_entries))
            for ehash in new_entries:
                pipe.send('NEWENTRY %s' % hexencode(ehash))

def term(signal, arg):
    terminate_child_procs()
    sys.exit(1)

def newworker(name, args):
    my_conn, child_conn = Pipe()
    p = Process(target=merge_fetch_worker,
                args=tuple(args + [child_conn]),
                name='merge_fetch_%s' % name)
    p.daemon = True
    p.start()
    logging.debug("%s started, pid %d", name, p.pid)
    return (name, my_conn, p)

def merge_fetch_parallel(args, config, localconfig):
    paths = localconfig["paths"]
    storagenodes = config["storagenodes"]
    mergedb = paths["mergedb"]
    logorderfile = mergedb + "/logorder"
    currentsizefile = mergedb + "/fetched"

    rand = Random()
    signal.signal(signal.SIGTERM, term)

    procs = {}
    for storagenode in storagenodes:
        name = storagenode['name']
        procs[name] = newworker(name, [args, localconfig, storagenode])

    logorder = get_logorder(logorderfile)   # List of entries in log.
    entries_in_log = set(logorder)          # Set of entries in log.
    entries_to_fetch = set()                # Set of entries to fetch.
    fetch = {}                              # Dict with entries to fetch.
    while procs:
        assert(not entries_to_fetch)
        # Poll worker processes.
        for name, pipe, p in procs.values():
            if not p.is_alive():
                logging.warning("%s is gone, restarting", name)
                procs[name] = newworker(name, [args, localconfig,
                                               storagenodes[name]])
                continue
            logging.info("polling %s", name)
            if pipe.poll(1):
                msg = pipe.recv().split()
                if len(msg) < 2:
                    logging.warning("unknown command from %s: %s", name, msg)
                    continue
                cmd = msg[0]
                ehash = msg[1]
                if cmd == 'NEWENTRY':
                    logging.info("NEWENTRY at %s: %s", name, ehash)
                    entries_to_fetch.add(ehash)
                    logging.debug("entries_to_fetch: %s", entries_to_fetch)
                elif cmd == 'FETCHED':
                    logging.info("FETCHED from %s: %s", name, ehash)
                    logorder.append(ehash)
                    add_to_logorder(logorderfile, hexdecode(ehash))
                    fsync_logorder(logorderfile)
                    entries_in_log.add(ehash)
                    if ehash in entries_to_fetch:
                        entries_to_fetch.remove(ehash)
                    del fetch[ehash]
                else:
                    logging.warning("unknown command from %s: %s", name, msg)

        # Ask workers to fetch entries.
        logging.debug("nof entries to fetch including entries in log: %d",
                      len(entries_to_fetch))
        entries_to_fetch -= entries_in_log
        logging.info("entries to fetch: %d", len(entries_to_fetch))
        # Add entries in entries_to_fetch as keys in dictionary fetch,
        # values being a list of storage nodes, in randomised order.
        for e in entries_to_fetch:
            if not e in fetch:
                l = procs.values()
                rand.shuffle(l)
                fetch[e] = l
        # For each entry to fetch, treat its list of nodes as a
        # circular list and ask the one in the front to fetch the
        # entry.
        while entries_to_fetch:
            ehash = entries_to_fetch.pop()
            nodes = fetch[ehash]
            node = nodes.pop(0)
            fetch[ehash] = nodes.append(node)
            name, pipe, p = node
            logging.info("asking %s to FETCH %s", name, ehash)
            pipe.send("FETCH %s" % ehash)

        # Update the 'fetched' file.
        logsize = len(logorder)
        if logsize == 0:
            last_hash = ''
        else:
            last_hash = logorder[logsize - 1]
        logging.info("updating 'fetched' file: %d %s", logsize-1, last_hash)
        currentsize = {"index": logsize - 1, "hash": last_hash}
        logging.debug("writing to %s: %s", currentsizefile, currentsize)
        write_file(currentsizefile, currentsize)

    return 0

def main():
    """
    If no `--mergeinterval':
      Fetch new entries from all storage nodes, in sequence, updating
      the 'logorder' file and the 'chains' database.

      Write 'fetched' to reflect how far in 'logorder' we've succesfully
      fetched and verified.

    If `--mergeinterval':
      Start one process per storage node, read their stdout for learning
      about two things: (i) new entries ready for fetching ("NEWENTRY") and
      (ii) new entries being succesfully fetched ("FETCHED").

      Write to their stdin ("FETCH") when they should fetch another entry.
      Update 'logorder' and the 'chains' database as we see new FETCHED
      messages.

      Write 'fetched' to reflect how far in 'logorder' we've succesfully
      fetched and verified.

      Keep doing this forever.

    NOTE: The point of having 'fetched' is that it can be atomically
    written while 'logorder' cannot (unless we're fine with rewriting it
    for each and every update, which we're not).

    TODO: Deduplicate some code.
    """
    args, config, localconfig = parse_args()
    paths = localconfig["paths"]
    mergedb = paths["mergedb"]
    currentsizefile = mergedb + "/fetched"
    lockfile = mergedb + "/.merge_fetch.lock"

    loglevel = getattr(logging, args.loglevel.upper())
    if args.mergeinterval is None:
        logging.basicConfig(level=loglevel)
    else:
        logging.basicConfig(filename=args.logdir + "/merge_fetch.log",
                            level=loglevel)

    if not flock_ex_or_fail(lockfile):
        logging.critical("unable to take lock %s", lockfile)
        return 1

    create_ssl_context(cafile=paths["https_cacertfile"])

    if args.mergeinterval:
        return merge_fetch_parallel(args, config, localconfig)
    else:
        logsize, last_hash = merge_fetch_sequenced(args, config, localconfig)
        currentsize = {"index": logsize - 1, "hash": hexencode(last_hash)}
        logging.debug("writing to %s: %s", currentsizefile, currentsize)
        write_file(currentsizefile, currentsize)
        return 0

if __name__ == '__main__':
    sys.exit(main())