# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import subprocess
import json
import base64
import urllib
import urllib2
import struct
import sys
import hashlib
import ecdsa
import datetime
import cStringIO
import zipfile

publickeys = {
    "https://ct.googleapis.com/pilot/":
    "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfahLEimAoz2t01p3uMziiLOl/fHTD"
    "M0YDOhBRuiBARsV4UvxG2LdNgoIGLrtCzWE0J5APC2em4JlvR8EEEFMoA==",

    "https://127.0.0.1:8080/":
    "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4qWq6afhBUi0OdcWUYhyJLNXTkGqQ9"
    "PMS5lqoCgkV2h1ZvpNjBH2u8UbgcOQwqDo66z6BWQJGolozZYmNHE2kQ==",

    "https://flimsy.ct.nordu.net/":
    "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4qWq6afhBUi0OdcWUYhyJLNXTkGqQ9"
    "PMS5lqoCgkV2h1ZvpNjBH2u8UbgcOQwqDo66z6BWQJGolozZYmNHE2kQ==",
}

def get_cert_info(s):
    p = subprocess.Popen(
        ["openssl", "x509", "-noout", "-subject", "-issuer", "-inform", "der"],
        stdin=subprocess.PIPE, stdout=subprocess.PIPE,
        stderr=subprocess.PIPE)
    parsed = p.communicate(s)
    if parsed[1]:
        print "ERROR:", parsed[1]
        sys.exit(1)
    result = {}
    for line in parsed[0].split("\n"):
        (key, sep, value) = line.partition("=")
        if sep == "=":
            result[key] = value
    return result


def get_pemlike(filename, marker):
    return get_pemlike_from_file(open(filename), marker)

def get_pemlike_from_file(f, marker):
    entries = []
    entry = ""
    inentry = False

    for line in f:
        line = line.strip()
        if line == "-----BEGIN " + marker + "-----":
            entry = ""
            inentry = True
        elif line == "-----END " + marker + "-----":
            entries.append(base64.decodestring(entry))
            inentry = False
        elif inentry:
            entry += line
    return entries

def get_certs_from_file(certfile):
    return get_pemlike(certfile, "CERTIFICATE")

def get_certs_from_string(s):
    f = cStringIO.StringIO(s)
    return get_pemlike_from_file(f, "CERTIFICATE")

def get_eckey_from_file(keyfile):
    keys = get_pemlike(keyfile, "EC PRIVATE KEY")
    assert len(keys) == 1
    return keys[0]

def get_root_cert(issuer):
    accepted_certs = \
        json.loads(open("googlelog-accepted-certs.txt").read())["certificates"]

    root_cert = None

    for accepted_cert in accepted_certs:
        subject = get_cert_info(base64.decodestring(accepted_cert))["subject"]
        if subject == issuer:
            root_cert = base64.decodestring(accepted_cert)

    return root_cert

def get_sth(baseurl):
    result = urllib2.urlopen(baseurl + "ct/v1/get-sth").read()
    return json.loads(result)

def get_proof_by_hash(baseurl, hash, tree_size):
    try:
        params = urllib.urlencode({"hash":base64.b64encode(hash),
                                   "tree_size":tree_size})
        result = \
          urllib2.urlopen(baseurl + "ct/v1/get-proof-by-hash?" + params).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR:", e.read()
        sys.exit(1)

def tls_array(data, length_len):
    length_bytes = struct.pack(">Q", len(data))[-length_len:]
    return length_bytes + data

def unpack_tls_array(packed_data, length_len):
    padded_length = ["\x00"] * 8
    padded_length[-length_len:] = packed_data[:length_len]
    (length,) = struct.unpack(">Q", "".join(padded_length))
    unpacked_data = packed_data[length_len:length_len+length]
    assert len(unpacked_data) == length, \
      "data is only %d bytes long, but length is %d bytes" % \
      (len(unpacked_data), length)
    rest_data = packed_data[length_len+length:]
    return (unpacked_data, rest_data)

def add_chain(baseurl, submission):
    try:
        result = urllib2.urlopen(baseurl + "ct/v1/add-chain",
            json.dumps(submission)).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR", e.code,":", e.read()
        if e.code == 400:
            return None
        sys.exit(1)
    except ValueError, e:
        print "==== FAILED REQUEST ===="
        print submission
        print "======= RESPONSE ======="
        print result
        print "========================"
        raise e

