#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2015, NORDUnet A/S.
# See LICENSE for licensing information.

import sys
import argparse
import json
import errno
import shutil
import base64
from datetime import datetime, timedelta, tzinfo
import time
from certtools import get_sth, create_ssl_context, check_sth_signature, get_public_key_from_file, get_consistency_proof, verify_consistency_proof

NAGIOS_OK = 0
NAGIOS_WARN = 1
NAGIOS_CRIT = 2
NAGIOS_UNKNOWN = 3

DEFAULT_CUR_FILE = 'cur-sth.json'

parser = argparse.ArgumentParser(description="")
parser.add_argument('--cur-sth',
                    metavar='file',
                    default=DEFAULT_CUR_FILE,
                    help="File containing current STH (default=%s)" % DEFAULT_CUR_FILE)
parser.add_argument('baseurl', help="Base URL for CT log")
parser.add_argument('--publickey', default=None, metavar='file', help='Public key for the CT log')
parser.add_argument('--cafile', default=None, metavar='file', help='File containing the CA cert')
parser.add_argument('--allow-lag', action='store_true', help='Allow node to lag behind previous STH')
parser.add_argument('--quiet-ok', action='store_true', help="Don't print status if OK")

def print_sth(sth):
    if sth is None:
        print "NONE"
    else:
        print sth['timestamp']
        print sth['sha256_root_hash']
        print sth['tree_size']
        print sth['tree_head_signature']

def get_new_sth(baseurl):
    try:
        sth = get_sth(baseurl)
    except Exception, e:
        print e
        sys.exit(NAGIOS_UNKNOWN)
    return sth
        
def read_sth(fn):
    try:
        f = open(fn)
    except IOError, e:
        if e.errno == errno.ENOENT:
            return None
        raise e
    return json.loads(f.read())

def mv_file(fromfn, tofn):
    shutil.move(fromfn, tofn)

def write_file(fn, sth):
    tempname = fn + ".new"
    open(tempname, 'w').write(json.dumps(sth))
    mv_file(tempname, fn)

class UTC(tzinfo):
    def utcoffset(self, dt):
      return timedelta(hours=0)
    def dst(self, dt):
        return timedelta(0)

def check_age(sth):
    age = time.time() - sth["timestamp"]/1000
    sth_time = datetime.fromtimestamp(sth['timestamp'] / 1000, UTC()).strftime("%Y-%m-%d %H:%M:%S")
    roothash = b64_to_b16(sth['sha256_root_hash'])
    if age > 6 * 3600:
        print "CRITICAL: %s is older than 6h: %s UTC" % (roothash, sth_time)
        sys.exit(NAGIOS_CRIT)
    if age > 2 * 3600:
        print "WARNING: %s is older than 2h: %s UTC" % (roothash, sth_time)
        sys.exit(NAGIOS_WARN)
    return "%s UTC, %d minutes ago" % (sth_time, age/60)

def check_consistency(newsth, oldsth, baseurl):
    consistency_proof = [base64.decodestring(entry) for entry in get_consistency_proof(baseurl, oldsth["tree_size"], newsth["tree_size"])]
    (old_treehead, new_treehead) = verify_consistency_proof(consistency_proof, oldsth["tree_size"], newsth["tree_size"], base64.b64decode(oldsth["sha256_root_hash"]))
    assert old_treehead == base64.b64decode(oldsth["sha256_root_hash"])
    assert new_treehead == base64.b64decode(newsth["sha256_root_hash"])

def b64_to_b16(s):
    return base64.b16encode(base64.decodestring(s))

def main(args):
    if args.cur_sth is None:
        args.cur_sth = "cur-sth.json"

    create_ssl_context(cafile=args.cafile)

    logpublickey = get_public_key_from_file(args.publickey) if args.publickey else None

    newsth = get_new_sth(args.baseurl)
    check_sth_signature(args.baseurl, newsth, publickey=logpublickey)

    oldsth = read_sth(args.cur_sth)

    #print_sth(newsth)
    #print_sth(oldsth)

    if oldsth:
        if newsth["tree_size"] == oldsth["tree_size"]:
            if oldsth["sha256_root_hash"] != newsth["sha256_root_hash"]:
                print "CRITICAL: root hash is different even though tree size is the same.",
                print "tree size:", newsth["tree_size"],
                print "old hash:", b64_to_b16(oldsth["sha256_root_hash"])
                print "new hash:", b64_to_b16(newsth["sha256_root_hash"])
                sys.exit(NAGIOS_CRIT)
        elif newsth["tree_size"] < oldsth["tree_size"]:
            if not args.allow_lag:
                print "CRITICAL: new tree smaller than previous tree (%d < %d)" % \
                  (newsth["tree_size"], oldsth["tree_size"])
                sys.exit(NAGIOS_CRIT)

    if oldsth and oldsth["tree_size"] > 0 and oldsth["tree_size"] != newsth["tree_size"]:
        check_consistency(newsth, oldsth, args.baseurl)

    age = check_age(newsth)

    write_file(args.cur_sth, newsth)

    if not args.quiet_ok:
        print "OK: size: %d hash: %s, %s" % (newsth["tree_size"], b64_to_b16(newsth["sha256_root_hash"])[:8], age)
    sys.exit(NAGIOS_OK)

if __name__ == '__main__':
    main(parser.parse_args())