%%% Copyright (c) 2014, NORDUnet A/S.
%%% See LICENSE for licensing information.

-module(x509).
-export([normalise_chain/2, cert_string/1, read_pemfiles_from_dir/1,
         self_signed/1]).

-include_lib("public_key/include/public_key.hrl").
-include_lib("eunit/include/eunit.hrl").

-type reason() :: {chain_too_long |
                   root_unknown |
                   signature_mismatch |
                   encoding_invalid}.

-define(MAX_CHAIN_LENGTH, 10).

-spec normalise_chain([binary()], [binary()]) -> {ok, [binary()]} |
                                                 {error, reason()}.
normalise_chain(AcceptableRootCerts, CertChain) ->
    case valid_chain_p(AcceptableRootCerts, CertChain, ?MAX_CHAIN_LENGTH) of
        {false, Reason} ->
            {error, Reason};
        {true, Root} ->
            [Leaf | Chain] = CertChain,
            {ok, [detox_precert(Leaf) | Chain] ++ Root}
    end.

%%%%%%%%%%%%%%%%%%%%
%% @doc Verify that the leaf cert or precert has a valid chain back to
%% an acceptable root cert. The order of certificates in the second
%% argument is: leaf cert in head, chain in tail. Order of first
%% argument is irrelevant.
%%
%% Return {false, Reason} or {true, ListWithRoot}. Note that
%% ListWithRoot is the empty list when the root of the chain is found
%% amongst the acceptable root certs. Otherwise it contains exactly
%% one element, a CA cert from the acceptable root certs signing the
%% root of the chain.
-spec valid_chain_p([binary()], [binary()], integer()) ->
                           {false, reason()} | {true, list()}.
valid_chain_p(_, _, MaxChainLength) when MaxChainLength =< 0 ->
    %% Chain too long.
    {false, chain_too_long};
valid_chain_p(AcceptableRootCerts, [TopCert], MaxChainLength) ->
    %% Check root of chain.
    case lists:member(TopCert, AcceptableRootCerts) of
        true ->
            %% Top cert is part of chain.
            {true, []};
        false when MaxChainLength =< 1 ->
            %% Chain too long.
            {false, chain_too_long};
        false ->
            %% Top cert _might_ be signed by a cert in truststore.
            case signer(TopCert, AcceptableRootCerts) of
                notfound -> {false, root_unknown};
                Root -> {true, [Root]}
            end
    end;
valid_chain_p(AcceptableRootCerts, [BottomCert|Rest], MaxChainLength) ->
    case signed_by_p(BottomCert, hd(Rest)) of
        true -> valid_chain_p(AcceptableRootCerts, Rest, MaxChainLength - 1);
        false -> {false, signature_mismatch}
    end.

%% @doc Return first cert in list signing Cert, or notfound. NOTE:
%% This is potentially expensive. It'd be more efficient to search for
%% Cert.issuer in a list of Issuer.subject's. If so, maybe make the
%% matching somewhat fuzzy unless that too is expensive.
-spec signer(binary(), [binary()]) -> notfound | binary().
signer(_Cert, []) ->
    notfound;
signer(Cert, [H|T]) ->
    lager:debug("Is ~p signed by ~p?", [cert_string(Cert), cert_string(H)]),
    case signed_by_p(Cert, H) of
        true ->
            lager:debug("~p is signed by ~p",
                        [cert_string(Cert), cert_string(H)]),
            H;
        false ->
            signer(Cert, T)
    end.

%% Code from pubkey_cert:encoded_tbs_cert/1.
encoded_tbs_cert(DerCert) ->
    {ok, PKIXCert} =
	'OTP-PUB-KEY':decode_TBSCert_exclusive(DerCert),
    {'Certificate', {'Certificate_tbsCertificate', EncodedTBSCert}, _, _} =
        PKIXCert,
    EncodedTBSCert.

