diff options
Diffstat (limited to 'tools/merge_dist.py')
-rwxr-xr-x | tools/merge_dist.py | 141 |
1 files changed, 86 insertions, 55 deletions
diff --git a/tools/merge_dist.py b/tools/merge_dist.py index a9b5c60..7a13bfa 100755 --- a/tools/merge_dist.py +++ b/tools/merge_dist.py @@ -9,12 +9,15 @@ # import sys import json +import logging from time import sleep from base64 import b64encode, b64decode +from os import stat from certtools import timing_point, \ create_ssl_context from mergetools import get_curpos, get_logorder, chunks, get_missingentries, \ - sendsth, sendlog, sendentry, parse_args, perm + sendsth, sendlog, sendentry, parse_args, perm, waitforfile, \ + flock_ex_or_fail, Status def merge_dist(args, localconfig, frontendnodes, timestamp): paths = localconfig["paths"] @@ -25,17 +28,19 @@ def merge_dist(args, localconfig, frontendnodes, timestamp): chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains") logorderfile = mergedb + "/logorder" sthfile = mergedb + "/sth" + statusfile = mergedb + "/merge_dist.status" + s = Status(statusfile) create_ssl_context(cafile=paths["https_cacertfile"]) timing = timing_point() try: sth = json.loads(open(sthfile, 'r').read()) except (IOError, ValueError): - print >>sys.stderr, "No valid STH file found in", sthfile + logging.warning("No valid STH file found in %s", sthfile) return timestamp if sth['timestamp'] < timestamp: - print >>sys.stderr, "New STH file older than the previous one:", \ - sth['timestamp'], "<", timestamp + logging.warning("New STH file older than the previous one: %d < %d", + sth['timestamp'], timestamp) return timestamp if sth['timestamp'] == timestamp: return timestamp @@ -49,96 +54,122 @@ def merge_dist(args, localconfig, frontendnodes, timestamp): nodename = frontendnode["name"] timing = timing_point() - print >>sys.stderr, "distributing for node", nodename - sys.stderr.flush() - curpos = get_curpos(nodename, nodeaddress, own_key, paths) + logging.info("distributing for node %s", nodename) + ok, curpos = get_curpos(nodename, nodeaddress, own_key, paths) + if not ok: + logging.error("get_curpos: %s", curpos) + continue timing_point(timing, "get curpos") - print >>sys.stderr, "current position", curpos - sys.stderr.flush() + logging.info("current position %d", curpos) entries = [b64encode(entry) for entry in logorder[curpos:]] - print >>sys.stderr, "sending log:", - sys.stderr.flush() + logging.info("sending log: %d", len(entries)) + sendlog_fail = False for chunk in chunks(entries, 1000): for trynumber in range(5, 0, -1): - sendlogresult = sendlog(nodename, nodeaddress, - own_key, paths, - {"start": curpos, "hashes": chunk}) - if sendlogresult == None: - if trynumber == 1: - sys.exit(1) - sleep(10) - print >>sys.stderr, "tries left:", trynumber - sys.stderr.flush() - continue + ok, sendlogresult = sendlog(nodename, nodeaddress, own_key, paths, + {"start": curpos, "hashes": chunk}) + if ok: + break + sleep(10) + logging.warning("tries left: %d", trynumber) + if not ok or sendlogresult.get("result") != "ok": + logging.error("sendlog: %s", sendlogresult) + sendlog_fail = True break - if sendlogresult["result"] != "ok": - print >>sys.stderr, "sendlog:", sendlogresult - sys.exit(1) curpos += len(chunk) - print >>sys.stderr, curpos, - sys.stderr.flush() - print >>sys.stderr + s.status("INFO: sendlog %d" % curpos) timing_point(timing, "sendlog") - print >>sys.stderr, "log sent" - sys.stderr.flush() + if sendlog_fail: + logging.error("sendlog failed for %s", nodename) + continue missingentries = get_missingentries(nodename, nodeaddress, own_key, paths) timing_point(timing, "get missing") - print >>sys.stderr, "missing entries:", len(missingentries) - sys.stderr.flush() + logging.info("sending missing entries: %d", len(missingentries)) sent_entries = 0 - print >>sys.stderr, "send missing entries", - sys.stderr.flush() + sendentry_fail = False for missingentry in missingentries: ehash = b64decode(missingentry) - sendentryresult = sendentry(nodename, nodeaddress, own_key, paths, - chainsdb.get(ehash), ehash) - if sendentryresult["result"] != "ok": - print >>sys.stderr, "sendentry:", sendentryresult - sys.exit(1) + ok, sendentryresult = sendentry(nodename, nodeaddress, own_key, paths, + chainsdb.get(ehash), ehash) + if not ok or sendentryresult.get("result") != "ok": + logging.error("sendentry: %s", sendentryresult) + sendentry_fail = True + break sent_entries += 1 if sent_entries % 1000 == 0: - print >>sys.stderr, sent_entries, - sys.stderr.flush() - print >>sys.stderr - sys.stderr.flush() + s.status("INFO: sendentry %d" % sent_entries) timing_point(timing, "send missing") + if sendentry_fail: + logging.error("sendentry failed for %s", nodename) + continue - print >>sys.stderr, "sending sth to node", nodename - sys.stderr.flush() - sendsthresult = sendsth(nodename, nodeaddress, own_key, paths, sth) - if sendsthresult["result"] != "ok": - print >>sys.stderr, "sendsth:", sendsthresult - sys.exit(1) + logging.info("sending sth to node %s", nodename) + sendsth_fail = False + ok, sendsthresult = sendsth(nodename, nodeaddress, own_key, paths, sth) + if not ok or sendsthresult.get("result") != "ok": + logging.error("sendsth: %s", sendsthresult) + sendsth_fail = True timing_point(timing, "send sth") if args.timing: - print >>sys.stderr, "timing: merge_dist:", timing["deltatimes"] - sys.stderr.flush() + logging.debug("timing: merge_dist: %s", timing["deltatimes"]) + + if sendsth_fail: + logging.error("sendsth failed for %s", nodename) + continue return timestamp def main(): """ + Wait until 'sth' exists and read it. + Distribute missing entries and the STH to all frontend nodes. + + If `--mergeinterval', wait until 'sth' is updated and read it and + start distributing again. """ args, config, localconfig = parse_args() + paths = localconfig["paths"] + mergedb = paths["mergedb"] + lockfile = mergedb + "/.merge_dist.lock" timestamp = 0 + loglevel = getattr(logging, args.loglevel.upper()) + if args.mergeinterval is None: + logging.basicConfig(level=loglevel) + else: + logging.basicConfig(filename=args.logdir + "/merge_dist.log", + level=loglevel) + + if not flock_ex_or_fail(lockfile): + logging.critical("unable to take lock %s", lockfile) + return 1 + if len(args.node) == 0: nodes = config["frontendnodes"] else: nodes = [n for n in config["frontendnodes"] if n["name"] in args.node] + if args.mergeinterval is None: + if merge_dist(args, localconfig, nodes, timestamp) < 0: + return 1 + return 0 + + sth_path = localconfig["paths"]["mergedb"] + "/sth" + sth_statinfo = waitforfile(sth_path) while True: - timestamp = merge_dist(args, localconfig, nodes, timestamp) - if args.interval is None: - break - print >>sys.stderr, "sleeping", args.interval, "seconds" - sleep(args.interval) + if merge_dist(args, localconfig, nodes, timestamp) < 0: + return 1 + sth_statinfo_old = sth_statinfo + while sth_statinfo == sth_statinfo_old: + sleep(args.mergeinterval / 30) + sth_statinfo = stat(sth_path) + return 0 if __name__ == '__main__': sys.exit(main()) |