#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import json
import base64
import urllib
import urllib2
import sys
import time
from certtools import build_merkle_tree, create_sth_signature, check_sth_signature, get_eckey_from_file

parser = argparse.ArgumentParser(description="")
parser.add_argument("--baseurl", metavar="url", help="Base URL for CT server", required=True)
parser.add_argument("--frontend", action="append", metavar="url", help="Base URL for frontend server", required=True)
parser.add_argument("--storage", action="append", metavar="url", help="Base URL for storage server", required=True)
parser.add_argument("--mergedb", metavar="dir", help="Merge database directory", required=True)
parser.add_argument("--keyfile", metavar="keyfile", help="File containing log key", required=True)
args = parser.parse_args()

ctbaseurl = args.baseurl
frontendnodes = args.frontend
storagenodes = args.storage

chainsdir = args.mergedb + "/chains"
logorderfile = args.mergedb + "/logorder"

def parselogrow(row):
    return base64.b16decode(row)

def get_logorder():
    f = open(logorderfile, "r")
    return [parselogrow(row.rstrip()) for row in f]

def write_chain(key, value):
    f = open(chainsdir + "/" + base64.b16encode(key), "w")
    f.write(value)
    f.close()

def read_chain(key):
    f = open(chainsdir + "/" + base64.b16encode(key), "r")
    value = f.read()
    f.close()
    return value

def add_to_logorder(key):
    f = open(logorderfile, "a")
    f.write(base64.b16encode(key) + "\n")
    f.close()

def get_new_entries(baseurl):
    try:
        result = urllib2.urlopen(baseurl + "ct/storage/fetchnewentries").read()
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return parsed_result[u"entries"]
        print "ERROR: fetchnewentries", parsed_result
        sys.exit(1)
    except urllib2.HTTPError, e:
        print "ERROR: fetchnewentries", e.read()
        sys.exit(1)

def get_entry(baseurl, hash):
    try:
        params = urllib.urlencode({"hash":base64.b64encode(hash)})
        result = urllib2.urlopen(baseurl + "ct/storage/getentry?" + params).read()
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            entries = parsed_result[u"entries"]
            assert len(entries) == 1
            assert base64.b64decode(entries[0]["hash"]) == hash
            return base64.b64decode(entries[0]["entry"])
        print "ERROR: getentry", parsed_result
        sys.exit(1)
    except urllib2.HTTPError, e:
        print "ERROR: getentry", e.read()
        sys.exit(1)

def get_curpos(baseurl):
    try:
        result = urllib2.urlopen(baseurl + "ct/frontend/currentposition").read()
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return parsed_result[u"position"]
        print "ERROR: currentposition", parsed_result
        sys.exit(1)
    except urllib2.HTTPError, e:
        print "ERROR: currentposition", e.read()
        sys.exit(1)

def sendlog(baseurl, submission):
    try:
        result = urllib2.urlopen(baseurl + "ct/frontend/sendlog",
            json.dumps(submission)).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR: sendlog", e.read()
        sys.exit(1)
    except ValueError, e:
        print "==== FAILED REQUEST ===="
        print submission
        print "======= RESPONSE ======="
        print result
        print "========================"
        raise e

def sendentry(baseurl, entry, hash):
    try:
        result = urllib2.urlopen(baseurl + "ct/frontend/sendentry",
            json.dumps({"entry":base64.b64encode(entry), "treeleafhash":base64.b64encode(hash)})).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR: sendentry", e.read()
        sys.exit(1)
    except ValueError, e:
        print "==== FAILED REQUEST ===="
        print hash
        print "======= RESPONSE ======="
        print result
        print "========================"
        raise e

def sendsth(baseurl, submission):
    try:
        result = urllib2.urlopen(baseurl + "ct/frontend/sendsth",
            json.dumps(submission)).read()
        return json.loads(result)
    except urllib2.HTTPError, e:
        print "ERROR: sendsth", e.read()
        sys.exit(1)
    except ValueError, e:
        print "==== FAILED REQUEST ===="
        print submission
        print "======= RESPONSE ======="
        print result
        print "========================"
        raise e

def get_missingentries(baseurl):
    try:
        result = urllib2.urlopen(baseurl + "ct/frontend/missingentries").read()
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return parsed_result[u"entries"]
        print "ERROR: missingentries", parsed_result
        sys.exit(1)
    except urllib2.HTTPError, e:
        print "ERROR: missingentries", e.read()
        sys.exit(1)


logorder = get_logorder()
certsinlog = set(logorder)

new_entries = [entry for storagenode in storagenodes for entry in get_new_entries(storagenode)]

print "adding entries"
added_entries = 0
for new_entry in new_entries:
    hash = base64.b64decode(new_entry)
    if hash not in certsinlog:
        entry = get_entry(storagenode, hash)
        write_chain(hash, entry)
        add_to_logorder(hash)
        logorder.append(hash)
        certsinlog.add(hash)
        added_entries += 1
print "added", added_entries, "entries"

tree = build_merkle_tree(logorder)
tree_size = len(logorder)
root_hash = tree[-1][0]
timestamp = int(time.time() * 1000)
privatekey = get_eckey_from_file(args.keyfile)

tree_head_signature = create_sth_signature(tree_size, timestamp,
                                           root_hash, privatekey)

sth = {"tree_size": tree_size, "timestamp": timestamp,
       "sha256_root_hash": base64.b64encode(root_hash),
       "tree_head_signature": base64.b64encode(tree_head_signature)}

check_sth_signature(ctbaseurl, sth)

print "root hash", base64.b16encode(root_hash)

for frontendnode in frontendnodes:
    print "distributing for node", frontendnode
    curpos = get_curpos(frontendnode)
    print "current position", curpos
    entries = [base64.b64encode(entry) for entry in logorder[curpos:]]
    sendlog(frontendnode, {"start": curpos, "hashes": entries})
    print "log sent"
    missingentries = get_missingentries(frontendnode)
    print "missing entries:", missingentries
    for missingentry in missingentries:
        hash = base64.b64decode(missingentry)
        sendentry(frontendnode, read_chain(hash), hash)
    sendsth(frontendnode, sth)