def get_entries(baseurl, start, end):
    try:
        params = urllib.urlencode({"start":start, "end":end})
        result = urllib2.urlopen(baseurl + "ct/v1/get-entries?" + params).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR:", e.read()
        sys.exit(1)

def extract_precertificate(precert_chain_entry):
    (precert, certchain) = unpack_tls_array(precert_chain_entry, 3)
    return (precert, certchain)

def decode_certificate_chain(packed_certchain):
    (unpacked_certchain, rest) = unpack_tls_array(packed_certchain, 3)
    assert len(rest) == 0
    certs = []
    while len(unpacked_certchain):
        (cert, rest) = unpack_tls_array(unpacked_certchain, 3)
        certs.append(cert)
        unpacked_certchain = rest
    return certs

def decode_signature(signature):
    (hash_alg, signature_alg) = struct.unpack(">bb", signature[0:2])
    (unpacked_signature, rest) = unpack_tls_array(signature[2:], 2)
    assert rest == ""
    return (hash_alg, signature_alg, unpacked_signature)

def encode_signature(hash_alg, signature_alg, unpacked_signature):
    signature = struct.pack(">bb", hash_alg, signature_alg)
    signature += tls_array(unpacked_signature, 2)
    return signature

def check_signature(baseurl, signature, data):
    publickey = base64.decodestring(publickeys[baseurl])
    (hash_alg, signature_alg, unpacked_signature) = decode_signature(signature)
    assert hash_alg == 4, \
        "hash_alg is %d, expected 4" % (hash_alg,) # sha256
    assert signature_alg == 3, \
        "signature_alg is %d, expected 3" % (signature_alg,) # ecdsa

    vk = ecdsa.VerifyingKey.from_der(publickey)
    vk.verify(unpacked_signature, data, hashfunc=hashlib.sha256,
              sigdecode=ecdsa.util.sigdecode_der)

def create_signature(privatekey, data):
    sk = ecdsa.SigningKey.from_der(privatekey)
    unpacked_signature = sk.sign(data, hashfunc=hashlib.sha256,
                                 sigencode=ecdsa.util.sigencode_der)
    return encode_signature(4, 3, unpacked_signature)

def check_sth_signature(baseurl, sth):
    signature = base64.decodestring(sth["tree_head_signature"])

    version = struct.pack(">b", 0)
    signature_type = struct.pack(">b", 1)
    timestamp = struct.pack(">Q", sth["timestamp"])
    tree_size = struct.pack(">Q", sth["tree_size"])
    hash = base64.decodestring(sth["sha256_root_hash"])
    tree_head = version + signature_type + timestamp + tree_size + hash

    check_signature(baseurl, signature, tree_head)

def create_sth_signature(tree_size, timestamp, root_hash, privatekey):
    version = struct.pack(">b", 0)
    signature_type = struct.pack(">b", 1)
    timestamp_packed = struct.pack(">Q", timestamp)
    tree_size_packed = struct.pack(">Q", tree_size)
    tree_head = version + signature_type + timestamp_packed + tree_size_packed + root_hash

    return create_signature(privatekey, tree_head)

def check_sct_signature(baseurl, leafcert, sct):
    publickey = base64.decodestring(publickeys[baseurl])
    calculated_logid = hashlib.sha256(publickey).digest()
    received_logid = base64.decodestring(sct["id"])
    assert calculated_logid == received_logid, \
        "log id is incorrect:\n  should be %s\n        got %s" % \
        (calculated_logid.encode("hex_codec"),
         received_logid.encode("hex_codec"))

    signature = base64.decodestring(sct["signature"])

    version = struct.pack(">b", sct["sct_version"])
    signature_type = struct.pack(">b", 0)
    timestamp = struct.pack(">Q", sct["timestamp"])
    entry_type = struct.pack(">H", 0)
    signed_struct = version + signature_type + timestamp + \
      entry_type + tls_array(leafcert, 3) + \
      tls_array(base64.decodestring(sct["extensions"]), 2)

    check_signature(baseurl, signature, signed_struct)

def pack_mtl(timestamp, leafcert):
    entry_type = struct.pack(">H", 0)
    extensions = ""

    timestamped_entry = struct.pack(">Q", timestamp) + entry_type + \
      tls_array(leafcert, 3) + tls_array(extensions, 2)
    version = struct.pack(">b", 0)
    leaf_type = struct.pack(">b", 0)
    merkle_tree_leaf = version + leaf_type + timestamped_entry
    return merkle_tree_leaf

