#!/usr/bin/env python

# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
from certtools import *
import os
import signal
import select
import zipfile

parser = argparse.ArgumentParser(description='')
parser.add_argument('baseurl', help="Base URL for CT server")
parser.add_argument('--sct-file', default=None, metavar="dir", help='SCT:s to verify')
parser.add_argument('--parallel', type=int, default=16, metavar="n", help="Number of parallel verifications")
args = parser.parse_args()

from multiprocessing import Pool

baseurl = args.baseurl

sth = get_sth(baseurl)

def verifysct(sctentry):
    timing = timing_point()

    leafcert = base64.b64decode(sctentry["leafcert"])
    try:
        check_sct_signature(baseurl, leafcert, sctentry["sct"])
        timing_point(timing, "checksig")
    except AssertionError, e:
        print "ERROR:", e
        return (None, None)
    except urllib2.HTTPError, e:
        print "ERROR:", e
        return (None, None)
    except ecdsa.keys.BadSignatureError, e:
        print "ERROR: bad signature"
        return (None, None)

    merkle_tree_leaf = pack_mtl(sctentry["sct"]["timestamp"], leafcert)

    leaf_hash = get_leaf_hash(merkle_tree_leaf)

    try:
        proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"])
    except SystemExit:
        return (None, None)

    #print proof

    leaf_index = proof["leaf_index"]
    inclusion_proof = [base64.b64decode(e) for e in proof["audit_path"]]

    calc_root_hash = verify_inclusion_proof(inclusion_proof, leaf_index, sth["tree_size"], leaf_hash)

    root_hash = base64.b64decode(sth["sha256_root_hash"])
    if root_hash != calc_root_hash:
        print "sth"
        print base64.b16encode(root_hash)
        print base64.b16encode(calc_root_hash)
    assert root_hash == calc_root_hash

    timing_point(timing, "lookup")
    return (True, timing["deltatimes"])

p = Pool(args.parallel, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN))

sctfile = open(args.sct_file)
scts = [json.loads(row) for row in sctfile]

nverified = 0
lastprinted = 0

starttime = datetime.datetime.now()

try:
    for result, timing in p.imap_unordered(verifysct, scts):
        if timing == None:
            print "error"
            print "verified", nverified
            p.terminate()
            p.join()
            sys.exit(1)
        if result != None:
            nverified += 1
        deltatime = datetime.datetime.now() - starttime
        deltatime_f = deltatime.seconds + deltatime.microseconds / 1000000.0
        rate = nverified / deltatime_f
        if nverified > lastprinted + 100:
            print nverified, "rate %.1f" % rate
            lastprinted = nverified
        #print timing, "rate %.1f" % rate
    print "verified", nverified
except KeyboardInterrupt:
    p.terminate()
    p.join()