# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import subprocess
import json
import base64
import urllib
import urllib2
import struct
import sys
import hashlib
import ecdsa
import datetime

publickeys = {
    "https://ct.googleapis.com/pilot/":
    "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfahLEimAoz2t01p3uMziiLOl/fHTD"
    "M0YDOhBRuiBARsV4UvxG2LdNgoIGLrtCzWE0J5APC2em4JlvR8EEEFMoA==",

    "https://127.0.0.1:8080/":
    "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4qWq6afhBUi0OdcWUYhyJLNXTkGqQ9"
    "PMS5lqoCgkV2h1ZvpNjBH2u8UbgcOQwqDo66z6BWQJGolozZYmNHE2kQ==",

    "https://flimsy.ct.nordu.net/":
    "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4qWq6afhBUi0OdcWUYhyJLNXTkGqQ9"
    "PMS5lqoCgkV2h1ZvpNjBH2u8UbgcOQwqDo66z6BWQJGolozZYmNHE2kQ==",
}

def get_cert_info(s):
    p = subprocess.Popen(
        ["openssl", "x509", "-noout", "-subject", "-issuer", "-inform", "der"],
        stdin=subprocess.PIPE, stdout=subprocess.PIPE,
        stderr=subprocess.PIPE)
    parsed = p.communicate(s)
    if parsed[1]:
        print "ERROR:", parsed[1]
        sys.exit(1)
    result = {}
    for line in parsed[0].split("\n"):
        (key, sep, value) = line.partition("=")
        if sep == "=":
            result[key] = value
    return result

def get_certs_from_file(certfile):
    certs = []
    cert = ""
    incert = False

    for line in open(certfile):
        line = line.strip()
        if line == "-----BEGIN CERTIFICATE-----":
            cert = ""
            incert = True
        elif line == "-----END CERTIFICATE-----":
            certs.append(base64.decodestring(cert))
            incert = False
        elif incert:
            cert += line
    return certs

def get_root_cert(issuer):
    accepted_certs = \
        json.loads(open("googlelog-accepted-certs.txt").read())["certificates"]

    root_cert = None

    for accepted_cert in accepted_certs:
        subject = get_cert_info(base64.decodestring(accepted_cert))["subject"]
        if subject == issuer:
            root_cert = base64.decodestring(accepted_cert)

    return root_cert

def get_sth(baseurl):
    result = urllib2.urlopen(baseurl + "ct/v1/get-sth").read()
    return json.loads(result)

def get_proof_by_hash(baseurl, hash, tree_size):
    try:
        params = urllib.urlencode({"hash":base64.b64encode(hash),
                                   "tree_size":tree_size})
        result = \
          urllib2.urlopen(baseurl + "ct/v1/get-proof-by-hash?" + params).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR:", e.read()
        sys.exit(1)

def tls_array(data, length_len):
    length_bytes = struct.pack(">Q", len(data))[-length_len:]
    return length_bytes + data

def unpack_tls_array(packed_data, length_len):
    padded_length = ["\x00"] * 8
    padded_length[-length_len:] = packed_data[:length_len]
    (length,) = struct.unpack(">Q", "".join(padded_length))
    unpacked_data = packed_data[length_len:length_len+length]
    assert len(unpacked_data) == length, \
      "data is only %d bytes long, but length is %d bytes" % \
      (len(unpacked_data), length)
    rest_data = packed_data[length_len+length:]
    return (unpacked_data, rest_data)

def add_chain(baseurl, submission):
    try:
        result = urllib2.urlopen(baseurl + "ct/v1/add-chain",
            json.dumps(submission)).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR:", e.read()
        sys.exit(1)
    except ValueError, e:
        print "==== FAILED REQUEST ===="
        print submission
        print "======= RESPONSE ======="
        print result
        print "========================"
        raise e

def get_entries(baseurl, start, end):
    try:
        params = urllib.urlencode({"start":start, "end":end})
        result = urllib2.urlopen(baseurl + "ct/v1/get-entries?" + params).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR:", e.read()
        sys.exit(1)

def decode_certificate_chain(packed_certchain):
    (unpacked_certchain, rest) = unpack_tls_array(packed_certchain, 3)
    assert len(rest) == 0
    certs = []
    while len(unpacked_certchain):
        (cert, rest) = unpack_tls_array(unpacked_certchain, 3)
        certs.append(cert)
        unpacked_certchain = rest
    return certs

def decode_signature(signature):
    (hash_alg, signature_alg) = struct.unpack(">bb", signature[0:2])
    (unpacked_signature, rest) = unpack_tls_array(signature[2:], 2)
    assert rest == ""
    return (hash_alg, signature_alg, unpacked_signature)

def encode_signature(hash_alg, signature_alg, unpacked_signature):
    signature = struct.pack(">bb", hash_alg, signature_alg)
    signature += tls_array(unpacked_signature, 2)
    return signature

