diff options
Diffstat (limited to 'tools')
-rw-r--r-- | tools/certtools.py | 8 | ||||
-rwxr-xr-x | tools/merge | 73 | ||||
-rwxr-xr-x | tools/merge_backup.py | 121 | ||||
-rwxr-xr-x | tools/merge_dist.py | 117 | ||||
-rwxr-xr-x | tools/merge_fetch.py | 247 | ||||
-rwxr-xr-x | tools/merge_sth.py | 81 | ||||
-rw-r--r-- | tools/mergetools.py | 100 | ||||
-rwxr-xr-x | tools/testcase1.py | 16 |
8 files changed, 584 insertions, 179 deletions
diff --git a/tools/certtools.py b/tools/certtools.py index 0009d5d..e9ee99b 100644 --- a/tools/certtools.py +++ b/tools/certtools.py @@ -11,12 +11,12 @@ import struct import sys import hashlib import ecdsa -import datetime import cStringIO import zipfile import shutil import requests import warnings +from datetime import datetime from certkeys import publickeys @@ -336,8 +336,8 @@ def check_sth_signature(baseurl, sth, publickey=None): signature_type = struct.pack(">b", 1) timestamp = struct.pack(">Q", sth["timestamp"]) tree_size = struct.pack(">Q", sth["tree_size"]) - hash = base64.decodestring(sth["sha256_root_hash"]) - tree_head = version + signature_type + timestamp + tree_size + hash + ehash = base64.decodestring(sth["sha256_root_hash"]) + tree_head = version + signature_type + timestamp + tree_size + ehash check_signature(baseurl, signature, tree_head, publickey=publickey) @@ -426,7 +426,7 @@ def get_leaf_hash(merkle_tree_leaf): return leaf_hash.digest() def timing_point(timer_dict=None, name=None): - t = datetime.datetime.now() + t = datetime.now() if timer_dict: starttime = timer_dict["lasttime"] stoptime = t diff --git a/tools/merge b/tools/merge index b5a50d5..4ba0438 100755 --- a/tools/merge +++ b/tools/merge @@ -1,10 +1,69 @@ -#! /bin/sh +#! /usr/bin/env python +"""merge""" -set -o errexit +import os +import sys +import signal +from time import sleep +from mergetools import parse_args, terminate_child_procs +from multiprocessing import Process +import merge_fetch, merge_backup, merge_sth, merge_dist -BINDIR=$(dirname $0) +def run_once(): + """Merge once.""" + ret = merge_fetch.main() + if ret == 0: + ret = merge_backup.main() + if ret == 0: + ret = merge_sth.main() + if ret == 0: + ret = merge_dist.main() + return ret -$BINDIR/merge_fetch.py "$@" -$BINDIR/merge_backup.py "$@" -$BINDIR/merge_sth.py "$@" -$BINDIR/merge_dist.py "$@" +def term(signal, arg): + terminate_child_procs() + sys.exit(1) + +def run_continuously(pidfilepath): + """Run continuously.""" + parts = (('fetch', merge_fetch), + ('backup', merge_backup), + ('sth', merge_sth), + ('dist', merge_dist)) + procs = {} + for part, mod in parts: + procs[part] = Process(target=mod.main, name='merge_%s' % part) + procs[part].start() + #print >>sys.stderr, "DEBUG:", part, "started, pid", procs[part].pid + + if pidfilepath: + open(pidfilepath, 'w').write(str(os.getpid()) + '\n') + + signal.signal(signal.SIGTERM, term) + retval = 0 + keep_going = True + while keep_going: + sleep(1) + for name, p in procs.items(): + if not p.is_alive(): + print >>sys.stderr, "\nERROR:", name, "process is gone; exiting" + retval = 1 # Fail. + keep_going = False + break + + terminate_child_procs() + return retval + +def main(): + """Main""" + args, _, _ = parse_args() + + if args.mergeinterval is None: + ret = run_once() + else: + ret = run_continuously(args.pidfile) + + return ret + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tools/merge_backup.py b/tools/merge_backup.py index e7cce26..f25b22a 100755 --- a/tools/merge_backup.py +++ b/tools/merge_backup.py @@ -8,17 +8,19 @@ # See catlfish/doc/merge.txt for more about the merge process. # import sys -import base64 import select import requests +import errno +import logging from time import sleep from base64 import b64encode, b64decode +from os import stat from certtools import timing_point, build_merkle_tree, write_file, \ create_ssl_context from mergetools import chunks, backup_sendlog, get_logorder, \ get_verifiedsize, get_missingentriesforbackup, \ hexencode, setverifiedsize, sendentries_merge, verifyroot, \ - get_nfetched, parse_args, perm + get_nfetched, parse_args, perm, waitforfile, flock_ex_or_fail, Status def backup_loop(nodename, nodeaddress, own_key, paths, verifiedsize, chunk): for trynumber in range(5, 0, -1): @@ -29,57 +31,49 @@ def backup_loop(nodename, nodeaddress, own_key, paths, verifiedsize, chunk): if trynumber == 1: return None select.select([], [], [], 10.0) - print >>sys.stderr, "tries left:", trynumber - sys.stderr.flush() + logging.info("tries left: %d", trynumber) continue return sendlogresult sendlog_discover_chunksize = 100000 -def sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths): - print >>sys.stderr, "sending log:", - sys.stderr.flush() +def sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths, + statusupdates): + logging.info("sending log") for chunk in chunks(entries, 1000): sendlogresult = backup_loop(nodename, nodeaddress, own_key, paths, verifiedsize, chunk) if sendlogresult == None: sys.exit(1) if sendlogresult["result"] != "ok": - print >>sys.stderr, "backup_sendlog:", sendlogresult + logging.error("backup_sendlog: %s", sendlogresult) sys.exit(1) verifiedsize += len(chunk) - print >>sys.stderr, verifiedsize, - sys.stderr.flush() - print >>sys.stderr - print >>sys.stderr, "log sent" - sys.stderr.flush() + statusupdates.status("PROG sending log: %d" % verifiedsize) + logging.info("log sent") -def fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing): +def fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, + timing, statusupdates): missingentries = get_missingentriesforbackup(nodename, nodeaddress, own_key, paths) timing_point(timing, "get missing") while missingentries: - print >>sys.stderr, "missing entries:", len(missingentries) - sys.stderr.flush() + logging.info("about to send %d missing entries", len(missingentries)) fetched_entries = 0 - print >>sys.stderr, "sending missing entries", - sys.stderr.flush() with requests.sessions.Session() as session: for missingentry_chunk in chunks(missingentries, 100): - missingentry_hashes = [base64.b64decode(missingentry) for missingentry in missingentry_chunk] - hashes_and_entries = [(hash, chainsdb.get(hash)) for hash in missingentry_hashes] + missingentry_hashes = [b64decode(missingentry) for missingentry in missingentry_chunk] + hashes_and_entries = [(ehash, chainsdb.get(ehash)) for ehash in missingentry_hashes] sendentryresult = sendentries_merge(nodename, nodeaddress, own_key, paths, hashes_and_entries, session) if sendentryresult["result"] != "ok": - print >>sys.stderr, "sendentries_merge:", sendentryresult + logging.error("sendentries_merge: %s", sendentryresult) sys.exit(1) fetched_entries += len(missingentry_hashes) - #print >>sys.stderr, fetched_entries, - #sys.stderr.flush() - print >>sys.stderr - sys.stderr.flush() + statusupdates.status("PROG sending missing entries: %d" % + fetched_entries) timing_point(timing, "send missing") missingentries = get_missingentriesforbackup(nodename, nodeaddress, @@ -93,17 +87,16 @@ def check_root(logorder, nodename, nodeaddress, own_key, paths, tree_size, timin verifyrootresult = verifyroot(nodename, nodeaddress, own_key, paths, tree_size) if verifyrootresult["result"] != "ok": - print >>sys.stderr, "verifyroot:", verifyrootresult + logging.error("verifyroot: %s", verifyrootresult) sys.exit(1) - secondary_root_hash = base64.b64decode(verifyrootresult["root_hash"]) + secondary_root_hash = b64decode(verifyrootresult["root_hash"]) if root_hash != secondary_root_hash: - print >>sys.stderr, "secondary root hash was", \ - hexencode(secondary_root_hash) - print >>sys.stderr, " expected", hexencode(root_hash) + logging.error("secondary root hash was %s while expected was %s", + hexencode(secondary_root_hash), hexencode(root_hash)) sys.exit(1) timing_point(timing, "verifyroot") return root_hash - + def merge_backup(args, config, localconfig, secondaries): maxwindow = localconfig.get("maxwindow", 1000) paths = localconfig["paths"] @@ -114,6 +107,8 @@ def merge_backup(args, config, localconfig, secondaries): chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains") logorderfile = mergedb + "/logorder" currentsizefile = mergedb + "/fetched" + statusfile = mergedb + "/merge_backup.status" + s = Status(statusfile) timing = timing_point() nfetched = get_nfetched(currentsizefile, logorderfile) @@ -128,12 +123,10 @@ def merge_backup(args, config, localconfig, secondaries): nodeaddress = "https://%s/" % secondary["address"] nodename = secondary["name"] timing = timing_point() - print >>sys.stderr, "backing up to node", nodename - sys.stderr.flush() + logging.info("backing up to node %s", nodename) verifiedsize = get_verifiedsize(nodename, nodeaddress, own_key, paths) timing_point(timing, "get verified size") - print >>sys.stderr, "verified size", verifiedsize - sys.stderr.flush() + logging.info("verified size %d", verifiedsize) if verifiedsize == tree_size: root_hash = check_root(logorder, nodename, nodeaddress, own_key, paths, tree_size, timing) @@ -142,10 +135,10 @@ def merge_backup(args, config, localconfig, secondaries): uptopos = min(verifiedsize + maxwindow, tree_size) entries = [b64encode(entry) for entry in logorder[verifiedsize:uptopos]] - sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths) + sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths, s) timing_point(timing, "sendlog") - fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing) + fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing, s) root_hash = check_root(logorder, nodename, nodeaddress, own_key, paths, uptopos, timing) @@ -155,29 +148,48 @@ def merge_backup(args, config, localconfig, secondaries): backuppath = mergedb + "/verified." + nodename backupdata = {"tree_size": tree_size, "sha256_root_hash": hexencode(root_hash)} - #print >>sys.stderr, "DEBUG: writing to", backuppath, ":", backupdata + logging.debug("writing to %s: %s", backuppath, backupdata) write_file(backuppath, backupdata) if args.timing: - print >>sys.stderr, "timing: merge_backup:", timing["deltatimes"] - sys.stderr.flush() + logging.debug("timing: merge_backup: %s", timing["deltatimes"]) + + return 0 def main(): """ - Read logorder file up until what's indicated by fetched file and - build the tree. + Wait until 'fetched' exists and read it. + + Read 'logorder' up until what's indicated by 'fetched' and build the + tree. Distribute entries to all secondaries, write tree size and tree head - to backup.<secondary> files as each secondary is verified to have + to 'backup.<secondary>' files as each secondary is verified to have the entries. - Sleep some and start over. + If `--mergeinterval', wait until 'fetched' is updated and read it + and start over from the point where 'logorder' is read. """ args, config, localconfig = parse_args() + paths = localconfig["paths"] + mergedb = paths["mergedb"] + lockfile = mergedb + "/.merge_backup.lock" + fetched_path = mergedb + "/fetched" + + loglevel = getattr(logging, args.loglevel.upper()) + if args.mergeinterval is None: + logging.basicConfig(level=loglevel) + else: + logging.basicConfig(filename=args.logdir + "/merge_backup.log", + level=loglevel) + + if not flock_ex_or_fail(lockfile): + logging.critical("unable to take lock %s", lockfile) + return 1 + all_secondaries = \ [n for n in config.get('mergenodes', []) if \ n['name'] != config['primarymergenode']] - paths = localconfig["paths"] create_ssl_context(cafile=paths["https_cacertfile"]) if len(args.node) == 0: @@ -185,12 +197,21 @@ def main(): else: nodes = [n for n in all_secondaries if n["name"] in args.node] + if args.mergeinterval is None: + return merge_backup(args, config, localconfig, nodes) + + fetched_statinfo = waitforfile(fetched_path) + while True: - merge_backup(args, config, localconfig, nodes) - if args.interval is None: - break - print >>sys.stderr, "sleeping", args.interval, "seconds" - sleep(args.interval) + err = merge_backup(args, config, localconfig, nodes) + if err: + return err + fetched_statinfo_old = fetched_statinfo + while fetched_statinfo == fetched_statinfo_old: + sleep(max(3, args.mergeinterval / 10)) + fetched_statinfo = stat(fetched_path) + + return 0 if __name__ == '__main__': sys.exit(main()) diff --git a/tools/merge_dist.py b/tools/merge_dist.py index 6582eff..ffddc25 100755 --- a/tools/merge_dist.py +++ b/tools/merge_dist.py @@ -9,19 +9,20 @@ # import sys import json -import base64 import requests +import logging from time import sleep from base64 import b64encode, b64decode -from certtools import timing_point, \ - create_ssl_context +from os import stat +from certtools import timing_point, create_ssl_context from mergetools import get_curpos, get_logorder, chunks, get_missingentries, \ - publish_sth, sendlog, sendentries, parse_args, perm, get_frontend_verifiedsize, \ - frontend_verify_entries + publish_sth, sendlog, sendentries, parse_args, perm, \ + get_frontend_verifiedsize, frontend_verify_entries, \ + waitforfile, flock_ex_or_fail, Status -def sendlog_helper(entries, curpos, nodename, nodeaddress, own_key, paths): - print >>sys.stderr, "sending log:", - sys.stderr.flush() +def sendlog_helper(entries, curpos, nodename, nodeaddress, own_key, paths, + statusupdates): + logging.info("sending log") for chunk in chunks(entries, 1000): for trynumber in range(5, 0, -1): sendlogresult = sendlog(nodename, nodeaddress, @@ -31,53 +32,43 @@ def sendlog_helper(entries, curpos, nodename, nodeaddress, own_key, paths): if trynumber == 1: sys.exit(1) sleep(10) - print >>sys.stderr, "tries left:", trynumber - sys.stderr.flush() + logging.warning("tries left: %d", trynumber) continue break if sendlogresult["result"] != "ok": - print >>sys.stderr, "sendlog:", sendlogresult + logging.error("sendlog: %s", sendlogresult) sys.exit(1) curpos += len(chunk) - print >>sys.stderr, curpos, - sys.stderr.flush() - print >>sys.stderr - print >>sys.stderr, "log sent" - sys.stderr.flush() + statusupdates.status("PROG sending log: %d" % curpos) + logging.info("log sent") -def fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing): +def fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, + timing, statusupdates): missingentries = get_missingentries(nodename, nodeaddress, own_key, paths) timing_point(timing, "get missing") while missingentries: - print >>sys.stderr, "missing entries:", len(missingentries) - sys.stderr.flush() - + logging.info("about to send %d missing entries", len(missingentries)) sent_entries = 0 - print >>sys.stderr, "sending missing entries", - sys.stderr.flush() with requests.sessions.Session() as session: for missingentry_chunk in chunks(missingentries, 100): - missingentry_hashes = [base64.b64decode(missingentry) for missingentry in missingentry_chunk] - hashes_and_entries = [(hash, chainsdb.get(hash)) for hash in missingentry_hashes] + missingentry_hashes = [b64decode(missingentry) for missingentry in missingentry_chunk] + hashes_and_entries = [(ehash, chainsdb.get(ehash)) for ehash in missingentry_hashes] sendentryresult = sendentries(nodename, nodeaddress, own_key, paths, hashes_and_entries, session) if sendentryresult["result"] != "ok": - print >>sys.stderr, "sendentries:", sendentryresult + logging.error("sendentries: %s", sendentryresult) sys.exit(1) sent_entries += len(missingentry_hashes) - print >>sys.stderr, sent_entries, - sys.stderr.flush() - print >>sys.stderr - sys.stderr.flush() + statusupdates.status( + "PROG sending missing entries: %d" % sent_entries) timing_point(timing, "send missing") missingentries = get_missingentries(nodename, nodeaddress, own_key, paths) timing_point(timing, "get missing") - def merge_dist(args, localconfig, frontendnodes, timestamp): maxwindow = localconfig.get("maxwindow", 1000) @@ -89,17 +80,19 @@ def merge_dist(args, localconfig, frontendnodes, timestamp): chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains") logorderfile = mergedb + "/logorder" sthfile = mergedb + "/sth" + statusfile = mergedb + "/merge_dist.status" + s = Status(statusfile) create_ssl_context(cafile=paths["https_cacertfile"]) timing = timing_point() try: sth = json.loads(open(sthfile, 'r').read()) except (IOError, ValueError): - print >>sys.stderr, "No valid STH file found in", sthfile + logging.warning("No valid STH file found in %s", sthfile) return timestamp if sth['timestamp'] < timestamp: - print >>sys.stderr, "New STH file older than the previous one:", \ - sth['timestamp'], "<", timestamp + logging.warning("New STH file older than the previous one: %d < %d", + sth['timestamp'], timestamp) return timestamp if sth['timestamp'] == timestamp: return timestamp @@ -113,16 +106,14 @@ def merge_dist(args, localconfig, frontendnodes, timestamp): nodename = frontendnode["name"] timing = timing_point() - print >>sys.stderr, "distributing for node", nodename - sys.stderr.flush() + logging.info("distributing for node %s", nodename) curpos = get_curpos(nodename, nodeaddress, own_key, paths) timing_point(timing, "get curpos") - print >>sys.stderr, "current position", curpos - sys.stderr.flush() + logging.info("current position %d", curpos) verifiedsize = get_frontend_verifiedsize(nodename, nodeaddress, own_key, paths) timing_point(timing, "get verified size") - print >>sys.stderr, "verified size", verifiedsize + logging.info("verified size %d", verifiedsize) assert verifiedsize >= curpos @@ -130,45 +121,71 @@ def merge_dist(args, localconfig, frontendnodes, timestamp): uptopos = min(verifiedsize + maxwindow, len(logorder)) entries = [b64encode(entry) for entry in logorder[verifiedsize:uptopos]] - sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths) + sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths, s) timing_point(timing, "sendlog") - fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing) + fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing, s) verifiedsize = frontend_verify_entries(nodename, nodeaddress, own_key, paths, uptopos) - print >>sys.stderr, "sending sth to node", nodename - sys.stderr.flush() + logging.info("sending sth to node %s", nodename) publishsthresult = publish_sth(nodename, nodeaddress, own_key, paths, sth) if publishsthresult["result"] != "ok": - print >>sys.stderr, "publishsth:", publishsthresult + logging.info("publishsth: %s", publishsthresult) sys.exit(1) timing_point(timing, "send sth") if args.timing: - print >>sys.stderr, "timing: merge_dist:", timing["deltatimes"] - sys.stderr.flush() + logging.debug("timing: merge_dist: %s", timing["deltatimes"]) return timestamp def main(): """ + Wait until 'sth' exists and read it. + Distribute missing entries and the STH to all frontend nodes. + + If `--mergeinterval', wait until 'sth' is updated and read it and + start distributing again. """ args, config, localconfig = parse_args() + paths = localconfig["paths"] + mergedb = paths["mergedb"] + lockfile = mergedb + "/.merge_dist.lock" timestamp = 0 + loglevel = getattr(logging, args.loglevel.upper()) + if args.mergeinterval is None: + logging.basicConfig(level=loglevel) + else: + logging.basicConfig(filename=args.logdir + "/merge_dist.log", + level=loglevel) + + if not flock_ex_or_fail(lockfile): + logging.critical("unable to take lock %s", lockfile) + return 1 + if len(args.node) == 0: nodes = config["frontendnodes"] else: nodes = [n for n in config["frontendnodes"] if n["name"] in args.node] + if args.mergeinterval is None: + if merge_dist(args, localconfig, nodes, timestamp) < 0: + return 1 + return 0 + + sth_path = localconfig["paths"]["mergedb"] + "/sth" + sth_statinfo = waitforfile(sth_path) while True: - timestamp = merge_dist(args, localconfig, nodes, timestamp) - if args.interval is None: - break - print >>sys.stderr, "sleeping", args.interval, "seconds" - sleep(args.interval) + if merge_dist(args, localconfig, nodes, timestamp) < 0: + return 1 + sth_statinfo_old = sth_statinfo + while sth_statinfo == sth_statinfo_old: + sleep(args.mergeinterval / 30) + sth_statinfo = stat(sth_path) + return 0 if __name__ == '__main__': sys.exit(main()) diff --git a/tools/merge_fetch.py b/tools/merge_fetch.py index 8c3a997..7e0dfd8 100755 --- a/tools/merge_fetch.py +++ b/tools/merge_fetch.py @@ -11,17 +11,24 @@ import sys import struct import subprocess import requests +import signal +import logging from time import sleep +from multiprocessing import Process, Pipe +from random import Random from mergetools import get_logorder, verify_entry, get_new_entries, \ chunks, fsync_logorder, get_entries, add_to_logorder, \ - hexencode, parse_args, perm + hexencode, hexdecode, parse_args, perm, flock_ex_or_fail, Status, \ + terminate_child_procs from certtools import timing_point, write_file, create_ssl_context -def merge_fetch(args, config, localconfig): +def merge_fetch_sequenced(args, config, localconfig): paths = localconfig["paths"] storagenodes = config["storagenodes"] mergedb = paths["mergedb"] logorderfile = mergedb + "/logorder" + statusfile = mergedb + "/merge_fetch.status" + s = Status(statusfile) chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains") own_key = (localconfig["nodename"], "%s/%s-private.pem" % (paths["privatekeys"], @@ -38,8 +45,7 @@ def merge_fetch(args, config, localconfig): entries_to_fetch = {} for storagenode in storagenodes: - print >>sys.stderr, "getting new entries from", storagenode["name"] - sys.stderr.flush() + logging.info("getting new entries from %s", storagenode["name"]) new_entries_per_node[storagenode["name"]] = \ set(get_new_entries(storagenode["name"], "https://%s/" % storagenode["address"], @@ -49,8 +55,7 @@ def merge_fetch(args, config, localconfig): timing_point(timing, "get new entries") new_entries -= certsinlog - print >>sys.stderr, "adding", len(new_entries), "entries" - sys.stderr.flush() + logging.info("adding %d entries", len(new_entries)) for ehash in new_entries: for storagenode in storagenodes: @@ -64,9 +69,8 @@ def merge_fetch(args, config, localconfig): added_entries = 0 for storagenode in storagenodes: - print >>sys.stderr, "getting %d entries from %s:" % \ - (len(entries_to_fetch[storagenode["name"]]), storagenode["name"]), - sys.stderr.flush() + nentries = len(entries_to_fetch[storagenode["name"]]) + logging.info("getting %d entries from %s", nentries, storagenode["name"]) with requests.sessions.Session() as session: for chunk in chunks(entries_to_fetch[storagenode["name"]], 100): entries = get_entries(storagenode["name"], @@ -80,21 +84,17 @@ def merge_fetch(args, config, localconfig): logorder.append(ehash) certsinlog.add(ehash) added_entries += 1 - print >>sys.stderr, added_entries, - sys.stderr.flush() - print >>sys.stderr - sys.stderr.flush() + s.status("PROG getting %d entries from %s: %d" % + (nentries, storagenode["name"], added_entries)) chainsdb.commit() fsync_logorder(logorderfile) timing_point(timing, "add entries") - print >>sys.stderr, "added", added_entries, "entries" - sys.stderr.flush() + logging.info("added %d entries", added_entries) verifycert.communicate(struct.pack("I", 0)) if args.timing: - print >>sys.stderr, "timing: merge_fetch:", timing["deltatimes"] - sys.stderr.flush() + logging.debug("timing: merge_fetch: %s", timing["deltatimes"]) tree_size = len(logorder) if tree_size == 0: @@ -102,30 +102,221 @@ def merge_fetch(args, config, localconfig): else: return (tree_size, logorder[tree_size-1]) +def merge_fetch_worker(args, localconfig, storagenode, pipe): + paths = localconfig["paths"] + mergedb = paths["mergedb"] + chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains") + own_key = (localconfig["nodename"], + "%s/%s-private.pem" % (paths["privatekeys"], + localconfig["nodename"])) + to_fetch = set() + timeout = max(3, args.mergeinterval / 10) + while True: + if pipe.poll(timeout): + msg = pipe.recv().split() + if len(msg) < 2: + continue + cmd = msg[0] + ehash = msg[1] + if cmd == 'FETCH': + to_fetch.add(hexdecode(ehash)) + else: + logging.warning("%s: unknown command from parent: %s", + storagenode["name"], msg) + + if len(to_fetch) > 0: + logging.info("%s: fetching %d entries", storagenode["name"], + len(to_fetch)) + # TODO: Consider running the verifycert process longer. + verifycert = subprocess.Popen( + [paths["verifycert_bin"], paths["known_roots"]], + stdin=subprocess.PIPE, stdout=subprocess.PIPE) + # Chunking for letting other workers take the chainsdb lock. + for chunk in chunks(list(to_fetch), 100): + chainsdb.lock_ex() + with requests.sessions.Session() as session: + entries = get_entries(storagenode["name"], + "https://%s/" % storagenode["address"], + own_key, paths, chunk, session=session) + for ehash in chunk: + entry = entries[ehash] + verify_entry(verifycert, entry, ehash) + chainsdb.add(ehash, entry) + chainsdb.commit() + chainsdb.release_lock() + for ehash in chunk: + pipe.send('FETCHED %s' % hexencode(ehash)) + to_fetch.remove(ehash) + verifycert.communicate(struct.pack("I", 0)) + + new_entries = get_new_entries(storagenode["name"], + "https://%s/" % storagenode["address"], + own_key, paths) + if len(new_entries) > 0: + logging.info("%s: got %d new entries", storagenode["name"], + len(new_entries)) + for ehash in new_entries: + pipe.send('NEWENTRY %s' % hexencode(ehash)) + +def term(signal, arg): + terminate_child_procs() + sys.exit(1) + +def newworker(name, args): + my_conn, child_conn = Pipe() + p = Process(target=merge_fetch_worker, + args=tuple(args + [child_conn]), + name='merge_fetch_%s' % name) + p.daemon = True + p.start() + logging.debug("%s started, pid %d", name, p.pid) + return (name, my_conn, p) + +def merge_fetch_parallel(args, config, localconfig): + paths = localconfig["paths"] + storagenodes = config["storagenodes"] + mergedb = paths["mergedb"] + logorderfile = mergedb + "/logorder" + currentsizefile = mergedb + "/fetched" + + rand = Random() + signal.signal(signal.SIGTERM, term) + + procs = {} + for storagenode in storagenodes: + name = storagenode['name'] + procs[name] = newworker(name, [args, localconfig, storagenode]) + + logorder = get_logorder(logorderfile) # List of entries in log. + entries_in_log = set(logorder) # Set of entries in log. + entries_to_fetch = set() # Set of entries to fetch. + fetch = {} # Dict with entries to fetch. + while procs: + assert(not entries_to_fetch) + # Poll worker processes. + for name, pipe, p in procs.values(): + if not p.is_alive(): + logging.warning("%s is gone, restarting", name) + procs[name] = newworker(name, [args, localconfig, + storagenodes[name]]) + continue + logging.info("polling %s", name) + if pipe.poll(1): + msg = pipe.recv().split() + if len(msg) < 2: + logging.warning("unknown command from %s: %s", name, msg) + continue + cmd = msg[0] + ehash = msg[1] + if cmd == 'NEWENTRY': + logging.info("NEWENTRY at %s: %s", name, ehash) + entries_to_fetch.add(ehash) + logging.debug("entries_to_fetch: %s", entries_to_fetch) + elif cmd == 'FETCHED': + logging.info("FETCHED from %s: %s", name, ehash) + logorder.append(ehash) + add_to_logorder(logorderfile, hexdecode(ehash)) + fsync_logorder(logorderfile) + entries_in_log.add(ehash) + if ehash in entries_to_fetch: + entries_to_fetch.remove(ehash) + del fetch[ehash] + else: + logging.warning("unknown command from %s: %s", name, msg) + + # Ask workers to fetch entries. + logging.debug("nof entries to fetch including entries in log: %d", + len(entries_to_fetch)) + entries_to_fetch -= entries_in_log + logging.info("entries to fetch: %d", len(entries_to_fetch)) + # Add entries in entries_to_fetch as keys in dictionary fetch, + # values being a list of storage nodes, in randomised order. + for e in entries_to_fetch: + if not e in fetch: + l = procs.values() + rand.shuffle(l) + fetch[e] = l + # For each entry to fetch, treat its list of nodes as a + # circular list and ask the one in the front to fetch the + # entry. + while entries_to_fetch: + ehash = entries_to_fetch.pop() + nodes = fetch[ehash] + node = nodes.pop(0) + fetch[ehash] = nodes.append(node) + name, pipe, p = node + logging.info("asking %s to FETCH %s", name, ehash) + pipe.send("FETCH %s" % ehash) + + # Update the 'fetched' file. + logsize = len(logorder) + if logsize == 0: + last_hash = '' + else: + last_hash = logorder[logsize - 1] + logging.info("updating 'fetched' file: %d %s", logsize-1, last_hash) + currentsize = {"index": logsize - 1, "hash": last_hash} + logging.debug("writing to %s: %s", currentsizefile, currentsize) + write_file(currentsizefile, currentsize) + + return 0 + def main(): """ - Fetch new entries from all storage nodes. + If no `--mergeinterval': + Fetch new entries from all storage nodes, in sequence, updating + the 'logorder' file and the 'chains' database. + + Write 'fetched' to reflect how far in 'logorder' we've succesfully + fetched and verified. + + If `--mergeinterval': + Start one process per storage node, read their stdout for learning + about two things: (i) new entries ready for fetching ("NEWENTRY") and + (ii) new entries being succesfully fetched ("FETCHED"). - Indicate current position by writing the index in the logorder file - (0-based) to the 'fetched' file. + Write to their stdin ("FETCH") when they should fetch another entry. + Update 'logorder' and the 'chains' database as we see new FETCHED + messages. - Sleep some and start over. + Write 'fetched' to reflect how far in 'logorder' we've succesfully + fetched and verified. + + Keep doing this forever. + + NOTE: The point of having 'fetched' is that it can be atomically + written while 'logorder' cannot (unless we're fine with rewriting it + for each and every update, which we're not). + + TODO: Deduplicate some code. """ args, config, localconfig = parse_args() paths = localconfig["paths"] mergedb = paths["mergedb"] currentsizefile = mergedb + "/fetched" + lockfile = mergedb + "/.merge_fetch.lock" + + loglevel = getattr(logging, args.loglevel.upper()) + if args.mergeinterval is None: + logging.basicConfig(level=loglevel) + else: + logging.basicConfig(filename=args.logdir + "/merge_fetch.log", + level=loglevel) + + if not flock_ex_or_fail(lockfile): + logging.critical("unable to take lock %s", lockfile) + return 1 + create_ssl_context(cafile=paths["https_cacertfile"]) - while True: - logsize, last_hash = merge_fetch(args, config, localconfig) + if args.mergeinterval: + return merge_fetch_parallel(args, config, localconfig) + else: + logsize, last_hash = merge_fetch_sequenced(args, config, localconfig) currentsize = {"index": logsize - 1, "hash": hexencode(last_hash)} - #print >>sys.stderr, "DEBUG: writing to", currentsizefile, ":", currentsize + logging.debug("writing to %s: %s", currentsizefile, currentsize) write_file(currentsizefile, currentsize) - if args.interval is None: - break - print >>sys.stderr, "sleeping", args.interval, "seconds" - sleep(args.interval) + return 0 if __name__ == '__main__': sys.exit(main()) diff --git a/tools/merge_sth.py b/tools/merge_sth.py index f4aec53..97f6e24 100755 --- a/tools/merge_sth.py +++ b/tools/merge_sth.py @@ -9,12 +9,13 @@ # import sys import json -import urllib2 import time import requests +import logging from base64 import b64encode +from datetime import datetime, timedelta from mergetools import parse_args, get_nfetched, hexencode, hexdecode, \ - get_logorder, get_sth + get_logorder, get_sth, flock_ex_or_fail from certtools import create_ssl_context, get_public_key_from_file, \ timing_point, create_sth_signature, write_file, check_sth_signature, \ build_merkle_tree @@ -39,6 +40,7 @@ def merge_sth(args, config, localconfig): trees = [{'tree_size': get_nfetched(currentsizefile, logorderfile), 'sha256_root_hash': ''}] + logging.debug("starting point, trees: %s", trees) for mergenode in mergenodes: if mergenode["name"] == config["primarymergenode"]: continue @@ -49,28 +51,29 @@ def merge_sth(args, config, localconfig): tree = {'tree_size': 0, "sha256_root_hash": ''} trees.append(tree) trees.sort(key=lambda e: e['tree_size'], reverse=True) - #print >>sys.stderr, "DEBUG: trees:", trees + logging.debug("trees: %s", trees) if backupquorum > len(trees) - 1: - print >>sys.stderr, "backup quorum > number of secondaries:", \ - backupquorum, ">", len(trees) - 1 - return + logging.error("backup quorum > number of secondaries: %d > %d", + backupquorum, len(trees) - 1) + return -1 tree_size = trees[backupquorum]['tree_size'] root_hash = hexdecode(trees[backupquorum]['sha256_root_hash']) - #print >>sys.stderr, "DEBUG: tree size candidate at backupquorum", backupquorum, ":", tree_size + logging.debug("tree size candidate at backupquorum %d: %d", backupquorum, + tree_size) cur_sth = get_sth(sthfile) if tree_size < cur_sth['tree_size']: - print >>sys.stderr, "candidate tree < current tree:", \ - tree_size, "<", cur_sth['tree_size'] - return + logging.info("candidate tree < current tree: %d < %d", + tree_size, cur_sth['tree_size']) + return 0 assert tree_size >= 0 # Don't read logorder without limit. logorder = get_logorder(logorderfile, tree_size) timing_point(timing, "get logorder") if tree_size == -1: tree_size = len(logorder) - print >>sys.stderr, "new tree size will be", tree_size + logging.info("new tree size will be %d", tree_size) root_hash_calc = build_merkle_tree(logorder)[-1][0] assert root_hash == '' or root_hash == root_hash_calc @@ -87,11 +90,10 @@ def merge_sth(args, config, localconfig): key=own_key) break except requests.exceptions.HTTPError, e: - print >>sys.stderr, e.response - sys.stderr.flush() + logging.warning("create_sth_signature error: %s", e.response) if tree_head_signature == None: - print >>sys.stderr, "Could not contact any signing nodes" - sys.exit(1) + logging.error("Could not contact any signing nodes") + return 0 sth = {"tree_size": tree_size, "timestamp": timestamp, "sha256_root_hash": b64encode(root_hash), @@ -100,34 +102,59 @@ def merge_sth(args, config, localconfig): check_sth_signature(ctbaseurl, sth, publickey=logpublickey) timing_point(timing, "build sth") - print hexencode(root_hash), timestamp, tree_size - sys.stdout.flush() + logging.info("new root: %s %d %d", hexencode(root_hash), timestamp, tree_size) write_file(sthfile, sth) if args.timing: - print >>sys.stderr, "timing: merge_sth:", timing["deltatimes"] - sys.stderr.flush() + logging.debug("timing: merge_sth: %s", timing["deltatimes"]) + + return 0 def main(): """ - Read file 'sth' to get current tree size, assuming zero if file not + Read 'sth' to get the current tree size, assuming zero if file not found. Read tree sizes from the backup.<secondary> files, put them in a - list and sort it. Let new tree size equal list[backup-quorum]. Barf - on a new tree size smaller than the currently published tree size. + list and sort the list. Let new tree size be list[backup-quorum]. If + the new tree size is smaller than the currently published tree size, + stop here. + + Decide on a timestamp, build an STH and write it to 'sth'. - Decide on a timestamp, build an STH and write it to file 'sth'. + Sleep some and start over, or exit if there's no `--mergeinterval'. """ args, config, localconfig = parse_args() + paths = localconfig["paths"] + mergedb = paths["mergedb"] + lockfile = mergedb + "/.merge_sth.lock" + + loglevel = getattr(logging, args.loglevel.upper()) + if args.mergeinterval is None: + logging.basicConfig(level=loglevel) + else: + logging.basicConfig(filename=args.logdir + "/merge_sth.log", + level=loglevel) + + if not flock_ex_or_fail(lockfile): + logging.critical("unable to take lock %s", lockfile) + return 1 while True: - merge_sth(args, config, localconfig) - if args.interval is None: + merge_start_time = datetime.now() + ret = merge_sth(args, config, localconfig) + if ret < 0: + return 1 + if args.mergeinterval is None: break - print >>sys.stderr, "sleeping", args.interval, "seconds" - time.sleep(args.interval) + sleep = (merge_start_time + timedelta(seconds=args.mergeinterval) - + datetime.now()).seconds + if sleep > 0: + logging.debug("sleeping %d seconds", sleep) + time.sleep(sleep) + + return 0 if __name__ == '__main__': sys.exit(main()) diff --git a/tools/mergetools.py b/tools/mergetools.py index 94901a9..d5d5f75 100644 --- a/tools/mergetools.py +++ b/tools/mergetools.py @@ -10,11 +10,16 @@ import json import yaml import argparse import requests +import time +import fcntl +import errno +import multiprocessing +import logging try: import permdb except ImportError: pass -from certtools import get_leaf_hash, http_request, get_leaf_hash +from certtools import get_leaf_hash, http_request def parselogrow(row): return base64.b16decode(row, casefold=True) @@ -33,7 +38,7 @@ def get_nfetched(currentsizefile, logorderfile): try: limit = json.loads(open(currentsizefile).read()) except (IOError, ValueError): - return -1 + return 0 if limit['index'] >= 0: with open(logorderfile, 'r') as f: f.seek(limit['index']*65) @@ -292,7 +297,7 @@ def backup_sendlog(node, baseurl, own_key, paths, submission): def sendentries(node, baseurl, own_key, paths, entries, session=None): try: - json_entries = [{"entry":base64.b64encode(entry), "treeleafhash":base64.b64encode(hash)} for hash, entry in entries] + json_entries = [{"entry":base64.b64encode(entry), "treeleafhash":base64.b64encode(ehash)} for ehash, entry in entries] result = http_request( baseurl + "plop/v1/frontend/sendentry", json.dumps(json_entries), @@ -310,13 +315,13 @@ def sendentries(node, baseurl, own_key, paths, entries, session=None): print >>sys.stderr, "========================" sys.stderr.flush() raise e - except requests.exceptions.ConnectionError, e: - print >>sys.stderr, "ERROR: sendentries", baseurl, e.request, e.response + except requests.exceptions.ConnectionError, e2: + print >>sys.stderr, "ERROR: sendentries", baseurl, e2.request, e2.response sys.exit(1) def sendentries_merge(node, baseurl, own_key, paths, entries, session=None): try: - json_entries = [{"entry":base64.b64encode(entry), "treeleafhash":base64.b64encode(hash)} for hash, entry in entries] + json_entries = [{"entry":base64.b64encode(entry), "treeleafhash":base64.b64encode(ehash)} for ehash, entry in entries] result = http_request( baseurl + "plop/v1/merge/sendentry", json.dumps(json_entries), @@ -334,8 +339,8 @@ def sendentries_merge(node, baseurl, own_key, paths, entries, session=None): print >>sys.stderr, "========================" sys.stderr.flush() raise e - except requests.exceptions.ConnectionError, e: - print >>sys.stderr, "ERROR: sendentries_merge", baseurl, e.request, e.response + except requests.exceptions.ConnectionError, e2: + print >>sys.stderr, "ERROR: sendentries_merge", baseurl, e2.request, e2.response sys.exit(1) def publish_sth(node, baseurl, own_key, paths, submission): @@ -430,10 +435,17 @@ def parse_args(): required=True) parser.add_argument('--localconfig', help="Local configuration", required=True) - parser.add_argument('--interval', type=int, metavar="n", - help="Repeate every N seconds") + # FIXME: verify that 0 < N < 1d + parser.add_argument('--mergeinterval', type=int, metavar="n", + help="Merge every N seconds") parser.add_argument("--timing", action='store_true', help="Print timing information") + parser.add_argument("--pidfile", type=str, metavar="file", + help="Store PID in FILE") + parser.add_argument("--logdir", type=str, default=".", metavar="dir", + help="Write logfiles in DIR [default: .]") + parser.add_argument("--loglevel", type=str, default="DEBUG", metavar="level", + help="Log level, one of DEBUG, INFO, WARNING, ERROR, CRITICAL [default: DEBUG]") args = parser.parse_args() config = yaml.load(open(args.config)) @@ -448,13 +460,74 @@ def perm(dbtype, path): return PermDB(path) assert False +def waitforfile(path): + statinfo = None + while statinfo is None: + try: + statinfo = os.stat(path) + except OSError, e: + if e.errno != errno.ENOENT: + raise + time.sleep(1) + return statinfo + +def flock_ex_or_fail(path): + """ + To be used at most once per process. Will otherwise leak file + descriptors. + """ + try: + fcntl.flock(os.open(path, os.O_CREAT), fcntl.LOCK_EX + fcntl.LOCK_NB) + except IOError, e: + if e.errno != errno.EWOULDBLOCK: + raise + return False + return True + +def flock_ex_wait(path): + fd = os.open(path, os.O_CREAT) + logging.debug("waiting for exclusive lock on %s (%s)", fd, path) + fcntl.flock(fd, fcntl.LOCK_EX) + logging.debug("taken exclusive lock on %s", fd) + return fd + +def flock_sh_wait(path): + fd = os.open(path, os.O_CREAT) + logging.debug("waiting for shared lock on %s (%s)", fd, path) + fcntl.flock(fd, fcntl.LOCK_SH) + logging.debug("taken shared lock on %s", fd) + return fd + +def flock_release(fd): + logging.debug("releasing lock on %s", fd) + fcntl.flock(fd, fcntl.LOCK_UN) + os.close(fd) + +def terminate_child_procs(): + for p in multiprocessing.active_children(): + #print >>sys.stderr, "DEBUG: terminating pid", p.pid + p.terminate() + +class Status: + def __init__(self, path): + self.path = path + def status(self, s): + open(self.path, 'w').write(s) + class FileDB: def __init__(self, path): self.path = path + self.lockfile = None def get(self, key): return read_chain(self.path, key) def add(self, key, value): return write_chain(key, value, self.path) + def lock_sh(self): + self.lockfile = flock_sh_wait(self.path + "/.lock") + def lock_ex(self): + self.lockfile = flock_ex_wait(self.path + "/.lock") + def release_lock(self): + flock_release(self.lockfile) def commit(self): pass @@ -465,5 +538,12 @@ class PermDB: return permdb.getvalue(self.permdbobj, key) def add(self, key, value): return permdb.addvalue(self.permdbobj, key, value) + def lock_sh(self): + assert False # NYI + def lock_ex(self): + assert False # NYI + def release_lock(self): + assert False # NYI def commit(self): permdb.committree(self.permdbobj) + diff --git a/tools/testcase1.py b/tools/testcase1.py index 81d589a..885c24d 100755 --- a/tools/testcase1.py +++ b/tools/testcase1.py @@ -13,6 +13,7 @@ import struct import hashlib import itertools import os.path +from time import sleep from certtools import * baseurls = [sys.argv[1]] @@ -20,6 +21,9 @@ logpublickeyfile = sys.argv[2] cacertfile = sys.argv[3] toolsdir = os.path.dirname(sys.argv[0]) testdir = sys.argv[4] +do_merge = True +if len(sys.argv) > 5 and sys.argv[5] == '--nomerge': + do_merge = False certfiles = [toolsdir + ("/testcerts/cert%d.txt" % e) for e in range(1, 6)] @@ -121,7 +125,7 @@ def get_and_check_entry(timestamp, chain, leaf_index, baseurl): assert_equal(len(entries), 1, "get_entries", quiet=True) fetched_entry = entries["entries"][0] merkle_tree_leaf = pack_mtl(timestamp, chain[0]) - leaf_input = base64.decodestring(fetched_entry["leaf_input"]) + leaf_input = base64.decodestring(fetched_entry["leaf_input"]) assert_equal(leaf_input, merkle_tree_leaf, "entry", nodata=True, quiet=True) extra_data = base64.decodestring(fetched_entry["extra_data"]) certchain = decode_certificate_chain(extra_data) @@ -148,8 +152,14 @@ def get_and_check_entry(timestamp, chain, leaf_index, baseurl): len(submittedcertchain)) def merge(): - return subprocess.call([toolsdir + "/merge", "--config", testdir + "/catlfish-test.cfg", - "--localconfig", testdir + "/catlfish-test-local-merge.cfg"]) + if do_merge: + return subprocess.call([toolsdir + "/merge", "--config", testdir + "/catlfish-test.cfg", + "--localconfig", testdir + "/catlfish-test-local-merge.cfg"]) + else: + n = 40 + print "testcase1.py: sleeping", n, "seconds waiting for merge" + sleep(n) + return 0 mergeresult = merge() assert_equal(mergeresult, 0, "merge", quiet=True, fatal=True) |