#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
from certtools import *

parser = argparse.ArgumentParser(description='')
parser.add_argument('baseurl', help="Base URL for CT server")
parser.add_argument('--store', default=None, metavar="dir", help='Store certificates in directory dir')
args = parser.parse_args()

def extract_original_entry(entry):
    leaf_input =  base64.decodestring(entry["leaf_input"])
    (leaf_cert, timestamp) = unpack_mtl(leaf_input)
    extra_data = base64.decodestring(entry["extra_data"])
    certchain = decode_certificate_chain(extra_data)
    return [leaf_cert] + certchain

def get_entries_wrapper(baseurl, start, end):
    fetched_entries = []
    while start + len(fetched_entries) < (end + 1):
        print "fetching from", start + len(fetched_entries)
        entries = get_entries(baseurl, start + len(fetched_entries), end)["entries"]
        if len(entries) == 0:
            break
        fetched_entries.extend(entries)
    return fetched_entries

def print_layer(layer):
    for entry in layer:
        print base64.b16encode(entry)

sth = get_sth(args.baseurl)
tree_size = sth["tree_size"]
root_hash = base64.decodestring(sth["sha256_root_hash"])

print "tree size", tree_size
print "root hash", base64.b16encode(root_hash)

entries = get_entries_wrapper(args.baseurl, 0, tree_size - 1)

print "fetched", len(entries), "entries"

layer0 = [get_leaf_hash(base64.decodestring(entry["leaf_input"])) for entry in entries]

tree = build_merkle_tree(layer0)

calculated_root_hash = tree[-1][0]

print "calculated root hash", base64.b16encode(calculated_root_hash)

if calculated_root_hash != root_hash:
    print "fetched root hash and calculated root hash different, aborting"
    sys.exit(1)

if args.store:
    for entry, i in zip(entries, range(0, len(entries))):
        chain = extract_original_entry(entry)
        f = open(args.store + "/" + ("%06d" % i), "w")
        for cert in chain:
            print >> f, "-----BEGIN CERTIFICATE-----"
            print >> f, base64.encodestring(cert).rstrip()
            print >> f, "-----END CERTIFICATE-----"
            print >> f, ""