%% Code from pubkey_cert:extract_verify_data/2.
-spec verifydata_from_cert(#'Certificate'{}, binary()) -> {ok, tuple()} | error.
verifydata_from_cert(Cert, DerCert) ->
    PlainText = encoded_tbs_cert(DerCert),
    {_, Sig} = Cert#'Certificate'.signature,
    SigAlgRecord = Cert#'Certificate'.signatureAlgorithm,
    SigAlg = SigAlgRecord#'AlgorithmIdentifier'.algorithm,
    lager:debug("SigAlg: ~p", [SigAlg]),
    try
        {DigestType, _} = public_key:pkix_sign_types(SigAlg),
        {ok, {PlainText, DigestType, Sig}}
    catch
        error:function_clause ->
            lager:debug("signature algorithm not supported: ~p", [SigAlg]),
            error
    end.

%% @doc Verify that Cert/DerCert is signed by Issuer.
-spec verify_sig(#'Certificate'{}, binary(), #'Certificate'{}) -> boolean().
verify_sig(Cert, DerCert,                     % Certificate to verify.
           #'Certificate'{                    % Issuer.
              tbsCertificate = #'TBSCertificate'{
                                  subjectPublicKeyInfo = IssuerSPKI}}) ->
    %% Dig out digest, digest type and signature from Cert/DerCert.
    case verifydata_from_cert(Cert, DerCert) of
        error -> false;
        {ok, Tuple} -> verify_sig2(IssuerSPKI, Tuple)
    end.

verify_sig2(IssuerSPKI, {DigestOrPlainText, DigestType, Signature}) ->
    %% Dig out issuer key from issuer cert.
    #'SubjectPublicKeyInfo'{
       algorithm = #'AlgorithmIdentifier'{algorithm = Alg, parameters = Params},
       subjectPublicKey = {0, Key0}} = IssuerSPKI,
    KeyType = pubkey_cert_records:supportedPublicKeyAlgorithms(Alg),
    lager:debug("Alg: ~p", [Alg]),
    lager:debug("Params: ~p", [Params]),
    lager:debug("KeyType: ~p", [KeyType]),
    lager:debug("Key0: ~p", [Key0]),
    IssuerKey =
        case KeyType of
            'RSAPublicKey' ->
                public_key:der_decode(KeyType, Key0);
            'ECPoint' ->
                Point = #'ECPoint'{point = Key0},
                ECParams = public_key:der_decode('EcpkParameters', Params),
                {Point, ECParams};
            _ ->                              % FIXME: 'DSAPublicKey'
                lager:error("NIY: Issuer key type ~p", [KeyType]),
                false
        end,

    lager:debug("DigestOrPlainText: ~p", [DigestOrPlainText]),
    lager:debug("DigestType: ~p", [DigestType]),
    lager:debug("Signature: ~p", [Signature]),
    lager:debug("IssuerKey: ~p", [IssuerKey]),

    %% Verify the signature.
    public_key:verify(DigestOrPlainText, DigestType, Signature, IssuerKey).

%% @doc Is Cert signed by Issuer? Only verify that the signature
%% matches and don't check things like Cert.issuer == Issuer.subject.
-spec signed_by_p(binary(), binary()) -> boolean().
signed_by_p(DerCert, IssuerDerCert) when is_binary(DerCert),
                                         is_binary(IssuerDerCert) ->
    verify_sig(public_key:pkix_decode_cert(DerCert, plain),
               DerCert,
               public_key:pkix_decode_cert(IssuerDerCert, plain)).

cert_string(Der) ->
    mochihex:to_hex(crypto:hash(sha, Der)).

parsable_cert_p(Der) ->
    case (catch public_key:pkix_decode_cert(Der, plain)) of
        #'Certificate'{} ->
            true;
        {'EXIT', Reason} ->
            lager:info("invalid certificate: ~p: ~p", [cert_string(Der), Reason]),
            false;
        Unknown ->
            lager:info("unknown error decoding cert: ~p: ~p",
                       [cert_string(Der), Unknown]),
            false
    end.

-spec self_signed([binary()]) -> [binary()].
self_signed(L) ->
    lists:filter(fun(Cert) -> signed_by_p(Cert, Cert) end, L).

%%%%%%%%%%%%%%%%%%%%
%% Precertificates according to draft-ietf-trans-rfc6962-bis-04.

%% Submitted precerts have a special critical poison extension -- OID
%% 1.3.6.1.4.1.11129.2.4.3, whose extnValue OCTET STRING contains
%% ASN.1 NULL data (0x05 0x00).

