#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright (c) 2014-2015, NORDUnet A/S. # See LICENSE for licensing information. # # Distribute the 'sth' file and all missing entries to all frontend nodes. # See catlfish/doc/merge.txt for more about the merge process. # import sys import json import logging from time import sleep from base64 import b64encode, b64decode from os import stat from certtools import timing_point, \ create_ssl_context from mergetools import get_curpos, get_logorder, chunks, get_missingentries, \ sendsth, sendlog, sendentry, parse_args, perm, waitforfile, \ flock_ex_or_fail, Status def merge_dist(args, localconfig, frontendnodes, timestamp): paths = localconfig["paths"] own_key = (localconfig["nodename"], "%s/%s-private.pem" % (paths["privatekeys"], localconfig["nodename"])) mergedb = paths["mergedb"] chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains") logorderfile = mergedb + "/logorder" sthfile = mergedb + "/sth" statusfile = mergedb + "/merge_dist.status" s = Status(statusfile) create_ssl_context(cafile=paths["https_cacertfile"]) timing = timing_point() try: sth = json.loads(open(sthfile, 'r').read()) except (IOError, ValueError): logging.warning("No valid STH file found in %s", sthfile) return timestamp if sth['timestamp'] < timestamp: logging.warning("New STH file older than the previous one: %d < %d", sth['timestamp'], timestamp) return timestamp if sth['timestamp'] == timestamp: return timestamp timestamp = sth['timestamp'] logorder = get_logorder(logorderfile, sth['tree_size']) timing_point(timing, "get logorder") for frontendnode in frontendnodes: nodeaddress = "https://%s/" % frontendnode["address"] nodename = frontendnode["name"] timing = timing_point() logging.info("distributing for node %s", nodename) ok, curpos = get_curpos(nodename, nodeaddress, own_key, paths) if not ok: logging.error("get_curpos: %s", curpos) continue timing_point(timing, "get curpos") logging.info("current position %d", curpos) entries = [b64encode(entry) for entry in logorder[curpos:]] logging.info("sending log: %d", len(entries)) sendlog_fail = False for chunk in chunks(entries, 1000): for trynumber in range(5, 0, -1): ok, sendlogresult = sendlog(nodename, nodeaddress, own_key, paths, {"start": curpos, "hashes": chunk}) if ok: break sleep(10) logging.warning("tries left: %d", trynumber) if not ok or sendlogresult.get("result") != "ok": logging.error("sendlog: %s", sendlogresult) sendlog_fail = True break curpos += len(chunk) s.status("INFO: sendlog %d" % curpos) timing_point(timing, "sendlog") if sendlog_fail: logging.error("sendlog failed for %s", nodename) continue missingentries = get_missingentries(nodename, nodeaddress, own_key, paths) timing_point(timing, "get missing") logging.info("sending missing entries: %d", len(missingentries)) sent_entries = 0 sendentry_fail = False for missingentry in missingentries: ehash = b64decode(missingentry) ok, sendentryresult = sendentry(nodename, nodeaddress, own_key, paths, chainsdb.get(ehash), ehash) if not ok or sendentryresult.get("result") != "ok": logging.error("sendentry: %s", sendentryresult) sendentry_fail = True break sent_entries += 1 if sent_entries % 1000 == 0: s.status("INFO: sendentry %d" % sent_entries) timing_point(timing, "send missing") if sendentry_fail: logging.error("sendentry failed for %s", nodename) continue logging.info("sending sth to node %s", nodename) sendsth_fail = False ok, sendsthresult = sendsth(nodename, nodeaddress, own_key, paths, sth) if not ok or sendsthresult.get("result") != "ok": logging.error("sendsth: %s", sendsthresult) sendsth_fail = True timing_point(timing, "send sth") if args.timing: logging.debug("timing: merge_dist: %s", timing["deltatimes"]) if sendsth_fail: logging.error("sendsth failed for %s", nodename) continue return timestamp def main(): """ Wait until 'sth' exists and read it. Distribute missing entries and the STH to all frontend nodes. If `--mergeinterval', wait until 'sth' is updated and read it and start distributing again. """ args, config, localconfig = parse_args() paths = localconfig["paths"] mergedb = paths["mergedb"] lockfile = mergedb + "/.merge_dist.lock" timestamp = 0 loglevel = getattr(logging, args.loglevel.upper()) if args.mergeinterval is None: logging.basicConfig(level=loglevel) else: logging.basicConfig(filename=args.logdir + "/merge_dist.log", level=loglevel) if not flock_ex_or_fail(lockfile): logging.critical("unable to take lock %s", lockfile) return 1 if len(args.node) == 0: nodes = config["frontendnodes"] else: nodes = [n for n in config["frontendnodes"] if n["name"] in args.node] if args.mergeinterval is None: if merge_dist(args, localconfig, nodes, timestamp) < 0: return 1 return 0 sth_path = localconfig["paths"]["mergedb"] + "/sth" sth_statinfo = waitforfile(sth_path) while True: if merge_dist(args, localconfig, nodes, timestamp) < 0: return 1 sth_statinfo_old = sth_statinfo while sth_statinfo == sth_statinfo_old: sleep(args.mergeinterval / 30) sth_statinfo = stat(sth_path) return 0 if __name__ == '__main__': sys.exit(main())