%%% Copyright (c) 2014, NORDUnet A/S.
%%% See LICENSE for licensing information.

-module(catlfish).
-export([add_chain/2, entries/2, entry_and_proof/2]).

-define(PROTOCOL_VERSION, 0).

%%-type signature_type() :: certificate_timestamp | tree_hash | test. % uint8
-type entry_type() :: x509_entry | precert_entry | test. % uint16
-type leaf_type() :: timestamped_entry | test.           % uint8
-type leaf_version() :: v1 | v2.                         % uint8

-record(mtl, {leaf_version :: leaf_version(),
              leaf_type :: leaf_type(),
              entry :: timestamped_entry()}).
-type mtl() :: #mtl{}.

-record(timestamped_entry, {timestamp :: integer(),
                            entry_type :: entry_type(),
                            signed_entry :: binary(),
                            extensions = <<>> :: binary()}).
-type timestamped_entry() :: #timestamped_entry{}.

-spec serialise(mtl() | timestamped_entry()) -> binary().
serialise(#timestamped_entry{timestamp = Timestamp} = E) ->
    list_to_binary(
      [<<Timestamp:64>>,
       serialise_entry_type(E#timestamped_entry.entry_type),
       encode_tls_vector(E#timestamped_entry.signed_entry, 3),
       encode_tls_vector(E#timestamped_entry.extensions, 2)]);
serialise(#mtl{leaf_version = LeafVersion,
               leaf_type = LeafType,
               entry = TimestampedEntry}) ->
    list_to_binary(
      [serialise_leaf_version(LeafVersion),
       serialise_leaf_type(LeafType),
       serialise(TimestampedEntry)]).

serialise_leaf_version(v1) ->
    <<0:8>>;
serialise_leaf_version(v2) ->
    <<1:8>>.

serialise_leaf_type(timestamped_entry) ->
    <<0:8>>.
%% serialise_leaf_type(_) ->
%%     <<>>.

serialise_entry_type(x509_entry) ->
    <<0:16>>;
serialise_entry_type(precert_entry) ->
    <<1:16>>.

serialise_signature_type(certificate_timestamp) ->
    <<0:8>>;
serialise_signature_type(tree_hash) ->
    <<1:8>>.

build_mtl(Timestamp, LeafCert) ->
    TSE = #timestamped_entry{timestamp = Timestamp,
                             entry_type = x509_entry,
                             signed_entry = LeafCert},
    MTL = #mtl{leaf_version = v1,
               leaf_type = timestamped_entry,
               entry = TSE},
    serialise(MTL).

-spec add_chain(binary(), [binary()]) -> nonempty_string().
add_chain(LeafCert, CertChain) ->
    EntryHash = crypto:hash(sha256, LeafCert),
    TimestampedEntry =
        case plop:get(EntryHash) of
            notfound ->
                Timestamp = plop:generate_timestamp(),
                TSE = #timestamped_entry{timestamp = Timestamp,
                                         entry_type = x509_entry,
                                         signed_entry = LeafCert},
                MTL = #mtl{leaf_version = v1,
                           leaf_type = timestamped_entry,
                           entry = TSE},
                ok = plop:add(
                       serialise_logentry(Timestamp, LeafCert, CertChain),
                       ht:leaf_hash(serialise(MTL)),
                       crypto:hash(sha256, LeafCert)),
                TSE;
            {_Index, _MTLHash, Entry} ->
                <<Timestamp:64, _LogEntry/binary>> = Entry,
                %% TODO: Perform a costly db consistency check against
                %% unpacked LogEntry (w/ LeafCert and CertChain)
                #timestamped_entry{timestamp = Timestamp,
                                   entry_type = x509_entry,
                                   signed_entry = LeafCert}
        end,
    SCT_sig =
        plop:spt(list_to_binary([<<?PROTOCOL_VERSION:8>>,
                                 serialise_signature_type(certificate_timestamp),
                                 serialise(TimestampedEntry)])),
    binary_to_list(
      jiffy:encode(
        {[{sct_version, ?PROTOCOL_VERSION},
          {id, base64:encode(plop:get_logid())},
          {timestamp, TimestampedEntry#timestamped_entry.timestamp},
          {extensions, base64:encode(<<>>)},
          {signature, base64:encode(plop:serialise(SCT_sig))}]})).

-spec serialise_logentry(integer(), binary(), [binary()]) -> binary().
serialise_logentry(Timestamp, LeafCert, CertChain) ->
    list_to_binary(
      [<<Timestamp:64>>,
       list_to_binary(
         [encode_tls_vector(LeafCert, 3),
          encode_tls_vector(
            list_to_binary(
              [encode_tls_vector(X, 3) || X <- CertChain]), 3)])]).

-spec entries(non_neg_integer(), non_neg_integer()) -> list().
entries(Start, End) ->
    binary_to_list(
      jiffy:encode({[{entries, x_entries(plop:get(Start, End))}]})).

-spec entry_and_proof(non_neg_integer(), non_neg_integer()) -> list().
entry_and_proof(Index, TreeSize) ->
    binary_to_list(
      jiffy:encode(
        case plop:inclusion_and_entry(Index, TreeSize) of
            {ok, Entry, Path} ->
                {Timestamp, LeafCertVector, CertChainVector} = unpack_entry(Entry),
                MTL = build_mtl(Timestamp, LeafCertVector),
                {[{leaf_input, base64:encode(MTL)},
                  {extra_data, base64:encode(CertChainVector)},
                  {audit_path, [base64:encode(X) || X <- Path]}]};
            {notfound, Msg} ->
                {[{success, false},
                  {error_message, list_to_binary(Msg)}]}
        end)).

%% Private functions.
unpack_entry(Entry) ->
    <<Timestamp:64, LogEntry/binary>> = Entry,
    {LeafCertVector, CertChainVector} = decode_tls_vector(LogEntry, 3),
    {Timestamp, LeafCertVector, CertChainVector}.

-spec x_entries([{non_neg_integer(), binary(), binary()}]) -> list().
x_entries([]) ->
    [];
x_entries([H|T]) ->
    {_Index, _Hash, Entry} = H,
    {Timestamp, LeafCertVector, CertChainVector} = unpack_entry(Entry),
    MTL = build_mtl(Timestamp, LeafCertVector),
    [{[{leaf_input, base64:encode(MTL)}, {extra_data, base64:encode(CertChainVector)}]} |
     x_entries(T)].

-spec encode_tls_vector(binary(), non_neg_integer()) -> binary().
encode_tls_vector(Binary, LengthLen) ->
    Length = byte_size(Binary),
    <<Length:LengthLen/integer-unit:8, Binary/binary>>.

-spec decode_tls_vector(binary(), non_neg_integer()) -> {binary(), binary()}.
decode_tls_vector(Binary, LengthLen) ->
    <<Length:LengthLen/integer-unit:8, Rest/binary>> = Binary,
    <<ExtractedBinary:Length/binary-unit:8, Rest2/binary>> = Rest,
    {ExtractedBinary, Rest2}.