import subprocess
import json
import base64
import urllib
import urllib2
import struct
import sys
import hashlib
import ecdsa

publickeys = {
    "https://ct.googleapis.com/pilot/":
    "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfahLEimAoz2t01p3uMziiLOl/fHTD"
    "M0YDOhBRuiBARsV4UvxG2LdNgoIGLrtCzWE0J5APC2em4JlvR8EEEFMoA==",

    "https://127.0.0.1:8080/":
    "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4qWq6afhBUi0OdcWUYhyJLNXTkGqQ9"
    "PMS5lqoCgkV2h1ZvpNjBH2u8UbgcOQwqDo66z6BWQJGolozZYmNHE2kQ==",

    "https://flimsy.ct.nordu.net/":
    "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4qWq6afhBUi0OdcWUYhyJLNXTkGqQ9"
    "PMS5lqoCgkV2h1ZvpNjBH2u8UbgcOQwqDo66z6BWQJGolozZYmNHE2kQ==",
}

def get_cert_info(s):
    p = subprocess.Popen(
        ["openssl", "x509", "-noout", "-subject", "-issuer", "-inform", "der"],
        stdin=subprocess.PIPE, stdout=subprocess.PIPE,
        stderr=subprocess.PIPE)
    parsed = p.communicate(s)
    if parsed[1]:
        print "ERROR:", parsed[1]
        sys.exit(1)
    result = {}
    for line in parsed[0].split("\n"):
        (key, sep, value) = line.partition("=")
        if sep == "=":
            result[key] = value
    return result

def get_certs_from_file(certfile):
    certs = []
    cert = ""
    incert = False

    for line in open(certfile):
        line = line.strip()
        if line == "-----BEGIN CERTIFICATE-----":
            cert = ""
            incert = True
        elif line == "-----END CERTIFICATE-----":
            certs.append(base64.decodestring(cert))
            incert = False
        elif incert:
            cert += line
    return certs

def get_root_cert(issuer):
    accepted_certs = \
        json.loads(open("googlelog-accepted-certs.txt").read())["certificates"]

    root_cert = None

    for accepted_cert in accepted_certs:
        subject = get_cert_info(base64.decodestring(accepted_cert))["subject"]
        if subject == issuer:
            root_cert = base64.decodestring(accepted_cert)

    return root_cert

def get_sth(baseurl):
    result = urllib2.urlopen(baseurl + "ct/v1/get-sth").read()
    return json.loads(result)

def get_proof_by_hash(baseurl, hash, tree_size):
    try:
        params = urllib.urlencode({"hash":base64.b64encode(hash),
                                   "tree_size":tree_size})
        result = \
          urllib2.urlopen(baseurl + "ct/v1/get-proof-by-hash?" + params).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR:", e.read()
        sys.exit(1)

def tls_array(data, length_len):
    length_bytes = struct.pack(">Q", len(data))[-length_len:]
    return length_bytes + data

def unpack_tls_array(packed_data, length_len):
    padded_length = ["\x00"] * 8
    padded_length[-length_len:] = packed_data[:length_len]
    (length,) = struct.unpack(">Q", "".join(padded_length))
    unpacked_data = packed_data[length_len:length_len+length]
    assert len(unpacked_data) == length, \
      "data is only %d bytes long, but length is %d bytes" % \
      (len(unpacked_data), length)
    rest_data = packed_data[length_len+length:]
    return (unpacked_data, rest_data)

def add_chain(baseurl, submission):
    try:
        return json.loads(urllib2.urlopen(baseurl + "ct/v1/add-chain",
                                          json.dumps(submission)).read())
    except urllib2.HTTPError, e:
        print "ERROR:", e.read()
        sys.exit(1)

def get_entries(baseurl, start, end):
    try:
        params = urllib.urlencode({"start":start, "end":end})
        result = urllib2.urlopen(baseurl + "ct/v1/get-entries?" + params).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR:", e.read()
        sys.exit(1)

def decode_certificate_chain(packed_certchain):
    (unpacked_certchain, rest) = unpack_tls_array(packed_certchain, 3)
    assert len(rest) == 0
    certs = []
    while len(unpacked_certchain):
        (cert, rest) = unpack_tls_array(unpacked_certchain, 3)
        certs.append(cert)
        unpacked_certchain = rest
    return certs

def decode_signature(signature):
    (hash_alg, signature_alg) = struct.unpack(">bb", signature[0:2])
    (unpacked_signature, rest) = unpack_tls_array(signature[2:], 2)
    assert rest == ""
    return (hash_alg, signature_alg, unpacked_signature)

def check_signature(baseurl, leafcert, sct):
    publickey = base64.decodestring(publickeys[baseurl])
    calculated_logid = hashlib.sha256(publickey).digest()
    received_logid = base64.decodestring(sct["id"])
    assert calculated_logid == received_logid, \
        "log id is incorrect:\n  should be %s\n        got %s" % \
        (calculated_logid.encode("hex_codec"),
         received_logid.encode("hex_codec"))

    signature = base64.decodestring(sct["signature"])

    version = struct.pack(">b", sct["sct_version"])
    signature_type = struct.pack(">b", 0)
    timestamp = struct.pack(">Q", sct["timestamp"])
    entry_type = struct.pack(">H", 0)
    signed_struct = version + signature_type + timestamp + \
      entry_type + tls_array(leafcert, 3) + \
      tls_array(base64.decodestring(sct["extensions"]), 2)

    (hash_alg, signature_alg, unpacked_signature) = decode_signature(signature)
    assert hash_alg == 4 # sha256
    assert signature_alg == 3 # ecdsa

    hash = hashlib.sha256()
    hash.update(signed_struct)

    vk = ecdsa.VerifyingKey.from_der(publickey)
    vk.verify(unpacked_signature, signed_struct, hashfunc=hashlib.sha256,
              sigdecode=ecdsa.util.sigdecode_der)

def pack_mtl(timestamp, leafcert):
    entry_type = struct.pack(">H", 0)
    extensions = ""

    timestamped_entry = struct.pack(">Q", timestamp) + entry_type + \
      tls_array(leafcert, 3) + tls_array(extensions, 2)
    version = struct.pack(">b", 0)
    leaf_type = struct.pack(">b", 0)
    merkle_tree_leaf = version + leaf_type + timestamped_entry
    return merkle_tree_leaf

def get_leaf_hash(merkle_tree_leaf):
    leaf_hash = hashlib.sha256()
    leaf_hash.update(struct.pack(">b", 0))
    leaf_hash.update(merkle_tree_leaf)

    return leaf_hash.digest()