/*
 * Copyright (c) 2016, NORDUnet A/S.
 * See LICENSE for licensing information.
 *
 * Invocation: dnssec <path-to-trust-anchor-file>
 *
 * Once running: Read DNSSEC RR's from stdin, canonicalise RR's
 * (RFC4034 6.2), validate DS RR (todo:ref) and write the result to
 * stdout.
 *
 * All length fields in the input and output denotes the length of the
 * piece of data to follow in number of octets. All integers are
 * transfered in network byte order (a.k.a. big-endian).
 *
 * Input format:
 * - Length of data in number of octets (integer, 4 octets)
 * - Validation time in seconds since the epoch (integer, 4 octets)
 * - Validation time skew in seconds (integer, 4 octets)
 * - DNSSEC RR's as a DNSSEC_key_chain, specified in
 *   draft-zhang-trans-ct-dnssec-03 section 4.1 but without the TLS
 *   data structure encoding.
 *
 * Output format:
 * - Lenght of data (integer, 4 octets)
 * - Status code -- the getdns_return_t value (integer, 2 octets)
 * - (RR's)* -- if validation succeeded: the DS+RRSIG and the full
 *   chain up to and including the trust anchor; if validation failed:
 *   nothing
 *
 * (RR's)* denotes zero or more RR's.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <time.h>
#include <arpa/inet.h>
#include <getopt.h>
#include <getdns/getdns.h>
#include <getdns/getdns_extra.h>
#include "erlport.h"
#include "dnssec_test.h"

static int debug = 0;           /* DEBUG */

#define hd(b, l) {               	 	\
  for (size_t n = 0; n < (l); n++) {            \
    if (n % 16 == 0) {                    	\
          if (n != 0) fprintf(stderr, "\n"); 	\
          fprintf(stderr, "%08x  ", n);         \
      } else if (n % 8 == 0) {                  \
        fprintf(stderr, " "); }                 \
    fprintf(stderr, "%02hhx ", (b)[n]); }       \
  fprintf(stderr, "\n"); }

#if defined(TEST)
static char *testmode = NULL;
#endif

static void
print_tree(FILE *fp, const getdns_list *tree, const char *name)
{
  if (name)
    fprintf(fp, "* %s\n", name);

    char *s = getdns_pretty_print_list(tree);
    fputs(s, fp);
    free(s);
}

/* TODO: Replace read_file() and wire_rrs2list() with getdns_fp2rr_list()? */
size_t
read_file(FILE *infp, uint8_t **bufp_out, size_t size_hint)
{
#define CHUNKSIZE 4096
  size_t nread = 0;
  uint8_t *wirebuf = NULL;
  size_t chunksize = CHUNKSIZE;
  int chunks = 1;

  if (size_hint > 0)
    chunksize = size_hint;
  wirebuf = malloc(chunksize);

  if (wirebuf == NULL)
    goto out;

  while (1)
    {
      size_t n = fread(wirebuf + nread, 1, chunksize, infp);
      nread += n;
      if (n < chunksize)
        break;                  /* Done. */

      wirebuf = realloc(wirebuf, ++chunks * chunksize);
      if (wirebuf == NULL)
        break;
    }

 out:
  if (bufp_out != NULL)
    *bufp_out = wirebuf;
  return nread;
}

static getdns_return_t
wire_rrs2list(const uint8_t *buf, size_t buf_len, getdns_list **list_out)
{
  getdns_return_t r = GETDNS_RETURN_GOOD;
  getdns_list *list = getdns_list_create();
  getdns_dict *dict = NULL;
  size_t rr_count = 0;

  if (list == NULL)
    return GETDNS_RETURN_MEMORY_ERROR;
  while (buf_len > 0)
    {
      r = getdns_wire2rr_dict_scan(&buf, &buf_len, &dict);
      if (r)
        break;
      r = getdns_list_set_dict(list, rr_count, dict);
      getdns_dict_destroy(dict); /* The list has a copy. */
      if (r)
        break;
      rr_count++;
    }

  if (list_out)
    *list_out = list;
  return r;
}

