#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2015, NORDUnet A/S.
# See LICENSE for licensing information.
#
# Fetch new entries from all storage nodes.
# See catlfish/doc/merge.txt for more about the merge process.
#
import sys
import struct
import subprocess
import requests
import signal
import logging
from time import sleep
from multiprocessing import Process, Pipe
from random import Random
from itertools import cycle
from mergetools import get_logorder, verify_entry, get_new_entries, \
     chunks, fsync_logorder, get_entries, add_to_logorder, \
     hexencode, hexdecode, parse_args, perm, flock_ex_or_fail, Status, \
     terminate_child_procs, loginit
from certtools import timing_point, write_file, create_ssl_context

def merge_fetch_sequenced(args, config, localconfig):
    paths = localconfig["paths"]
    storagenodes = config["storagenodes"]
    mergedb = paths["mergedb"]
    logorderfile = mergedb + "/logorder"
    statusfile = mergedb + "/merge_fetch.status"
    s = Status(statusfile)
    chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains")
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    timing = timing_point()

    logorder = get_logorder(logorderfile)
    timing_point(timing, "get logorder")

    certsinlog = set(logorder)

    new_entries_per_node = {}
    new_entries = set()
    entries_to_fetch = {}

    for storagenode in storagenodes:
        logging.info("getting new entries from %s", storagenode["name"])
        new_entries_per_node[storagenode["name"]] = \
          set(get_new_entries(storagenode["name"],
                              "https://%s/" % storagenode["address"],
                              own_key, paths))
        new_entries.update(new_entries_per_node[storagenode["name"]])
        entries_to_fetch[storagenode["name"]] = []
    timing_point(timing, "get new entries")

    new_entries -= certsinlog
    logging.info("adding %d entries", len(new_entries))

    for ehash in new_entries:
        for storagenode in storagenodes:
            if ehash in new_entries_per_node[storagenode["name"]]:
                entries_to_fetch[storagenode["name"]].append(ehash)
                break

    verifycert = subprocess.Popen(
        [paths["verifycert_bin"], paths["known_roots"]],
        stdin=subprocess.PIPE, stdout=subprocess.PIPE)

    added_entries = 0
    for storagenode in storagenodes:
        nentries = len(entries_to_fetch[storagenode["name"]])
        logging.info("getting %d entries from %s", nentries, storagenode["name"])
        with requests.sessions.Session() as session:
            for chunk in chunks(entries_to_fetch[storagenode["name"]], 100):
                entries = get_entries(storagenode["name"],
                                      "https://%s/" % storagenode["address"],
                                      own_key, paths, chunk, session=session)
                for ehash in chunk:
                    entry = entries[ehash]
                    verify_entry(verifycert, entry, ehash)
                    chainsdb.add(ehash, entry)
                    add_to_logorder(logorderfile, ehash)
                    logorder.append(ehash)
                    certsinlog.add(ehash)
                    added_entries += 1
                s.status("PROG getting %d entries from %s: %d" %
                         (nentries, storagenode["name"], added_entries))
    chainsdb.commit()
    fsync_logorder(logorderfile)
    timing_point(timing, "add entries")
    logging.info("added %d entries", added_entries)

    verifycert.communicate(struct.pack("I", 0))

    if args.timing:
        logging.debug("timing: merge_fetch: %s", timing["deltatimes"])

    tree_size = len(logorder)
    if tree_size == 0:
        return (0, '')
    else:
        return (tree_size, logorder[tree_size-1])

def read_parent_messages(pipe, to_fetch):
    while pipe.poll():
        cmd, ehash = pipe.recv()
        if cmd == 'FETCH':
            to_fetch.add(ehash)

