#!/usr/bin/env python # Copyright (c) 2014, NORDUnet A/S. # See LICENSE for licensing information. import argparse import urllib2 import urllib import json import base64 import sys import struct import hashlib import itertools from certtools import * from precerttools import * import os import signal import select import zipfile parser = argparse.ArgumentParser(description='') parser.add_argument('baseurl', help="Base URL for CT server") parser.add_argument('--store', default=None, metavar="dir", help='Get certificates from directory dir') parser.add_argument('--sct-file', default=None, metavar="file", help='Store SCT:s in file') parser.add_argument('--parallel', type=int, default=16, metavar="n", help="Number of parallel submits") parser.add_argument('--check-sct', action='store_true', help="Check SCT signature") parser.add_argument('--pre-warm', action='store_true', help="Wait 3 seconds after first submit") parser.add_argument('--publickey', default=None, metavar="file", help='Public key for the CT log') parser.add_argument('--cafile', default=None, metavar="file", help='File containing the CA cert') args = parser.parse_args() create_ssl_context(cafile=args.cafile) from multiprocessing import Pool baseurl = args.baseurl certfilepath = args.store logpublickey = get_public_key_from_file(args.publickey) if args.publickey else None lookup_in_log = False if certfilepath[-1] == "/": certfiles = [certfilepath + filename for filename in sorted(os.listdir(certfilepath)) if os.path.isfile(certfilepath + filename)] else: certfiles = [certfilepath] sth = get_sth(baseurl) def submitcert((certfile, cert)): timing = timing_point() certchain = get_certs_from_string(cert) precerts = get_precerts_from_string(cert) assert len(precerts) == 0 or len(precerts) == 1 precert = precerts[0] if precerts else None timing_point(timing, "readcerts") try: if precert: if ext_key_usage_precert_signing_cert in get_ext_key_usage(certchain[0]): issuer_key_hash = get_cert_key_hash(certchain[1]) issuer = certchain[1] else: issuer_key_hash = get_cert_key_hash(certchain[0]) issuer = None cleanedcert = cleanprecert(precert, issuer=issuer) signed_entry = pack_precert(cleanedcert, issuer_key_hash) leafcert = cleanedcert result = add_prechain(baseurl, {"chain":map(base64.b64encode, [precert] + certchain)}) else: signed_entry = pack_cert(certchain[0]) leafcert = certchain[0] issuer_key_hash = None result = add_chain(baseurl, {"chain":map(base64.b64encode, certchain)}) except SystemExit: print "EXIT:", certfile select.select([], [], [], 1.0) return (None, None) timing_point(timing, "addchain") if result == None: print "ERROR for certfile", certfile return (None, timing["deltatimes"]) try: if args.check_sct: check_sct_signature(baseurl, signed_entry, result, precert=precert, publickey=logpublickey) timing_point(timing, "checksig") except AssertionError, e: print "ERROR:", certfile, e return (None, None) except urllib2.HTTPError, e: print "ERROR:", certfile, e return (None, None) except ecdsa.keys.BadSignatureError, e: print "ERROR: bad signature", certfile return (None, None) if lookup_in_log: merkle_tree_leaf = pack_mtl(result["timestamp"], leafcert) leaf_hash = get_leaf_hash(merkle_tree_leaf) proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"]) leaf_index = proof["leaf_index"] entries = get_entries(baseurl, leaf_index, leaf_index) fetched_entry = entries["entries"][0] print "does the leaf_input of the fetched entry match what we calculated:", \ base64.decodestring(fetched_entry["leaf_input"]) == merkle_tree_leaf extra_data = fetched_entry["extra_data"] certchain = decode_certificate_chain(base64.decodestring(extra_data)) submittedcertchain = certchain[1:] for (submittedcert, fetchedcert, i) in zip(submittedcertchain, certchain, itertools.count(1)): print "cert", i, "in chain is the same:", submittedcert == fetchedcert if len(certchain) == len(submittedcertchain) + 1: last_issuer = get_cert_info(certchain[-1])["issuer"] root_subject = get_cert_info(certchain[-1])["subject"] print "issuer of last cert in submitted chain and " \ "subject of last cert in fetched chain is the same:", \ last_issuer == root_subject elif len(certchain) == len(submittedcertchain): print "cert chains are the same length" else: print "ERROR: fetched cert chain has length", len(certchain), print "and submitted chain has length", len(submittedcertchain) timing_point(timing, "lookup") return ((leafcert, issuer_key_hash, result), timing["deltatimes"]) def get_ncerts(certfiles): n = 0 for certfile in certfiles: if certfile.endswith(".zip"): zf = zipfile.ZipFile(certfile) n += len(zf.namelist()) zf.close() else: n += 1 return n def get_all_certificates(certfiles): for certfile in certfiles: if certfile.endswith(".zip"): zf = zipfile.ZipFile(certfile) for name in zf.namelist(): yield (name, zf.read(name)) zf.close() else: yield (certfile, open(certfile).read()) def save_sct(sct, sth, leafcert, issuer_key_hash): sctlog = open(args.sct_file, "a") sctentry = {"leafcert": base64.b64encode(leafcert), "sct": sct, "sth": sth} if issuer_key_hash: sctentry["issuer_key_hash"] = base64.b64encode(issuer_key_hash) json.dump(sctentry, sctlog) sctlog.write("\n") sctlog.close() p = Pool(args.parallel, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN)) nsubmitted = 0 lastprinted = 0 print "listing certs" ncerts = get_ncerts(certfiles) print ncerts, "certs" certs = get_all_certificates(certfiles) (result, timing) = submitcert(certs.next()) if result != None: nsubmitted += 1 (leafcert, issuer_key_hash, sct) = result save_sct(sct, sth, leafcert, issuer_key_hash) if args.pre_warm: select.select([], [], [], 3.0) starttime = datetime.datetime.now() try: for result, timing in p.imap_unordered(submitcert, certs): if timing == None: print "error" print "submitted", nsubmitted p.terminate() p.join() sys.exit(1) if result != None: nsubmitted += 1 (leafcert, issuer_key_hash, sct) = result save_sct(sct, sth, leafcert, issuer_key_hash) deltatime = datetime.datetime.now() - starttime deltatime_f = deltatime.seconds + deltatime.microseconds / 1000000.0 rate = nsubmitted / deltatime_f if nsubmitted > lastprinted + ncerts / 10: print nsubmitted, "rate %.1f" % rate lastprinted = nsubmitted #print timing, "rate %.1f" % rate print "submitted", nsubmitted except KeyboardInterrupt: p.terminate() p.join()