def check_signature(baseurl, signature, data):
    publickey = base64.decodestring(publickeys[baseurl])
    (hash_alg, signature_alg, unpacked_signature) = decode_signature(signature)
    assert hash_alg == 4, \
        "hash_alg is %d, expected 4" % (hash_alg,) # sha256
    assert signature_alg == 3, \
        "signature_alg is %d, expected 3" % (signature_alg,) # ecdsa

    vk = ecdsa.VerifyingKey.from_der(publickey)
    vk.verify(unpacked_signature, data, hashfunc=hashlib.sha256,
              sigdecode=ecdsa.util.sigdecode_der)

def create_signature(privatekey, data):
    sk = ecdsa.SigningKey.from_der(privatekey)
    unpacked_signature = sk.sign(data, hashfunc=hashlib.sha256,
                                 sigencode=ecdsa.util.sigencode_der)
    return encode_signature(4, 3, unpacked_signature)

def check_sth_signature(baseurl, sth):
    signature = base64.decodestring(sth["tree_head_signature"])

    version = struct.pack(">b", 0)
    signature_type = struct.pack(">b", 1)
    timestamp = struct.pack(">Q", sth["timestamp"])
    tree_size = struct.pack(">Q", sth["tree_size"])
    hash = base64.decodestring(sth["sha256_root_hash"])
    tree_head = version + signature_type + timestamp + tree_size + hash

    check_signature(baseurl, signature, tree_head)

def create_sth_signature(tree_size, timestamp, root_hash, privatekey):
    version = struct.pack(">b", 0)
    signature_type = struct.pack(">b", 1)
    timestamp_packed = struct.pack(">Q", timestamp)
    tree_size_packed = struct.pack(">Q", tree_size)
    tree_head = version + signature_type + timestamp_packed + tree_size_packed + root_hash

    return create_signature(privatekey, tree_head)

def check_sct_signature(baseurl, leafcert, sct):
    publickey = base64.decodestring(publickeys[baseurl])
    calculated_logid = hashlib.sha256(publickey).digest()
    received_logid = base64.decodestring(sct["id"])
    assert calculated_logid == received_logid, \
        "log id is incorrect:\n  should be %s\n        got %s" % \
        (calculated_logid.encode("hex_codec"),
         received_logid.encode("hex_codec"))

    signature = base64.decodestring(sct["signature"])

    version = struct.pack(">b", sct["sct_version"])
    signature_type = struct.pack(">b", 0)
    timestamp = struct.pack(">Q", sct["timestamp"])
    entry_type = struct.pack(">H", 0)
    signed_struct = version + signature_type + timestamp + \
      entry_type + tls_array(leafcert, 3) + \
      tls_array(base64.decodestring(sct["extensions"]), 2)

    check_signature(baseurl, signature, signed_struct)

def pack_mtl(timestamp, leafcert):
    entry_type = struct.pack(">H", 0)
    extensions = ""

    timestamped_entry = struct.pack(">Q", timestamp) + entry_type + \
      tls_array(leafcert, 3) + tls_array(extensions, 2)
    version = struct.pack(">b", 0)
    leaf_type = struct.pack(">b", 0)
    merkle_tree_leaf = version + leaf_type + timestamped_entry
    return merkle_tree_leaf

def unpack_mtl(merkle_tree_leaf):
    version = merkle_tree_leaf[0:1]
    leaf_type = merkle_tree_leaf[1:2]
    timestamped_entry = merkle_tree_leaf[2:]
    (timestamp, entry_type) = struct.unpack(">QH", timestamped_entry[0:10])
    (leafcert, rest_entry) = unpack_tls_array(timestamped_entry[10:], 3)
    return (leafcert, timestamp)

def get_leaf_hash(merkle_tree_leaf):
    leaf_hash = hashlib.sha256()
    leaf_hash.update(struct.pack(">b", 0))
    leaf_hash.update(merkle_tree_leaf)

    return leaf_hash.digest()

def timing_point(timer_dict=None, name=None):
    t = datetime.datetime.now()
    if timer_dict:
        starttime = timer_dict["lasttime"]
        stoptime = t
        deltatime = stoptime - starttime
        timer_dict["deltatimes"].append((name, deltatime.seconds * 1000000 + deltatime.microseconds))
        timer_dict["lasttime"] = t
        return None
    else:
        timer_dict = {"deltatimes":[], "lasttime":t}
        return timer_dict

def internal_hash(pair):
    if len(pair) == 1:
        return pair[0]
    else:
        hash = hashlib.sha256()
        hash.update(struct.pack(">b", 1))
        hash.update(pair[0])
        hash.update(pair[1])
        return hash.digest()

def chunks(l, n):
    return [l[i:i+n] for i in range(0, len(l), n)]

def next_merkle_layer(layer):
    return [internal_hash(pair) for pair in chunks(layer, 2)]

def build_merkle_tree(layer0):
    if len(layer0) == 0:
        return [[hashlib.sha256().digest()]]
    layers = []
    current_layer = layer0
    layers.append(current_layer)
    while len(current_layer) > 1:
        current_layer = next_merkle_layer(current_layer)
        layers.append(current_layer)
    return layers