# Copyright (c) 2015, NORDUnet A/S.
# See LICENSE for licensing information.

import sys
import hashlib
import rfc2459
from pyasn1.type import univ, tag
from pyasn1.codec.der import encoder, decoder

def cleanextensions(extensions):
    result = rfc2459.Extensions().subtype(explicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 3))
    for idx in range(len(extensions)):
        extension = extensions.getComponentByPosition(idx)
        if extension.getComponentByName("extnID") == univ.ObjectIdentifier("1.3.6.1.4.1.11129.2.4.3"):
            pass
        else:
            result.setComponentByPosition(len(result), extension)
    return result

def decode_any(anydata, asn1Spec=None):
    (wrapper, _) = decoder.decode(anydata)
    (data, _) = decoder.decode(wrapper, asn1Spec=asn1Spec)
    return data

def get_subject(cert):
    (asn1,rest) = decoder.decode(cert, asn1Spec=rfc2459.Certificate())
    assert rest == ''
    tbsCertificate = asn1.getComponentByName("tbsCertificate")
    subject = tbsCertificate.getComponentByName("subject")
    extensions = tbsCertificate.getComponentByName("extensions")
    keyid_wrapper = get_extension(extensions, rfc2459.id_ce_subjectKeyIdentifier)
    keyid = decode_any(keyid_wrapper, asn1Spec=rfc2459.KeyIdentifier())
    return (subject, keyid)

def cleanprecert(precert, issuer=None):
    (asn1,rest) = decoder.decode(precert, asn1Spec=rfc2459.Certificate())
    assert rest == ''
    tbsCertificate = asn1.getComponentByName("tbsCertificate")

    extensions = tbsCertificate.getComponentByName("extensions")
    tbsCertificate.setComponentByName("extensions", cleanextensions(extensions))

    if issuer:
        (issuer_subject, keyid) = get_subject(issuer)
        tbsCertificate.setComponentByName("issuer", issuer_subject)
        authkeyid = rfc2459.AuthorityKeyIdentifier()
        authkeyid.setComponentByName("keyIdentifier",
            rfc2459.KeyIdentifier(str(keyid)).subtype(implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 0)))
        authkeyid_wrapper = univ.OctetString(encoder.encode(authkeyid))
        authkeyid_wrapper2 = encoder.encode(authkeyid_wrapper)
        set_extension(extensions, rfc2459.id_ce_authorityKeyIdentifier, authkeyid_wrapper2)
    return encoder.encode(tbsCertificate)

def get_extension(extensions, id):
    for idx in range(len(extensions)):
        extension = extensions.getComponentByPosition(idx)
        if extension.getComponentByName("extnID") == id:
            return extension.getComponentByName("extnValue")
    return None

def set_extension(extensions, id, value):
    result = rfc2459.Extensions().subtype(explicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 3))
    for idx in range(len(extensions)):
        extension = extensions.getComponentByPosition(idx)
        if extension.getComponentByName("extnID") == id:
            extension.setComponentByName("extnValue", value)

def get_cert_key_hash(cert):
    (asn1,rest) = decoder.decode(cert, asn1Spec=rfc2459.Certificate())
    assert rest == ''
    tbsCertificate = asn1.getComponentByName("tbsCertificate")
    key = encoder.encode(tbsCertificate.getComponentByName("subjectPublicKeyInfo"))
    hash = hashlib.sha256()
    hash.update(key)
    return hash.digest()

def printcert(cert, outfile=sys.stdout):
    (asn1,rest) = decoder.decode(cert, asn1Spec=rfc2459.Certificate())
    assert rest == ''
    print >>outfile, asn1.prettyPrint()

def printtbscert(cert, outfile=sys.stdout):
    (asn1,rest) = decoder.decode(cert, asn1Spec=rfc2459.TBSCertificate())
    assert rest == ''
    print >>outfile, asn1.prettyPrint()

ext_key_usage_precert_signing_cert = univ.ObjectIdentifier("1.3.6.1.4.1.11129.2.4.4")

def get_ext_key_usage(cert):
    (asn1,rest) = decoder.decode(cert, asn1Spec=rfc2459.Certificate())
    assert rest == ''
    tbsCertificate = asn1.getComponentByName("tbsCertificate")
    extensions = tbsCertificate.getComponentByName("extensions")
    for idx in range(len(extensions)):
        extension = extensions.getComponentByPosition(idx)
        if extension.getComponentByName("extnID") == rfc2459.id_ce_extKeyUsage:
            ext_key_usage_wrapper_binary = extension.getComponentByName("extnValue")
            (ext_key_usage_wrapper, _) = decoder.decode(ext_key_usage_wrapper_binary)
            (ext_key_usage, _) = decoder.decode(ext_key_usage_wrapper)#, asn1Spec=rfc2459.ExtKeyUsageSyntax())
            return list(ext_key_usage)
    return []