static int
read_trust_anchors(const char *fname, getdns_list **list_out)
{
  if (debug)
    fprintf(stderr, "reading trust anchors from file %s\n", fname);

  FILE *fp = fopen(fname, "r");
  if (fp == NULL)
    return -errno;
  uint8_t *buf = NULL;
  size_t n = read_file(fp, &buf, 0);
  if (fclose(fp))
    fprintf(stderr, "unable to close trust anchors file: %d", errno);
  int r = wire_rrs2list(buf, n, list_out);
  free(buf);
  return r;
}

static getdns_return_t
list2wire(const getdns_list *list, uint8_t *out_buf, size_t *out_buf_len)
{
  size_t list_len;
  getdns_return_t r = getdns_list_get_length(list, &list_len);
  if (r)
    return r;

  getdns_dict *rr = NULL;
  uint8_t *buf = NULL;
  size_t buf_len = 0;
  *out_buf_len = 0;
  for (int i = 0; i < list_len; i++)
    {
      if ((r = getdns_list_get_dict(list, i , &rr)))
        return r;
      if ((r = getdns_rr_dict2wire(rr, &buf, &buf_len)))
        return r;               /* FIXME: Risk of leaking buf? */
      memcpy(out_buf, buf, buf_len);
      free(buf);
      out_buf += buf_len;
      *out_buf_len += buf_len;
    }

  return r;
}

#if !defined(TEST)
static getdns_return_t
validate(const uint8_t *buf, size_t buf_len,
         getdns_list *trust_anchors,
         time_t validation_time, uint32_t skew,
         uint8_t *out_buf, size_t *out_buf_len)
{
  getdns_return_t r = GETDNS_DNSSEC_INDETERMINATE;
  getdns_list *list = NULL;
  getdns_list *to_validate = getdns_list_create();
  getdns_list *support_records = getdns_list_create();

  if (to_validate == NULL || support_records == NULL)
    return GETDNS_RETURN_MEMORY_ERROR;

  /* Convert RR's in buf to dicts in a list. */
  if ((r = wire_rrs2list(buf, buf_len, &list)))
    goto out;

  /* First record MUST be the DS RR to validate. Second record MUST be
     an RRSIG covering the DS RR. Copy those to to_validate. */
  getdns_dict *ds_dict = NULL;
  getdns_dict *rrsig_ds_dict = NULL;
  uint32_t rrtype = 0;
  /* DS */
  if ((r = getdns_list_get_dict(list, 0, &ds_dict)))
    goto out;
  if ((r = getdns_dict_get_int(ds_dict, "type", &rrtype)))
    goto out;
  if (rrtype != GETDNS_RRTYPE_DS)
    {
      r = GETDNS_RETURN_INVALID_PARAMETER;
      goto out;
    }
  if ((r = getdns_list_set_dict(to_validate, 0, ds_dict)))
    goto out;
  /* RRSIG DS */
  if ((r = getdns_list_get_dict(list, 1, &rrsig_ds_dict)))
    goto out;
  if ((r = getdns_dict_get_int(rrsig_ds_dict, "type", &rrtype)))
    goto out;
  if (rrtype != GETDNS_RRTYPE_RRSIG)
    {
      r = GETDNS_RETURN_INVALID_PARAMETER;
      goto out;
    }
  if ((r = getdns_list_set_dict(to_validate, 1, rrsig_ds_dict)))
    goto out;

  /* The rest is "support records". Copy them to support_records. */
  size_t list_len;
  if ((r = getdns_list_get_length(list, &list_len)))
    goto out;
  for (int i = 2; i < list_len; i++)
    {
      getdns_dict *tmp_dict = NULL;
      if ((r = getdns_list_get_dict(list, i, &tmp_dict)))
        goto out;
      if ((r = getdns_list_set_dict(support_records, i - 2, tmp_dict)))
        goto out;
    }

  if (0 && debug)
    {
      print_tree(stderr, to_validate, "to_validate");
      print_tree(stderr, support_records, "support_records");
      print_tree(stderr, trust_anchors, "trust_anchors");
    }

  r = getdns_validate_dnssec2(to_validate,
                              support_records,
                              trust_anchors,
                              validation_time,
                              skew);

  if (out_buf && out_buf_len)
    {
      getdns_return_t r_save = r;
      size_t len = 0;
      *out_buf_len = 0;

      if ((r = list2wire(to_validate, out_buf, &len)))
        goto out;
      out_buf += len;
      *out_buf_len += len;

      if ((r = list2wire(support_records, out_buf, &len)))
        goto out;
      out_buf += len;
      *out_buf_len += len;

      if ((r = list2wire(trust_anchors, out_buf, &len)))
        goto out;
      out_buf += len;
      *out_buf_len += len;

      r = r_save;
    }

out:
  if (list)
    getdns_list_destroy(list);
  getdns_list_destroy(to_validate);
  getdns_list_destroy(support_records);
  return r;
}
#endif  /* !TEST */

