#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import json
import base64
import urllib
import urllib2
import sys
import time
import ecdsa
import hashlib
import urlparse
from certtools import build_merkle_tree, create_sth_signature, check_sth_signature, get_eckey_from_file, timing_point

parser = argparse.ArgumentParser(description="")
parser.add_argument("--baseurl", metavar="url", help="Base URL for CT server", required=True)
parser.add_argument("--frontend", action="append", metavar="url", help="Base URL for frontend server", required=True)
parser.add_argument("--storage", action="append", metavar="url", help="Base URL for storage server", required=True)
parser.add_argument("--mergedb", metavar="dir", help="Merge database directory", required=True)
parser.add_argument("--keyfile", metavar="keyfile", help="File containing log key", required=True)
parser.add_argument("--own-keyname", metavar="keyname", help="The key name of the merge node", required=True)
parser.add_argument("--own-keyfile", metavar="keyfile", help="The file containing the private key of the merge node", required=True)
parser.add_argument("--nomerge", action='store_true', help="Don't actually do merge")
args = parser.parse_args()

ctbaseurl = args.baseurl
frontendnodes = args.frontend
storagenodes = args.storage

chainsdir = args.mergedb + "/chains"
logorderfile = args.mergedb + "/logorder"

def parselogrow(row):
    return base64.b16decode(row)

def get_logorder():
    f = open(logorderfile, "r")
    return [parselogrow(row.rstrip()) for row in f]

def write_chain(key, value):
    f = open(chainsdir + "/" + base64.b16encode(key), "w")
    f.write(value)
    f.close()

def read_chain(key):
    f = open(chainsdir + "/" + base64.b16encode(key), "r")
    value = f.read()
    f.close()
    return value

def add_to_logorder(key):
    f = open(logorderfile, "a")
    f.write(base64.b16encode(key) + "\n")
    f.close()

def http_request(url, data=None):
    req = urllib2.Request(url, data)
    keyname = args.own_keyname
    privatekey = get_eckey_from_file(args.own_keyfile)
    sk = ecdsa.SigningKey.from_der(privatekey)
    parsed_url = urlparse.urlparse(url)
    if data == None:
        data = parsed_url.query
        method = "GET"
    else:
        method = "POST"
    signature = sk.sign("%s\0%s\0%s" % (method, parsed_url.path, data), hashfunc=hashlib.sha256,
                        sigencode=ecdsa.util.sigencode_der)
    req.add_header('X-Catlfish-Auth', base64.b64encode(signature) + ";key=" + keyname)
    result = urllib2.urlopen(req).read()
    return result

def get_new_entries(baseurl):
    try:
        result = http_request(baseurl + "ct/storage/fetchnewentries")
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return [base64.b64decode(entry) for entry in parsed_result[u"entries"]]
        print "ERROR: fetchnewentries", parsed_result
        sys.exit(1)
    except urllib2.HTTPError, e:
        print "ERROR: fetchnewentries", e.read()
        sys.exit(1)

def get_entries(baseurl, hashes):
    try:
        params = urllib.urlencode({"hash":[base64.b64encode(hash) for hash in hashes]}, doseq=True)
        result = http_request(baseurl + "ct/storage/getentry?" + params)
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            entries = dict([(base64.b64decode(entry["hash"]), base64.b64decode(entry["entry"])) for entry in parsed_result[u"entries"]])
            assert len(entries) == len(hashes)
            assert set(entries.keys()) == set(hashes)
            return entries
        print "ERROR: getentry", parsed_result
        sys.exit(1)
    except urllib2.HTTPError, e:
        print "ERROR: getentry", e.read()
        sys.exit(1)

def get_curpos(baseurl):
    try:
        result = http_request(baseurl + "ct/frontend/currentposition")
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return parsed_result[u"position"]
        print "ERROR: currentposition", parsed_result
        sys.exit(1)
    except urllib2.HTTPError, e:
        print "ERROR: currentposition", e.read()
        sys.exit(1)

def sendlog(baseurl, submission):
    try:
        result = http_request(baseurl + "ct/frontend/sendlog",
            json.dumps(submission))
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR: sendlog", e.read()
        sys.exit(1)
    except ValueError, e:
        print "==== FAILED REQUEST ===="
        print submission
        print "======= RESPONSE ======="
        print result
        print "========================"
        raise e