%% They are signed with either the CA cert that will sign the final
%% cert or Precertificate Signing Certificate directly signed by the
%% CA cert that will sign the final cert. A Precertificate Signing
%% Certificate has CA:true and Extended Key Usage: Certificate
%% Transparency, OID 1.3.6.1.4.1.11129.2.4.4.

%% A PreCert in a SignedCertificateTimestamp does _not_ contain the
%% poison extension, nor a Precertificate Signing Certificate. This
%% means that we might have to 1) remove poison extensions in leaf
%% certs, 2) remove "poisoned signatures", 3) change issuer and
%% Authority Key Identifier of leaf certs.

-spec detox_precert([#'Certificate'{}]) -> [#'Certificate'{}].
detox_precert(CertChain) ->
    CertChain.                                  % NYI

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-spec read_pemfiles_from_dir(file:filename()) -> [binary()].
%% @doc Reading certificates from files. Flattening the result -- all
%% certs in all files are returned in a single list.
read_pemfiles_from_dir(Dir) ->
    case file:list_dir(Dir) of
        {error, enoent} ->
            lager:error("directory does not exist: ~p", [Dir]),
            [];
        {error, Reason} ->
            lager:error("unable to read directory ~p: ~p", [Dir, Reason]),
            [];
        {ok, Filenames} ->
            Files = lists:filter(
                      fun(F) ->
                              string:equal(".pem", filename:extension(F))
                      end,
                      Filenames),
            ders_from_pemfiles(Dir, Files)
    end.

ders_from_pemfiles(Dir, Filenames) ->
    lists:flatten(
      [ders_from_pemfile(filename:join(Dir, X)) || X <- Filenames]).

ders_from_pemfile(Filename) ->
    lager:debug("reading PEM from ~s", [Filename]),
    PemBins = pems_from_file(Filename),
    Pems = case (catch public_key:pem_decode(PemBins)) of
               {'EXIT', Reason} ->
                   lager:info("~p: invalid PEM-encoding: ~p", [Filename, Reason]),
                   [];
               P -> P
           end,
    [der_from_pem(X) || X <- Pems].

der_from_pem(Pem) ->
    case Pem of
        {_Type, Der, not_encrypted} ->
            case parsable_cert_p(Der) of
                true ->
                    Der;
                false ->
                    dump_unparsable_cert(Der),
                    []
            end;
        Fail ->
            lager:info("ignoring PEM-encoded data: ~p~n", [Fail]),
            []
    end.

-spec pems_from_file(file:filename()) -> binary().
pems_from_file(Filename) ->
    {ok, Pems} = file:read_file(Filename),
    Pems.

-spec dump_unparsable_cert(binary()) -> ok | {error, atom()} | not_logged.
dump_unparsable_cert(CertDer) ->
    case application:get_env(catlfish, rejected_certs_path) of
        {ok, Directory} ->
            {NowMegaSec, NowSec, NowMicroSec} = now(),
            Filename =
                filename:join(Directory,
                              io_lib:format("~p:~p.~p",
                                            [cert_string(CertDer),
                                             NowMegaSec * 1000 * 1000 + NowSec,
                                             NowMicroSec])),
            lager:debug("dumping cert to ~p~n", [Filename]),
            file:write_file(Filename, CertDer);
        _ ->
            not_logged
    end.

%%%%%%%%%%%%%%%%%%%%
%% Testing private functions.
-include("x509_test.hrl").
sign_test_() ->
    {setup,
     fun() -> ok end,
     fun(_) -> ok end,
     fun(_) -> [?_assertMatch(true, signed_by_p(?C0, ?C1))] end}.

valid_cert_test_() ->
    {setup,
     fun() -> {read_pemfiles_from_dir("test/testdata/known_roots"),
               read_certs("test/testdata/chains")} end,
     fun(_) -> ok end,
     fun({KnownRoots, Chains}) ->
             [
              %% Self-signed but verified against itself so pass.
              %% Not a valid OTPCertificate:
              %% {error,{asn1,{invalid_choice_tag,{22,<<"US">>}}}}
              %% 'OTP-PUB-KEY':Func('OTP-X520countryname', Value0)
              %% FIXME: This error doesn't make much sense -- is my
              %% environment borked?
              ?_assertMatch({true, _}, valid_chain_p(lists:nth(1, Chains),
                                                     lists:nth(1, Chains), 10)),
              %% Self-signed so fail.
              ?_assertMatch({false, root_unknown},
                            valid_chain_p(KnownRoots,
                                          lists:nth(2, Chains), 10)),
              %% Leaf signed by known CA, pass.
              ?_assertMatch({true, _}, valid_chain_p(KnownRoots,
                                                     lists:nth(3, Chains), 10)),
              %% Proper 3-depth chain with root in KnownRoots, pass.
              %% Bug CATLFISH-19 --> [info] rejecting "3ee62cb678014c14d22ebf96f44cc899adea72f1": chain_broken
              %% leaf sha1: 3ee62cb678014c14d22ebf96f44cc899adea72f1
              %% leaf Subject: C=KR, O=Government of Korea, OU=Group of Server, OU=\xEA\xB5\x90\xEC\x9C\xA1\xEA\xB3\xBC\xED\x95\x99\xEA\xB8\xB0\xEC\x88\xA0\xEB\xB6\x80, CN=www.berea.ac.kr, CN=haksa.bits.ac.kr
              ?_assertMatch({true, _}, valid_chain_p(KnownRoots,
                                                     lists:nth(4, Chains), 3)),
              %% Verify against self, pass.
              %% Bug CATLFISH-??, can't handle issuer keytype ECPoint.
              %% Issuer sha1: 6969562e4080f424a1e7199f14baf3ee58ab6abb
              ?_assertMatch(true, signed_by_p(hd(lists:nth(5, Chains)),
                                              hd(lists:nth(5, Chains)))),
              %% Unsupported signature algorithm MD2-RSA, fail.
              %% Signature Algorithm: md2WithRSAEncryption
              %% CA cert with sha1 96974cd6b663a7184526b1d648ad815cf51e801a
              ?_assertMatch(false, signed_by_p(hd(lists:nth(6, Chains)),
                                               hd(lists:nth(6, Chains))))
              ] end}.

chain_test_() ->
    {setup,
     fun() -> {?C0, ?C1} end,
     fun(_) -> ok end,
     fun({C0, C1}) -> chain_test(C0, C1) end}.

chain_test(C0, C1) ->
    [
     %% Root not in chain but in trust store.
     ?_assertEqual({true, [C1]}, valid_chain_p([C1], [C0], 10)),
     ?_assertEqual({true, [C1]}, valid_chain_p([C1], [C0], 2)),
     %% Chain too long.
     ?_assertMatch({false, chain_too_long}, valid_chain_p([C1], [C0], 1)),
     %% Root in chain and in trust store.
     ?_assertEqual({true, []}, valid_chain_p([C1], [C0, C1], 2)),
     %% Chain too long.
     ?_assertMatch({false, chain_too_long}, valid_chain_p([C1], [C0, C1], 1)),
     %% Root not in trust store.
     ?_assertMatch({false, root_unknown}, valid_chain_p([], [C0, C1], 10)),
     %% Selfsigned. Actually OK.
     ?_assertMatch({true, []}, valid_chain_p([C0], [C0], 10)),
     ?_assertMatch({true, []}, valid_chain_p([C0], [C0], 1)),
     %% Max chain length 0 is not OK.
     ?_assertMatch({false, chain_too_long}, valid_chain_p([C0], [C0], 0))
    ].

%%-spec read_certs(file:filename()) -> [string:string()].
-spec read_certs(file:filename()) -> [[binary()]].
read_certs(Dir) ->
    {ok, Fnames} = file:list_dir(Dir),
    PemBins =
        [Pems || {ok, Pems} <-
                     [file:read_file(filename:join(Dir, F)) ||
                         F <- lists:sort(
                                lists:filter(
                                  fun(FN) -> string:equal(
                                               ".pem", filename:extension(FN))
                                  end,
                                  Fnames))]],
    PemEntries = [public_key:pem_decode(P) || P <- PemBins],
    lists:map(fun(L) -> [Der || {'Certificate', Der, not_encrypted} <- L] end,
              PemEntries).