#!/usr/bin/env python

# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
from certtools import *
import os
import signal
import select

from multiprocessing import Pool

baseurl = sys.argv[1]
certfilepath = sys.argv[2]

lookup_in_log = False
check_sig = False

if certfilepath[-1] == "/":
    certfiles = [certfilepath + filename for filename in sorted(os.listdir(certfilepath))]
else:
    certfiles = [certfilepath]

def submitcert(certfile):
    timing = timing_point()
    certs = get_certs_from_file(certfile)
    timing_point(timing, "readcerts")

    try:
        result = add_chain(baseurl, {"chain":map(base64.b64encode, certs)})
    except SystemExit:
        print "EXIT:", certfile
        select.select([], [], [], 1.0)
        return None

    timing_point(timing, "addchain")

    if result == None:
        print "ERROR for certfile", certfile
        return timing["deltatimes"]

    try:
        if check_sig:
            check_sct_signature(baseurl, certs[0], result)
            timing_point(timing, "checksig")
    except AssertionError, e:
        print "ERROR:", certfile, e
        return None
    except urllib2.HTTPError, e:
        print "ERROR:", certfile, e
        return None
    except ecdsa.keys.BadSignatureError, e:
        print "ERROR: bad signature", certfile
        return None

    if lookup_in_log:

        merkle_tree_leaf = pack_mtl(result["timestamp"], certs[0])

        leaf_hash = get_leaf_hash(merkle_tree_leaf)

        sth = get_sth(baseurl)

        proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"])

        leaf_index = proof["leaf_index"]

        entries = get_entries(baseurl, leaf_index, leaf_index)

        fetched_entry = entries["entries"][0]

        print "does the leaf_input of the fetched entry match what we calculated:", \
          base64.decodestring(fetched_entry["leaf_input"]) == merkle_tree_leaf

        extra_data = fetched_entry["extra_data"]

        certchain = decode_certificate_chain(base64.decodestring(extra_data))

        submittedcertchain = certs[1:]

        for (submittedcert, fetchedcert, i) in zip(submittedcertchain,
                                                   certchain, itertools.count(1)):
            print "cert", i, "in chain is the same:", submittedcert == fetchedcert

        if len(certchain) == len(submittedcertchain) + 1:
            last_issuer = get_cert_info(certs[-1])["issuer"]
            root_subject = get_cert_info(certchain[-1])["subject"]
            print "issuer of last cert in submitted chain and " \
                "subject of last cert in fetched chain is the same:", \
                last_issuer == root_subject
        elif len(certchain) == len(submittedcertchain):
            print "cert chains are the same length"
        else:
            print "ERROR: fetched cert chain has length", len(certchain),
            print "and submitted chain has length", len(submittedcertchain)

    timing_point(timing, "lookup")
    return timing["deltatimes"]

p = Pool(16, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN))

nsubmitted = 0
lastprinted = 0
starttime = datetime.datetime.now()

print len(certfiles), "certs"

submitcert(certfiles[0])
nsubmitted += 1
select.select([], [], [], 3.0)

try:
    for timing in p.imap_unordered(submitcert, certfiles[1:]):
        if timing == None:
            print "error"
            print "submitted", nsubmitted
            p.terminate()
            p.join()
            sys.exit(1)
        nsubmitted += 1
        deltatime = datetime.datetime.now() - starttime
        deltatime_f = deltatime.seconds + deltatime.microseconds / 1000000.0
        rate = nsubmitted / deltatime_f
        if nsubmitted > lastprinted + len(certfiles) / 10:
            print nsubmitted, "rate %.1f" % rate
            lastprinted = nsubmitted
        #print timing, "rate %.1f" % rate
    print "submitted", nsubmitted
except KeyboardInterrupt:
    p.terminate()
    p.join()