def sendentry(baseurl, entry, hash):
    try:
        result = http_request(baseurl + "ct/frontend/sendentry",
            json.dumps({"entry":base64.b64encode(entry), "treeleafhash":base64.b64encode(hash)}))
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR: sendentry", e.read()
        sys.exit(1)
    except ValueError, e:
        print "==== FAILED REQUEST ===="
        print hash
        print "======= RESPONSE ======="
        print result
        print "========================"
        raise e

def sendsth(baseurl, submission):
    try:
        result = http_request(baseurl + "ct/frontend/sendsth",
            json.dumps(submission))
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR: sendsth", e.read()
        sys.exit(1)
    except ValueError, e:
        print "==== FAILED REQUEST ===="
        print submission
        print "======= RESPONSE ======="
        print result
        print "========================"
        raise e

def get_missingentries(baseurl):
    try:
        result = http_request(baseurl + "ct/frontend/missingentries")
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return parsed_result[u"entries"]
        print "ERROR: missingentries", parsed_result
        sys.exit(1)
    except urllib2.HTTPError, e:
        print "ERROR: missingentries", e.read()
        sys.exit(1)

def chunks(l, n):
    return [l[i:i+n] for i in range(0, len(l), n)]

timing = timing_point()

logorder = get_logorder()

timing_point(timing, "get logorder")

certsinlog = set(logorder)

new_entries_per_node = {}
new_entries = set()
entries_to_fetch = {}

for storagenode in storagenodes:
    print "getting new entries from", storagenode
    new_entries_per_node[storagenode] = set(get_new_entries(storagenode))
    new_entries.update(new_entries_per_node[storagenode])
    entries_to_fetch[storagenode] = []

timing_point(timing, "get new entries")

new_entries -= certsinlog

print "adding", len(new_entries), "entries"

if args.nomerge:
    sys.exit(0)

for hash in new_entries:
    for storagenode in storagenodes:
        if hash in new_entries_per_node[storagenode]:
            entries_to_fetch[storagenode].append(hash)
            break


added_entries = 0
for storagenode in storagenodes:
    print "getting", len(entries_to_fetch[storagenode]), "entries from", storagenode
    for chunk in chunks(entries_to_fetch[storagenode], 100):
        entries = get_entries(storagenode, chunk)
        for hash in chunk:
            entry = entries[hash]
            write_chain(hash, entry)
            add_to_logorder(hash)
            logorder.append(hash)
            certsinlog.add(hash)
            added_entries += 1
timing_point(timing, "add entries")
print "added", added_entries, "entries"

tree = build_merkle_tree(logorder)
tree_size = len(logorder)
root_hash = tree[-1][0]
timestamp = int(time.time() * 1000)
privatekey = get_eckey_from_file(args.keyfile)

tree_head_signature = create_sth_signature(tree_size, timestamp,
                                           root_hash, privatekey)

sth = {"tree_size": tree_size, "timestamp": timestamp,
       "sha256_root_hash": base64.b64encode(root_hash),
       "tree_head_signature": base64.b64encode(tree_head_signature)}

check_sth_signature(ctbaseurl, sth)

timing_point(timing, "build sth")

print timing["deltatimes"]

print "root hash", base64.b16encode(root_hash)

for frontendnode in frontendnodes:
    timing = timing_point()
    print "distributing for node", frontendnode
    curpos = get_curpos(frontendnode)
    timing_point(timing, "get curpos")
    print "current position", curpos
    entries = [base64.b64encode(entry) for entry in logorder[curpos:]]
    for chunk in chunks(entries, 1000):
        sendlog(frontendnode, {"start": curpos, "hashes": chunk})
        curpos += len(chunk)
        print curpos,
        sys.stdout.flush()
    timing_point(timing, "sendlog")
    print "log sent"
    missingentries = get_missingentries(frontendnode)
    timing_point(timing, "get missing")
    print "missing entries:", len(missingentries)
    for missingentry in missingentries:
        hash = base64.b64decode(missingentry)
        sendentry(frontendnode, read_chain(hash), hash)
    timing_point(timing, "send missing")
    sendsth(frontendnode, sth)
    timing_point(timing, "send sth")
    print timing["deltatimes"]