#!/usr/bin/env python

# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
from certtools import *
from certtools import *
from precerttools import *
import os
import signal
import select
import zipfile

def readfile(filename):
    contents = open(filename).read()
    certchain = get_certs_from_string(contents)
    precerts = get_precerts_from_string(contents)
    return (certchain, precerts)

def testcerts(template, test):
    (certchain1, precerts1) = template
    (certchain2, precerts2) = test

    if precerts1 != precerts2:
        return (False, "precerts are different")

    if certchain1 == certchain2:
        return (True, "")

    if len(certchain2) == len(certchain1) + 1:
        if certchain2[:-1] != certchain1:
            return (False, "certchains are different")
        last_issuer = get_cert_info(certchain1[-1])["issuer"]
        root_subject = get_cert_info(certchain2[-1])["subject"]
        if last_issuer == root_subject:
            return (True, "fetched chain has an appended root cert")
        else:
            return (False, "fetched chain has an extra entry")

    return (False, "certchains are different")

parser = argparse.ArgumentParser(description='')
parser.add_argument('templates', help="Test templates, separated with colon")
parser.add_argument('test', help="Files to test, separated with colon")
args = parser.parse_args()

templates = [readfile(filename) for filename in args.templates.split(":")]

tests = [readfile(filename) for filename in args.test.split(":")]


for test in tests:
    found = False
    errors = []
    for template in templates:
        (result, message) = testcerts(template, test)
        if result:
            print message
            found = True
            templates.remove(template)
            break
        else:
            errors.append(message)
    if not found:
        print "Matching template not found for test"
        for error in errors:
            print error
        sys.exit(1)
sys.exit(0)