%%% Copyright (c) 2014, NORDUnet A/S.
%%% See LICENSE for licensing information.

-module(catlfish).
-export([add_chain/2, entries/2, entry_and_proof/2]).
-export([known_roots/0, update_known_roots/0]).
-include_lib("eunit/include/eunit.hrl").
-include("catlfish.hrl").

-define(PROTOCOL_VERSION, 0).

%%-type signature_type() :: certificate_timestamp | tree_hash | test. % uint8
-type entry_type() :: x509_entry | precert_entry | test. % uint16
-type leaf_type() :: timestamped_entry | test.           % uint8
-type leaf_version() :: v1 | v2.                         % uint8

-record(mtl, {leaf_version :: leaf_version(),
              leaf_type :: leaf_type(),
              entry :: timestamped_entry()}).
-type mtl() :: #mtl{}.

-record(timestamped_entry, {timestamp :: integer(),
                            entry_type :: entry_type(),
                            signed_entry :: binary(),
                            extensions = <<>> :: binary()}).
-type timestamped_entry() :: #timestamped_entry{}.

-spec serialise(mtl() | timestamped_entry()) -> binary().
serialise(#timestamped_entry{timestamp = Timestamp} = E) ->
    list_to_binary(
      [<<Timestamp:64>>,
       serialise_entry_type(E#timestamped_entry.entry_type),
       encode_tls_vector(E#timestamped_entry.signed_entry, 3),
       encode_tls_vector(E#timestamped_entry.extensions, 2)]);
serialise(#mtl{leaf_version = LeafVersion,
               leaf_type = LeafType,
               entry = TimestampedEntry}) ->
    list_to_binary(
      [serialise_leaf_version(LeafVersion),
       serialise_leaf_type(LeafType),
       serialise(TimestampedEntry)]).

serialise_leaf_version(v1) ->
    <<0:8>>;
serialise_leaf_version(v2) ->
    <<1:8>>.

serialise_leaf_type(timestamped_entry) ->
    <<0:8>>.
%% serialise_leaf_type(_) ->
%%     <<>>.

serialise_entry_type(x509_entry) ->
    <<0:16>>;
serialise_entry_type(precert_entry) ->
    <<1:16>>.

serialise_signature_type(certificate_timestamp) ->
    <<0:8>>;
serialise_signature_type(tree_hash) ->
    <<1:8>>.

build_mtl(Timestamp, LeafCert) ->
    TSE = #timestamped_entry{timestamp = Timestamp,
                             entry_type = x509_entry,
                             signed_entry = LeafCert},
    MTL = #mtl{leaf_version = v1,
               leaf_type = timestamped_entry,
               entry = TSE},
    serialise(MTL).

-spec add_chain(binary(), [binary()]) -> nonempty_string().
add_chain(LeafCert, CertChain) ->
    EntryHash = crypto:hash(sha256, LeafCert),
    TimestampedEntry =
        case plop:get(EntryHash) of
            notfound ->
                Timestamp = plop:generate_timestamp(),
                TSE = #timestamped_entry{timestamp = Timestamp,
                                         entry_type = x509_entry,
                                         signed_entry = LeafCert},
                MTL = #mtl{leaf_version = v1,
                           leaf_type = timestamped_entry,
                           entry = TSE},
                ok = plop:add(
                       serialise_logentry(Timestamp, LeafCert, CertChain),
                       ht:leaf_hash(serialise(MTL)),
                       crypto:hash(sha256, LeafCert)),
                TSE;
            {_Index, _MTLHash, Entry} ->
                <<Timestamp:64, _LogEntry/binary>> = Entry,
                %% TODO: Perform a costly db consistency check against
                %% unpacked LogEntry (w/ LeafCert and CertChain)
                #timestamped_entry{timestamp = Timestamp,
                                   entry_type = x509_entry,
                                   signed_entry = LeafCert}
        end,
    SCT_sig =
        plop:spt(list_to_binary([<<?PROTOCOL_VERSION:8>>,
                                 serialise_signature_type(certificate_timestamp),
                                 serialise(TimestampedEntry)])),
    {[{sct_version, ?PROTOCOL_VERSION},
      {id, base64:encode(plop:get_logid())},
      {timestamp, TimestampedEntry#timestamped_entry.timestamp},
      {extensions, base64:encode(<<>>)},
      {signature, base64:encode(plop:serialise(SCT_sig))}]}.

-spec serialise_logentry(integer(), binary(), [binary()]) -> binary().
serialise_logentry(Timestamp, LeafCert, CertChain) ->
    list_to_binary(
      [<<Timestamp:64>>,
       list_to_binary(
         [encode_tls_vector(LeafCert, 3),
          encode_tls_vector(
            list_to_binary(
              [encode_tls_vector(X, 3) || X <- CertChain]), 3)])]).

-spec entries(non_neg_integer(), non_neg_integer()) -> list().
entries(Start, End) ->
    {[{entries, x_entries(plop:get(Start, End))}]}.

-spec entry_and_proof(non_neg_integer(), non_neg_integer()) -> list().
entry_and_proof(Index, TreeSize) ->
    case plop:inclusion_and_entry(Index, TreeSize) of
        {ok, Entry, Path} ->
            {Timestamp, LeafCertVector, CertChainVector} = unpack_entry(Entry),
            MTL = build_mtl(Timestamp, LeafCertVector),
            {[{leaf_input, base64:encode(MTL)},
              {extra_data, base64:encode(CertChainVector)},
              {audit_path, [base64:encode(X) || X <- Path]}]};
        {notfound, Msg} ->
            {[{success, false},
              {error_message, list_to_binary(Msg)}]}
    end.

%% Private functions.
unpack_entry(Entry) ->
    <<Timestamp:64, LogEntry/binary>> = Entry,
    {LeafCertVector, CertChainVector} = decode_tls_vector(LogEntry, 3),
    {Timestamp, LeafCertVector, CertChainVector}.

-spec x_entries([{non_neg_integer(), binary(), binary()}]) -> list().
x_entries([]) ->
    [];
x_entries([H|T]) ->
    {_Index, _Hash, Entry} = H,
    {Timestamp, LeafCertVector, CertChainVector} = unpack_entry(Entry),
    MTL = build_mtl(Timestamp, LeafCertVector),
    [{[{leaf_input, base64:encode(MTL)}, {extra_data, base64:encode(CertChainVector)}]} |
     x_entries(T)].

-spec encode_tls_vector(binary(), non_neg_integer()) -> binary().
encode_tls_vector(Binary, LengthLen) ->
    Length = byte_size(Binary),
    <<Length:LengthLen/integer-unit:8, Binary/binary>>.

-spec decode_tls_vector(binary(), non_neg_integer()) -> {binary(), binary()}.
decode_tls_vector(Binary, LengthLen) ->
    <<Length:LengthLen/integer-unit:8, Rest/binary>> = Binary,
    <<ExtractedBinary:Length/binary-unit:8, Rest2/binary>> = Rest,
    {ExtractedBinary, Rest2}.

-define(ROOTS_CACHE_KEY, roots).

update_known_roots() ->
    case application:get_env(catlfish, known_roots_path) of
        {ok, Dir} -> update_known_roots(Dir);
        undefined -> []
    end.

update_known_roots(Directory) ->
    known_roots(Directory, update_tab).

known_roots() ->
    case application:get_env(catlfish, known_roots_path) of
        {ok, Dir} -> known_roots(Dir, use_cache);
        undefined -> []
    end.

-spec known_roots(file:filename(), use_cache|update_tab) -> list().
known_roots(Directory, CacheUsage) ->
    case CacheUsage of
        use_cache ->
            case ets:lookup(?CACHE_TABLE, ?ROOTS_CACHE_KEY) of
                [] ->
                    read_pemfiles_from_dir(Directory);
                [{roots, DerList}] ->
                    DerList
            end;
        update_tab ->
            read_pemfiles_from_dir(Directory)
    end.

-spec read_pemfiles_from_dir(file:filename()) -> list().
read_pemfiles_from_dir(Dir) ->
    DerList =
        case file:list_dir(Dir) of
            {error, enoent} ->
                [];                             % FIXME: log enoent
            {error, _Reason} ->
                [];                             % FIXME: log Reason
            {ok, Filenames} ->
                Files = lists:filter(
                          fun(F) ->
                                  string:equal(".pem", filename:extension(F))
                          end,
                          Filenames),
                ders_from_pemfiles(Dir, Files)
        end,
    true = ets:insert(?CACHE_TABLE, {?ROOTS_CACHE_KEY, DerList}),
    DerList.

ders_from_pemfiles(Dir, Filenames) ->
    L = [ders_from_pemfile(filename:join(Dir, X)) || X <- Filenames],
    lists:flatten(L).

ders_from_pemfile(Filename) ->
    Pems = case (catch public_key:pem_decode(pems_from_file(Filename))) of
               {'EXIT', _} -> [];
               P -> P
           end,
    [der_from_pem(X) || X <- Pems].

-include_lib("public_key/include/public_key.hrl").
der_from_pem(Pem) ->
    case Pem of
        {_Type, Der, not_encrypted} ->
            case (catch public_key:pkix_decode_cert(Der, otp)) of
                {'EXIT', _} ->
                    [];
                #'OTPCertificate'{} ->
                    Der;
                _Unknown ->
                    []
            end;
        _ -> []
    end.

pems_from_file(Filename) ->
    {ok, Pems} = file:read_file(Filename),
    Pems.

%%%%%%%%%%%%%%%%%%%%
%% Testing internal functions.
-define(PEMFILES_DIR_OK, "../test/testdata/known-roots").
-define(PEMFILES_DIR_NONEXISTENT, "../test/testdata/nonexistent-dir").

read_pemfiles_test_() ->
    {setup,
     fun() -> {known_roots(?PEMFILES_DIR_OK, use_cache),
               known_roots(?PEMFILES_DIR_OK, use_cache)}
     end,
     fun(_) -> ets:delete(?CACHE_TABLE, ?ROOTS_CACHE_KEY) end,
     fun({L, LCached}) ->
             [?_assertMatch(7, length(L)),
              ?_assertEqual(L, LCached)]
     end}.

read_pemfiles_fail_test_() ->
    {setup,
     fun() -> known_roots(?PEMFILES_DIR_NONEXISTENT, use_cache) end,
     fun(_) -> ets:delete(?CACHE_TABLE, ?ROOTS_CACHE_KEY) end,
     fun(Empty) -> [?_assertMatch([], Empty)] end}.