def unpack_mtl(merkle_tree_leaf):
    version = merkle_tree_leaf[0:1]
    leaf_type = merkle_tree_leaf[1:2]
    timestamped_entry = merkle_tree_leaf[2:]
    (timestamp, entry_type) = struct.unpack(">QH", timestamped_entry[0:10])
    if entry_type == 0:
        issuer_key_hash = None
        (leafcert, rest_entry) = unpack_tls_array(timestamped_entry[10:], 3)
    elif entry_type == 1:
        issuer_key_hash = timestamped_entry[10:42]
        (leafcert, rest_entry) = unpack_tls_array(timestamped_entry[42:], 3)
    return (leafcert, timestamp, issuer_key_hash)

def get_leaf_hash(merkle_tree_leaf):
    leaf_hash = hashlib.sha256()
    leaf_hash.update(struct.pack(">b", 0))
    leaf_hash.update(merkle_tree_leaf)

    return leaf_hash.digest()

def timing_point(timer_dict=None, name=None):
    t = datetime.datetime.now()
    if timer_dict:
        starttime = timer_dict["lasttime"]
        stoptime = t
        deltatime = stoptime - starttime
        timer_dict["deltatimes"].append((name, deltatime.seconds * 1000000 + deltatime.microseconds))
        timer_dict["lasttime"] = t
        return None
    else:
        timer_dict = {"deltatimes":[], "lasttime":t}
        return timer_dict

def internal_hash(pair):
    if len(pair) == 1:
        return pair[0]
    else:
        hash = hashlib.sha256()
        hash.update(struct.pack(">b", 1))
        hash.update(pair[0])
        hash.update(pair[1])
        return hash.digest()

def chunks(l, n):
    return [l[i:i+n] for i in range(0, len(l), n)]

def next_merkle_layer(layer):
    return [internal_hash(pair) for pair in chunks(layer, 2)]

def build_merkle_tree(layer0):
    if len(layer0) == 0:
        return [[hashlib.sha256().digest()]]
    layers = []
    current_layer = layer0
    layers.append(current_layer)
    while len(current_layer) > 1:
        current_layer = next_merkle_layer(current_layer)
        layers.append(current_layer)
    return layers

def print_inclusion_proof(proof):
    audit_path = proof[u'audit_path']
    n = proof[u'leaf_index']
    level = 0
    for s in audit_path:
        entry = base64.b16encode(base64.b64decode(s))
        n ^= 1
        print level, n, entry
        n >>= 1
        level += 1

def get_one_cert(store, i):
    filename = i / 10000
    zf = zipfile.ZipFile("%s/%04d.zip" % (store, i / 10000))
    cert = zf.read("%08d" % i)
    zf.close()
    return cert

def get_hash_from_certfile(cert):
    for line in cert.split("\n"):
        if line.startswith("-----"):
            return None
        if line.startswith("Leafhash: "):
            return base64.b16decode(line[len("Leafhash: "):])
    return None

def get_proof(store, tree_size, n):
    hash = get_hash_from_certfile(get_one_cert(store, n))
    return get_proof_by_hash(args.baseurl, hash, tree_size)

def get_certs_from_zipfiles(zipfiles, firstleaf, lastleaf):
    for i in range(firstleaf, lastleaf + 1):
        try:
            yield zipfiles[i / 10000].read("%08d" % i)
        except KeyError:
            return

def get_merkle_hash_64k(store, blocknumber, write_to_cache=False):
    hashfilename = "%s/%04x.64khash" % (store, blocknumber)
    try:
        hash = base64.b16decode(open(hashfilename).read())
        assert len(hash) == 32
        return ("hash", hash)
    except IOError:
        pass
    firstleaf = blocknumber * 65536
    lastleaf = firstleaf + 65535
    firstfile = firstleaf / 10000
    lastfile = lastleaf / 10000
    zipfiles = {}
    for i in range(firstfile, lastfile + 1):
        try:
            zipfiles[i] = zipfile.ZipFile("%s/%04d.zip" % (store, i))
        except IOError:
            break
    certs = get_certs_from_zipfiles(zipfiles, firstleaf, lastleaf)
    layer0 = [get_hash_from_certfile(cert) for cert in certs]
    tree = build_merkle_tree(layer0)
    calculated_hash = tree[-1][0]
    for zf in zipfiles.values():
        zf.close()
    if len(layer0) != 65536:
        return ("incomplete", (len(layer0), calculated_hash))
    if write_to_cache:
        f = open(hashfilename, "w")
        f.write(base64.b16encode(calculated_hash))
        f.close()
    return ("hash", calculated_hash)