def merge_fetch_worker(args, localconfig, storagenode, pipe):
    paths = localconfig["paths"]
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    name = storagenode["name"]
    address = storagenode["address"]
    url = "https://%s/" % address

    # NOTE: We should probably verifycert.communicate(struct.pack("I",0))
    # to ask the verifycert process to quit nicely.
    verifycert = subprocess.Popen([paths["verifycert_bin"], paths["known_roots"]],
                                   stdin=subprocess.PIPE, stdout=subprocess.PIPE)
    to_fetch = set()
    while True:
        ## Read all messages from parent.
        read_parent_messages(pipe, to_fetch)

        ## Fetch entries from node.
        if to_fetch:
            logging.info("%s: fetching %d entries", name, len(to_fetch))
            with requests.sessions.Session() as session:
                for chunk in chunks(list(to_fetch), 100):
                    entries = get_entries(name, url, own_key, paths, chunk,
                                          session=session)
                    for ehash in chunk:
                        entry = entries[ehash]
                        verify_entry(verifycert, entry, ehash)
                        pipe.send(('FETCHED', ehash, entry))
                        to_fetch.remove(ehash)
                        read_parent_messages(pipe, to_fetch) # Drain pipe.

        ## Ask node for more entries.
        if not to_fetch:
            for ehash in get_new_entries(name, url, own_key, paths):
                pipe.send(('NEWENTRY',  ehash))
                read_parent_messages(pipe, to_fetch)     # Drain pipe.

        ## Wait some if nothing to do.
        if not to_fetch:
            sleep(max(3, args.mergeinterval / 10))

def term(signal, arg):
    terminate_child_procs()
    sys.exit(1)

def newworker(name, args):
    my_conn, child_conn = Pipe()
    p = Process(target=merge_fetch_worker,
                args=tuple(args + [child_conn]),
                name='merge_fetch_%s' % name)
    p.daemon = True
    p.start()
    logging.debug("%s started, pid %d", name, p.pid)
    return (name, my_conn, p)

def read_worker_messages(procs, messages, args, localconfig):
    for name, pipe, p, storagenode in procs.values():
        if not p.is_alive():
            logging.warning("%s is gone, restarting", name)
            procs[name] = \
              newworker(name, [args, localconfig, storagenode]) + (storagenode,)
            continue
        while pipe.poll():
            messages.append((name, pipe.recv()))

def process_worker_message(name, msg, fetch_dict, fetch_set, chainsdb, newentry,
                           logorder, entries_in_log):
    cmd = msg[0]
    ehash = msg[1]
    if cmd == 'NEWENTRY':
        logging.info("NEWENTRY at %s: %s", name, hexencode(ehash))
        if not ehash in fetch_dict:               # Don't fetch twice.
            fetch_set.add(ehash)
    elif cmd == 'FETCHED':
        entry = msg[2]
        logging.info("FETCHED from %s: %s", name, hexencode(ehash))
        chainsdb.add(ehash, entry)
        newentry.append(ehash) # Writing to logorderfile after loop.
        logorder.append(ehash)
        entries_in_log.add(ehash)
        if ehash in fetch_dict:
            del fetch_dict[ehash]

