#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright (c) 2014-2015, NORDUnet A/S. # See LICENSE for licensing information. # # Distribute the 'sth' file and all missing entries to all frontend nodes. # See catlfish/doc/merge.txt for more about the merge process. # import sys import json import requests import logging from time import sleep from base64 import b64encode, b64decode from os import stat from multiprocessing import Process, Pipe from certtools import timing_point, create_ssl_context from mergetools import get_curpos, get_logorder, chunks, get_missingentries, \ publish_sth, sendlog, sendentries, parse_args, perm, \ get_frontend_verifiedsize, frontend_verify_entries, \ waitforfile, flock_ex_or_fail, Status, loginit, start_worker def sendlog_helper(entries, curpos, nodename, nodeaddress, own_key, paths, statusupdates): logging.info("sending log") for chunk in chunks(entries, 1000): for trynumber in range(5, 0, -1): sendlogresult = sendlog(nodename, nodeaddress, own_key, paths, {"start": curpos, "hashes": chunk}) if sendlogresult == None: if trynumber == 1: sys.exit(1) sleep(10) logging.warning("tries left: %d", trynumber) continue break if sendlogresult["result"] != "ok": logging.error("sendlog: %s", sendlogresult) sys.exit(1) curpos += len(chunk) statusupdates.status("PROG sending log: %d" % curpos) logging.info("log sent") def fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing, statusupdates): missingentries = get_missingentries(nodename, nodeaddress, own_key, paths) timing_point(timing, "get missing") while missingentries: logging.info("about to send %d missing entries", len(missingentries)) sent_entries = 0 with requests.sessions.Session() as session: for missingentry_chunk in chunks(missingentries, 100): missingentry_hashes = [b64decode(missingentry) for missingentry in missingentry_chunk] hashes_and_entries = [(ehash, chainsdb.get(ehash)) for ehash in missingentry_hashes] sendentryresult = sendentries(nodename, nodeaddress, own_key, paths, hashes_and_entries, session) if sendentryresult["result"] != "ok": logging.error("sendentries: %s", sendentryresult) sys.exit(1) sent_entries += len(missingentry_hashes) statusupdates.status( "PROG sending missing entries: %d" % sent_entries) timing_point(timing, "send missing") missingentries = get_missingentries(nodename, nodeaddress, own_key, paths) timing_point(timing, "get missing") def do_send(args, localconfig, frontendnode, logorder, sth, chainsdb, s): timing = timing_point() paths = localconfig["paths"] own_key = (localconfig["nodename"], "%s/%s-private.pem" % (paths["privatekeys"], localconfig["nodename"])) maxwindow = localconfig.get("maxwindow", 1000) nodename = frontendnode["name"] nodeaddress = "https://%s/" % frontendnode["address"] logging.info("distributing for node %s", nodename) curpos = get_curpos(nodename, nodeaddress, own_key, paths) timing_point(timing, "get curpos") logging.info("current position %d", curpos) verifiedsize = \ get_frontend_verifiedsize(nodename, nodeaddress, own_key, paths) timing_point(timing, "get verified size") logging.info("verified size %d", verifiedsize) assert verifiedsize >= curpos while verifiedsize < len(logorder): uptopos = min(verifiedsize + maxwindow, len(logorder)) entries = [b64encode(entry) for entry in logorder[verifiedsize:uptopos]] sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths, s) timing_point(timing, "sendlog") fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing, s) verifiedsize = frontend_verify_entries(nodename, nodeaddress, own_key, paths, uptopos) logging.info("sending sth to node %s", nodename) publishsthresult = publish_sth(nodename, nodeaddress, own_key, paths, sth) if publishsthresult["result"] != "ok": logging.info("publishsth: %s", publishsthresult) sys.exit(1) timing_point(timing, "send sth") if args.timing: logging.debug("timing: merge_dist: %s", timing["deltatimes"]) def merge_dist_sequenced(args, localconfig, frontendnodes, chainsdb, s): paths = localconfig["paths"] mergedb = paths["mergedb"] logorderfile = mergedb + "/logorder" sthfile = mergedb + "/sth" timestamp = 0 timing = timing_point() try: sth = json.loads(open(sthfile, 'r').read()) except (IOError, ValueError): logging.warning("No valid STH file found in %s", sthfile) return timestamp if sth['timestamp'] < timestamp: logging.warning("New STH file older than the previous one: %d < %d", sth['timestamp'], timestamp) return timestamp if sth['timestamp'] == timestamp: return timestamp timestamp = sth['timestamp'] logorder = get_logorder(logorderfile, sth['tree_size']) timing_point(timing, "get logorder") for frontendnode in frontendnodes: do_send(args, localconfig, frontendnode, logorder, sth, chainsdb, s) return timestamp def dist_worker(_, argv): args, localconfig, frontendnode, chainsdb, s = argv paths = localconfig["paths"] mergedb = paths["mergedb"] sthfile = mergedb + "/sth" logorderfile = mergedb + "/logorder" nodename = frontendnode["name"] wait = max(3, args.mergeinterval / 10) timestamp = 0 while True: try: sth = json.loads(open(sthfile, 'r').read()) except (IOError, ValueError): logging.error("%s: No valid STH file found in %s", nodename, sthfile) sleep(wait) continue if sth['timestamp'] < timestamp: logging.error( "%s: New STH file older than the previous one: %d < %d", nodename, sth['timestamp'], timestamp) sleep(wait) continue if sth['timestamp'] == timestamp: logging.info( "%s: sth still at %d (%d), sleeping %s seconds", nodename, sth['tree_size'], timestamp, wait) sleep(wait) continue timestamp = sth['timestamp'] logorder = get_logorder(logorderfile, sth['tree_size']) do_send(args, localconfig, frontendnode, logorder, sth, chainsdb, s) def merge_dist_parallel(args, localconfig, frontendnodes, chainsdb, s): procs = {} for frontendnode in frontendnodes: nodename = frontendnode["name"] procname = 'dist_%s' % nodename p, pipe = start_worker(procname, dist_worker, (args, localconfig, frontendnode, chainsdb, s)) procs[p] = (frontendnode, pipe) while True: for p in list(procs): if not p.is_alive(): p.join() frontendnode, _ = procs[p] nodename = frontendnode["name"] logging.warning("%s exited with %d, restarting", nodename, p.exitcode) procname = 'dist_%s' % nodename newproc, pipe = \ start_worker(procname, dist_worker, (args, localconfig, frontendnode, chainsdb, s)) procs[p] = (frontendnode, pipe) sleep(1) return -1 def main(): """ Distribute missing entries and the STH to all frontend nodes, in parallel if `--mergeinterval'. If `--mergeinterval', re-read 'sth' when it changes and keep distributing. """ args, config, localconfig = parse_args() paths = localconfig["paths"] mergedb = paths["mergedb"] chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains", write_enabled=False) lockfile = mergedb + "/.merge_dist.lock" loginit(args, "merge_dist.log") if not flock_ex_or_fail(lockfile): logging.critical("unable to take lock %s", lockfile) return 1 statusfile = mergedb + "/merge_dist.status" s = Status(statusfile) create_ssl_context(cafile=paths["https_cacertfile"]) if len(args.node) == 0: nodes = config["frontendnodes"] else: nodes = [n for n in config["frontendnodes"] if n["name"] in args.node] if args.mergeinterval: return merge_dist_parallel(args, localconfig, nodes, chainsdb, s) else: merge_dist_sequenced(args, localconfig, nodes, chainsdb, s) if __name__ == '__main__': sys.exit(main())