static void
loop(getdns_list *trust_anchors)
{
  getdns_return_t r = GETDNS_RETURN_GENERIC_ERROR;
  unsigned char buf[64 * 1024]; /* FIXME */
  ssize_t len;

  while ((len = read_command(buf, sizeof(buf), 4)) > 0)
    {
      unsigned char reply[2 * 64 * 1024]; /* FIXME */
      size_t out_len = 0;

#if !defined(TEST)
      unsigned char *bufp = buf;
      uint32_t validation_time = ntohl(*((uint32_t *)bufp));
      bufp += 4;
      uint32_t validation_time_skew = ntohl(*((uint32_t *)bufp));
      bufp += 4;
      r = validate(bufp, len - 8, trust_anchors,
                   validation_time, validation_time_skew,
                   reply + 2, &out_len);
#else
      r = test_validate(buf, len, trust_anchors, testmode);
#endif

      if (debug)
        {
          int intr = r; /* GETDNS_DNSSEC_SECURE is not in enum getdns_return_t */
          switch (intr)
            {
            case GETDNS_DNSSEC_SECURE:
              fprintf(stderr, "validation successful\n");
              break;
            default:
              fprintf(stderr, "validation error %d (%s)\n",
                      r, getdns_get_errorstr_by_id(r));
            }
        }

      *((uint16_t *) reply) = htons(r);
      if (debug)
        fprintf(stderr, "writing %d octets of data, including status code %d\n",
                2 + out_len, r);
      if (write_reply(reply, 2 + out_len, 4))
        fprintf(stderr, "error writing reply\n");
    }
}


int
main(int argc, char *argv[])
{
  int c;
  getdns_list *trust_anchors = NULL;
  time_t trust_anchor_date;

  /* Parse command line. */
  while (1) {
    static struct option long_options[] = {
      {"testmode", required_argument, NULL, 't'},
      {0, 0, 0, 0}};

    c = getopt_long(argc, argv, "", long_options, NULL);
    if (c == -1)
      break;

    switch (c)
      {
#if defined(TEST)
      case 't':
        testmode = optarg;
        break;
#endif
      default:
        fprintf(stderr, "bad option: %s", argv[optind]);
        return -1;
      }
  }

  /* Read trust anchors file. */
  if (optind < argc)
    {
      int r = read_trust_anchors(argv[optind], &trust_anchors);
      if (r < 0)
        {
          perror("read trust anchors");
          return -r;
        }
      else if (r > 0)
        {
          fprintf(stderr,
                  "unable to read trust anchors file %s: %d (%s)",
                  argv[optind], r, getdns_get_errorstr_by_id(r));
          return r;
        }
    }
  else                         /* DEBUG: Using getdns trust anchor. */
    {
      trust_anchors = getdns_root_trust_anchor(&trust_anchor_date);
    }

  /* Eternal loop. */
  loop(trust_anchors);

  /* Not reached. */
  return 0;
}