def merge_fetch_parallel(args, config, localconfig):
    paths = localconfig["paths"]
    storagenodes = config["storagenodes"]
    mergedb = paths["mergedb"]
    chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains")
    logorderfile = mergedb + "/logorder"
    currentsizefile = mergedb + "/fetched"

    rand = Random()
    signal.signal(signal.SIGTERM, term)

    procs = {}
    for storagenode in storagenodes:
        name = storagenode['name']
        procs[name] = newworker(name,
                                [args, localconfig, storagenode]) + (storagenode,)

    currentsizefilecontent = ""
    # Entries in log, kept in both a set and a list.
    logorder = get_logorder(logorderfile)
    entries_in_log = set(logorder)

    # Entries to fetch, kept in both a set and a dict. The dict is
    # keyed on hashes (binary) and contains randomised lists of nodes
    # to fetch from. Note that the dict keeps entries until they're
    # successfully fetched while the set is temporary within one
    # iteration of the loop.
    fetch_set = set()
    fetch_dict = {}

    messages = []

    while procs:
        ## Poll worker processes and handle messages.
        assert not fetch_set
        newentry = []

        # Drain pipe, then process messages.
        read_worker_messages(procs, messages, args, localconfig)
        for name, msg in messages:
            process_worker_message(name, msg, fetch_dict, fetch_set, chainsdb,
                                   newentry, logorder, entries_in_log)
        messages = []

        # Commit to chains database and update 'logorder' file.
        chainsdb.commit()
        for ehash in newentry:
            add_to_logorder(logorderfile, ehash)
        fsync_logorder(logorderfile)

        ## Ask workers to fetch new entries.
        logging.debug("nof entries to fetch including entries in log: %d",
                      len(fetch_set))
        fetch_set -= entries_in_log
        logging.info("entries to fetch: %d", len(fetch_set))
        # Add entries to be fetched to fetch_dict, with the hash as
        # key and value being a cyclic iterator of list of storage
        # nodes, in randomised order. Ask next node to fetch the
        # entry.
        while fetch_set:
            e = fetch_set.pop()
            if not e in fetch_dict:
                l = list(procs.values())
                rand.shuffle(l)
                fetch_dict[e] = cycle(l)
            name, pipe, _, _ = fetch_dict[e].next()
            logging.info("asking %s to FETCH %s", name, hexencode(e))
            pipe.send(('FETCH', e))
            read_worker_messages(procs, messages, args, localconfig) # Drain pipe.

        ## Update the 'fetched' file.
        logsize = len(logorder)
        if logsize == 0:
            last_hash = ''
        else:
            last_hash = logorder[logsize - 1]
        newcontent = {"index": logsize - 1, "hash": hexencode(last_hash)}
        if newcontent != currentsizefilecontent:
            logging.info("updating 'fetched' file: %d %s", logsize - 1,
                         hexencode(last_hash))
            currentsizefilecontent = newcontent
            write_file(currentsizefile, currentsizefilecontent)

        ## Wait some if nothing to do.
        if not messages:
            sleep(1)

    return 0

def main():
    """
    If no `--mergeinterval':
      Fetch new entries from all storage nodes, in sequence, updating
      the 'logorder' file and the 'chains' database.

      Write 'fetched' to reflect how far in 'logorder' we've succesfully
      fetched and verified.

    If `--mergeinterval':
      Start one process per storage node, read their stdout for learning
      about two things: (i) new entries ready for fetching ("NEWENTRY") and
      (ii) new entries being succesfully fetched ("FETCHED").

      Write to their stdin ("FETCH") when they should fetch another entry.
      Update 'logorder' and the 'chains' database as we see new FETCHED
      messages.

      Write 'fetched' to reflect how far in 'logorder' we've succesfully
      fetched and verified.

      Keep doing this forever.

    NOTE: The point of having 'fetched' is that it can be atomically
    written while 'logorder' cannot (unless we're fine with rewriting it
    for each and every update, which we're not).

    TODO: Deduplicate some code.
    """
    args, config, localconfig = parse_args()
    paths = localconfig["paths"]
    mergedb = paths["mergedb"]
    currentsizefile = mergedb + "/fetched"
    lockfile = mergedb + "/.merge_fetch.lock"

    loginit(args, "merge_fetch.log")

    if not flock_ex_or_fail(lockfile):
        logging.critical("unable to take lock %s", lockfile)
        return 1

    create_ssl_context(cafile=paths["https_cacertfile"])

    if args.mergeinterval:
        return merge_fetch_parallel(args, config, localconfig)
    else:
        logsize, last_hash = merge_fetch_sequenced(args, config, localconfig)
        currentsize = {"index": logsize - 1, "hash": hexencode(last_hash)}
        logging.debug("writing to %s: %s", currentsizefile, currentsize)
        write_file(currentsizefile, currentsize)
        return 0

if __name__ == '__main__':
    sys.exit(main())