From bff5d58fcce0534cf4774df386ff448261b28c20 Mon Sep 17 00:00:00 2001 From: Linus Nordberg Date: Wed, 30 Nov 2016 16:47:23 +0100 Subject: Parallelise merge_dist. Also deduplicate some code. --- tools/merge_backup.py | 16 +++--- tools/merge_dist.py | 131 +++++++++++++++++++++++++++++++------------------- tools/mergetools.py | 10 +++- 3 files changed, 97 insertions(+), 60 deletions(-) (limited to 'tools') diff --git a/tools/merge_backup.py b/tools/merge_backup.py index 41b1014..cadcec7 100755 --- a/tools/merge_backup.py +++ b/tools/merge_backup.py @@ -22,7 +22,7 @@ from mergetools import chunks, backup_sendlog, get_logorder, \ get_verifiedsize, get_missingentriesforbackup, \ hexencode, setverifiedsize, sendentries_merge, verifyroot, \ get_nfetched, parse_args, perm, waitforfile, flock_ex_or_fail, \ - Status, loginit + Status, loginit, start_worker def backup_loop(nodename, nodeaddress, own_key, paths, verifiedsize, chunk): for trynumber in range(5, 0, -1): @@ -166,12 +166,11 @@ def merge_backup(args, config, localconfig, secondaries): backupargs = (secondary, localconfig, chainsdb, logorder, s, timing) if args.mergeinterval: - pipe_mine, pipe_theirs = Pipe() - p = Process(target=lambda pipe, argv: pipe.send(do_send(argv)), - args=(pipe_theirs, backupargs), - name='backup_%s' % nodename) - p.start() - procs[p] = (nodename, pipe_mine) + name = 'backup_%s' % nodename + p, pipe = start_worker(name, + lambda cpipe, argv: cpipe.send(do_send(argv)), + backupargs) + procs[p] = (nodename, pipe) else: root_hash = do_send(backupargs) update_backupfile(mergedb, nodename, tree_size, root_hash) @@ -233,7 +232,6 @@ def main(): create_ssl_context(cafile=paths["https_cacertfile"]) fetched_statinfo = waitforfile(fetched_path) - retval = 0 while True: failures = merge_backup(args, config, localconfig, nodes) if not args.mergeinterval: @@ -245,7 +243,7 @@ def main(): break fetched_statinfo = stat(fetched_path) - return retval + return 0 if __name__ == '__main__': sys.exit(main()) diff --git a/tools/merge_dist.py b/tools/merge_dist.py index d612600..bc9c676 100755 --- a/tools/merge_dist.py +++ b/tools/merge_dist.py @@ -14,11 +14,12 @@ import logging from time import sleep from base64 import b64encode, b64decode from os import stat +from multiprocessing import Process, Pipe from certtools import timing_point, create_ssl_context from mergetools import get_curpos, get_logorder, chunks, get_missingentries, \ publish_sth, sendlog, sendentries, parse_args, perm, \ get_frontend_verifiedsize, frontend_verify_entries, \ - waitforfile, flock_ex_or_fail, Status, loginit + waitforfile, flock_ex_or_fail, Status, loginit, start_worker def sendlog_helper(entries, curpos, nodename, nodeaddress, own_key, paths, statusupdates): @@ -70,12 +71,51 @@ def fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, own_key, paths) timing_point(timing, "get missing") -def merge_dist(args, localconfig, frontendnodes, timestamp): - maxwindow = localconfig.get("maxwindow", 1000) +def do_send(args, localconfig, frontendnode, logorder, sth, chainsdb, s): + timing = timing_point() paths = localconfig["paths"] own_key = (localconfig["nodename"], "%s/%s-private.pem" % (paths["privatekeys"], localconfig["nodename"])) + maxwindow = localconfig.get("maxwindow", 1000) + nodename = frontendnode["name"] + nodeaddress = "https://%s/" % frontendnode["address"] + + logging.info("distributing for node %s", nodename) + curpos = get_curpos(nodename, nodeaddress, own_key, paths) + timing_point(timing, "get curpos") + logging.info("current position %d", curpos) + + verifiedsize = \ + get_frontend_verifiedsize(nodename, nodeaddress, own_key, paths) + timing_point(timing, "get verified size") + logging.info("verified size %d", verifiedsize) + + assert verifiedsize >= curpos + + while verifiedsize < len(logorder): + uptopos = min(verifiedsize + maxwindow, len(logorder)) + + entries = [b64encode(entry) for entry in logorder[verifiedsize:uptopos]] + sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths, s) + timing_point(timing, "sendlog") + + fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing, s) + + verifiedsize = frontend_verify_entries(nodename, nodeaddress, own_key, paths, uptopos) + + logging.info("sending sth to node %s", nodename) + publishsthresult = publish_sth(nodename, nodeaddress, own_key, paths, sth) + if publishsthresult["result"] != "ok": + logging.info("publishsth: %s", publishsthresult) + sys.exit(1) + timing_point(timing, "send sth") + + if args.timing: + logging.debug("timing: merge_dist: %s", timing["deltatimes"]) + +def merge_dist(args, localconfig, frontendnodes, timestamp): + paths = localconfig["paths"] mergedb = paths["mergedb"] chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains") logorderfile = mergedb + "/logorder" @@ -89,56 +129,49 @@ def merge_dist(args, localconfig, frontendnodes, timestamp): sth = json.loads(open(sthfile, 'r').read()) except (IOError, ValueError): logging.warning("No valid STH file found in %s", sthfile) - return timestamp + return timestamp, 0 if sth['timestamp'] < timestamp: logging.warning("New STH file older than the previous one: %d < %d", - sth['timestamp'], timestamp) - return timestamp + sth['timestamp'], timestamp) + return timestamp, 0 if sth['timestamp'] == timestamp: - return timestamp + return timestamp, 0 timestamp = sth['timestamp'] logorder = get_logorder(logorderfile, sth['tree_size']) timing_point(timing, "get logorder") + procs = {} for frontendnode in frontendnodes: - nodeaddress = "https://%s/" % frontendnode["address"] nodename = frontendnode["name"] - timing = timing_point() - - logging.info("distributing for node %s", nodename) - curpos = get_curpos(nodename, nodeaddress, own_key, paths) - timing_point(timing, "get curpos") - logging.info("current position %d", curpos) - - verifiedsize = get_frontend_verifiedsize(nodename, nodeaddress, own_key, paths) - timing_point(timing, "get verified size") - logging.info("verified size %d", verifiedsize) - - assert verifiedsize >= curpos - while verifiedsize < len(logorder): - uptopos = min(verifiedsize + maxwindow, len(logorder)) - - entries = [b64encode(entry) for entry in logorder[verifiedsize:uptopos]] - sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths, s) - timing_point(timing, "sendlog") + if args.mergeinterval: + name = 'dist_%s' % nodename + p, pipe = start_worker(name, + lambda _, argv: do_send(*(argv)), + (args, localconfig, frontendnode, logorder, sth, chainsdb, s)) + procs[p] = (nodename, pipe) + else: + do_send(args, localconfig, frontendnode, logorder, sth, chainsdb, s) - fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing, s) + if not args.mergeinterval: + return timestamp, 0 - verifiedsize = frontend_verify_entries(nodename, nodeaddress, own_key, paths, uptopos) + failures = 0 + while True: + for p in list(procs): + if not p.is_alive(): + p.join() + nodename, _ = procs[p] + if p.exitcode != 0: + logging.warning("%s failure: %d", nodename, p.exitcode) + failures += 1 + del procs[p] + if not procs: + break + sleep(1) - logging.info("sending sth to node %s", nodename) - publishsthresult = publish_sth(nodename, nodeaddress, own_key, paths, sth) - if publishsthresult["result"] != "ok": - logging.info("publishsth: %s", publishsthresult) - sys.exit(1) - timing_point(timing, "send sth") - - if args.timing: - logging.debug("timing: merge_dist: %s", timing["deltatimes"]) - - return timestamp + return timestamp, failures def main(): """ @@ -146,12 +179,12 @@ def main(): Distribute missing entries and the STH to all frontend nodes. - If `--mergeinterval', wait until 'sth' is updated and read it and - start distributing again. + If `--mergeinterval', start over again. """ args, config, localconfig = parse_args() paths = localconfig["paths"] mergedb = paths["mergedb"] + sth_path = localconfig["paths"]["mergedb"] + "/sth" lockfile = mergedb + "/.merge_dist.lock" timestamp = 0 @@ -166,20 +199,18 @@ def main(): else: nodes = [n for n in config["frontendnodes"] if n["name"] in args.node] - if args.mergeinterval is None: - if merge_dist(args, localconfig, nodes, timestamp) < 0: - return 1 - return 0 - - sth_path = localconfig["paths"]["mergedb"] + "/sth" sth_statinfo = waitforfile(sth_path) while True: - if merge_dist(args, localconfig, nodes, timestamp) < 0: - return 1 + timestamp, failures = merge_dist(args, localconfig, nodes, timestamp) + if not args.mergeinterval: + break sth_statinfo_old = sth_statinfo while sth_statinfo == sth_statinfo_old: - sleep(args.mergeinterval / 30) + sleep(max(3, args.mergeinterval / 10)) + if failures > 0: + break sth_statinfo = stat(sth_path) + return 0 if __name__ == '__main__': diff --git a/tools/mergetools.py b/tools/mergetools.py index 109e9d4..beb41bf 100644 --- a/tools/mergetools.py +++ b/tools/mergetools.py @@ -484,13 +484,21 @@ def flock_ex_or_fail(path): return False return True +def start_worker(name, fun, args): + pipe_mine, pipe_theirs = multiprocessing.Pipe() + p = multiprocessing.Process(target=fun, + args=(pipe_theirs, args), + name=name) + p.start() + return (p, pipe_mine) + def terminate_child_procs(): for p in multiprocessing.active_children(): #print >>sys.stderr, "DEBUG: terminating pid", p.pid p.terminate() def loginit(args, fname): - logfmt = '%(asctime)s %(message)s' + logfmt = '%(asctime)s %(name)s %(levelname)s %(message)s' loglevel = getattr(logging, args.loglevel.upper()) if args.logdir is None: logging.basicConfig(format=logfmt, level=loglevel) -- cgit v1.1