diff options
Diffstat (limited to 'tools')
-rwxr-xr-x | tools/submitcert.py | 62 | ||||
-rwxr-xr-x | tools/verifysct.py | 101 |
2 files changed, 144 insertions, 19 deletions
diff --git a/tools/submitcert.py b/tools/submitcert.py index 04b6ebe..9f0be67 100755 --- a/tools/submitcert.py +++ b/tools/submitcert.py @@ -3,6 +3,7 @@ # Copyright (c) 2014, NORDUnet A/S. # See LICENSE for licensing information. +import argparse import urllib2 import urllib import json @@ -17,19 +18,29 @@ import signal import select import zipfile +parser = argparse.ArgumentParser(description='') +parser.add_argument('baseurl', help="Base URL for CT server") +parser.add_argument('--store', default=None, metavar="dir", help='Get certificates from directory dir') +parser.add_argument('--sct-file', default=None, metavar="file", help='Store SCT:s in file') +parser.add_argument('--parallel', type=int, default=16, metavar="n", help="Number of parallel submits") +parser.add_argument('--check-sct', action='store_true', help="Check SCT signature") +parser.add_argument('--pre-warm', action='store_true', help="Wait 3 seconds after first submit") +args = parser.parse_args() + from multiprocessing import Pool -baseurl = sys.argv[1] -certfilepath = sys.argv[2] +baseurl = args.baseurl +certfilepath = args.store lookup_in_log = False -check_sig = False if certfilepath[-1] == "/": - certfiles = [certfilepath + filename for filename in sorted(os.listdir(certfilepath))] + certfiles = [certfilepath + filename for filename in sorted(os.listdir(certfilepath)) if os.path.isfile(certfilepath + filename)] else: certfiles = [certfilepath] +sth = get_sth(baseurl) + def submitcert((certfile, cert)): timing = timing_point() certchain = get_certs_from_string(cert) @@ -40,27 +51,27 @@ def submitcert((certfile, cert)): except SystemExit: print "EXIT:", certfile select.select([], [], [], 1.0) - return None + return (None, None) timing_point(timing, "addchain") if result == None: print "ERROR for certfile", certfile - return timing["deltatimes"] + return (None, timing["deltatimes"]) try: - if check_sig: + if args.check_sct: check_sct_signature(baseurl, certchain[0], result) timing_point(timing, "checksig") except AssertionError, e: print "ERROR:", certfile, e - return None + return (None, None) except urllib2.HTTPError, e: print "ERROR:", certfile, e - return None + return (None, None) except ecdsa.keys.BadSignatureError, e: print "ERROR: bad signature", certfile - return None + return (None, None) if lookup_in_log: @@ -68,8 +79,6 @@ def submitcert((certfile, cert)): leaf_hash = get_leaf_hash(merkle_tree_leaf) - sth = get_sth(baseurl) - proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"]) leaf_index = proof["leaf_index"] @@ -104,7 +113,7 @@ def submitcert((certfile, cert)): print "and submitted chain has length", len(submittedcertchain) timing_point(timing, "lookup") - return timing["deltatimes"] + return ((certchain[0], result), timing["deltatimes"]) def get_ncerts(certfiles): n = 0 @@ -127,32 +136,47 @@ def get_all_certificates(certfiles): else: yield (certfile, open(certfile).read()) -p = Pool(16, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN)) +def save_sct(sct, sth): + sctlog = open(args.sct_file, "a") + json.dump({"leafcert": base64.b64encode(leafcert), "sct": sct, "sth": sth}, sctlog) + sctlog.write("\n") + sctlog.close() + +p = Pool(args.parallel, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN)) nsubmitted = 0 lastprinted = 0 +print "listing certs" ncerts = get_ncerts(certfiles) print ncerts, "certs" certs = get_all_certificates(certfiles) -submitcert(certs.next()) -nsubmitted += 1 -select.select([], [], [], 3.0) +(result, timing) = submitcert(certs.next()) +if result != None: + nsubmitted += 1 + (leafcert, sct) = result + save_sct(sct, sth) + +if args.pre_warm: + select.select([], [], [], 3.0) starttime = datetime.datetime.now() try: - for timing in p.imap_unordered(submitcert, certs): + for result, timing in p.imap_unordered(submitcert, certs): if timing == None: print "error" print "submitted", nsubmitted p.terminate() p.join() sys.exit(1) - nsubmitted += 1 + if result != None: + nsubmitted += 1 + (leafcert, sct) = result + save_sct(sct, sth) deltatime = datetime.datetime.now() - starttime deltatime_f = deltatime.seconds + deltatime.microseconds / 1000000.0 rate = nsubmitted / deltatime_f diff --git a/tools/verifysct.py b/tools/verifysct.py new file mode 100755 index 0000000..290d471 --- /dev/null +++ b/tools/verifysct.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python + +# Copyright (c) 2014, NORDUnet A/S. +# See LICENSE for licensing information. + +import argparse +import urllib2 +import urllib +import json +import base64 +import sys +import struct +import hashlib +import itertools +from certtools import * +import os +import signal +import select +import zipfile + +parser = argparse.ArgumentParser(description='') +parser.add_argument('baseurl', help="Base URL for CT server") +parser.add_argument('--sct-file', default=None, metavar="dir", help='SCT:s to verify') +parser.add_argument('--parallel', type=int, default=16, metavar="n", help="Number of parallel verifications") +args = parser.parse_args() + +from multiprocessing import Pool + +baseurl = args.baseurl + +def verifysct(sctentry): + timing = timing_point() + + leafcert = base64.b64decode(sctentry["leafcert"]) + try: + check_sct_signature(baseurl, leafcert, sctentry["sct"]) + timing_point(timing, "checksig") + except AssertionError, e: + print "ERROR:", e + return (None, None) + except urllib2.HTTPError, e: + print "ERROR:", e + return (None, None) + except ecdsa.keys.BadSignatureError, e: + print "ERROR: bad signature" + return (None, None) + + merkle_tree_leaf = pack_mtl(sctentry["sct"]["timestamp"], leafcert) + + leaf_hash = get_leaf_hash(merkle_tree_leaf) + + proof = get_proof_by_hash(baseurl, leaf_hash, sctentry["sth"]["tree_size"]) + + #print proof + + leaf_index = proof["leaf_index"] + inclusion_proof = [base64.b64decode(e) for e in proof["audit_path"]] + + calc_root_hash = verify_inclusion_proof(inclusion_proof, leaf_index, sctentry["sth"]["tree_size"], leaf_hash) + + root_hash = base64.b64decode(sctentry["sth"]["sha256_root_hash"]) + if root_hash != calc_root_hash: + print "sth" + print base64.b16encode(root_hash) + print base64.b16encode(calc_root_hash) + assert root_hash == calc_root_hash + + timing_point(timing, "lookup") + return (True, timing["deltatimes"]) + +p = Pool(args.parallel, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN)) + +sctfile = open(args.sct_file) +scts = [json.loads(row) for row in sctfile] + +nverified = 0 +lastprinted = 0 + +starttime = datetime.datetime.now() + +try: + for result, timing in p.imap_unordered(verifysct, scts): + if timing == None: + print "error" + print "verified", nverified + p.terminate() + p.join() + sys.exit(1) + if result != None: + nverified += 1 + deltatime = datetime.datetime.now() - starttime + deltatime_f = deltatime.seconds + deltatime.microseconds / 1000000.0 + rate = nverified / deltatime_f + if nverified > lastprinted + 100: + print nverified, "rate %.1f" % rate + lastprinted = nverified + #print timing, "rate %.1f" % rate + print "verified", nverified +except KeyboardInterrupt: + p.terminate() + p.join() |