#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2015, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
from certtools import *
import zipfile
import os
import time

parser = argparse.ArgumentParser(description='')
parser.add_argument('--store', default=None, metavar="dir", help='Certificates directory')

parser.add_argument('--head', action='store_true', help="Calculate tree head")
parser.add_argument('--printnode', action='store_true', help="Print tree node")

parser.add_argument('--treesize', type=int, default=None, metavar="treesize", help="Tree size")
parser.add_argument('--level', type=int, default=None, metavar="level", help="Level")
parser.add_argument('--index', type=int, default=None, metavar="index", help="Index")

parser.add_argument('--follow', action='store_true', help="Follow upwards")

parser.add_argument('--dot', default=None, metavar="file", help='Output data in dot format')

args = parser.parse_args()

def index_to_root(index, treesize, level=0):
    path = (index, level)
    height = merkle_height(treesize)
    nodes = []
    while node_level(path) < height:
        nodes.append(path)
        path = node_above(path)
    return nodes

def set_tree_node(tree, level, index, value, overwrite=True):
    if not overwrite and index in levels.setdefault(level, {}):
        return
    levels.setdefault(level, {})[index] = value

def draw_path(tree, startlevel, startindex, treesize, colors):
    height = merkle_height(treesize)
    nodes = index_to_root(startindex, treesize, level=startlevel)

    for (index, level) in nodes:
        if level == 0:
            set_tree_node(tree, level, index, colors[0])
        else:
            set_tree_node(tree, level, index, colors[1])
        index ^= 1
        levelsize = 2 ** level
        firstleaf = index * levelsize
        if firstleaf < treesize:
            set_tree_node(tree, level, index, "", overwrite=False)
    set_tree_node(tree, height, 0, colors[1])
    

if args.head:
    treehead = get_tree_head(args.store, args.treesize)
    print base64.b16encode(treehead)
elif args.dot:
    levels = {}
    if args.index >= args.treesize:
        sys.exit(1)
    dotfile = open(args.dot, "w")
    print >>dotfile, 'graph "" {'
    print >>dotfile, 'ordering=out;'
    print >>dotfile, 'node [style=filled];'

    height = merkle_height(args.treesize)

    draw_path(levels, 0, args.treesize - 1, args.treesize, ["0.600 0.500 0.900", "0.600 0.300 0.900"])

    draw_path(levels, args.level, args.index, args.treesize, ["0.300 0.500 0.900", "0.300 0.200 0.900"])

    for l in sorted(levels.keys(), reverse=True):
        for i in sorted(levels[l].keys()):
            print >>dotfile, "l%di%d [color=\"%s\" label=\"%s\"];" % (l, i, levels[l][i], path_as_string(i, l, args.treesize))
            if height != l:
                print >>dotfile, "l%di%d -- l%di%d;" % (l + 1, i / 2, l, i)
                if i & 1 == 0:
                    print >>dotfile, "ml%di%d [shape=point style=invis];" % (l, i)
                    print >>dotfile, "l%di%d -- ml%di%d [weight=100 style=invis];" % (l + 1, i / 2, l, i)
    print >>dotfile, "}"
    dotfile.close()
elif args.printnode:
    index = args.index
    level = args.level
    if args.index >= args.treesize:
        sys.exit(1)
    height = merkle_height(args.treesize)
    nodes = index_to_root(index, args.treesize, level=level)

    for (index, level) in nodes:
        print level, index
        if args.store:
            print base64.b16encode(get_intermediate_hash(args.store, args.treesize, level, index))

        if not args.follow:
            sys.exit(0)

        index ^= 1
        levelsize = 2 ** level

        firstleaf = index * levelsize
        if firstleaf < args.treesize:
            print level, index
            if args.store:
                print base64.b16encode(get_intermediate_hash(args.store, args.treesize, level, index))

    print height, 0
    if args.store:
        print base64.b16encode(get_intermediate_hash(args.store, args.treesize, height, 0))