%%% Copyright (c) 2019, Sunet. %%% See LICENSE for licensing information. %% TODO: deduplicate code in parse/2 once there's proper tests -module(p11p_rpc). -export([ call_code/1, dump/1, msg_error/2, msg_ok/1, new/0, new/1, parse/2, req_id/1, serialise/1 ]). -include("p11p_rpc.hrl"). call_code(Msg) -> Msg#p11rpc_msg.call_code. dump(Msg = #p11rpc_msg{data = Data}) -> {ReqId, Data2} = parse_req_id(Data), {ArgsDesc, Data3} = parse_args_desc(Data2), {Name, _ReqArgs, _RespArgs} = lists:nth(ReqId + 1, ?REQIDS), io_lib:format("RPC [~B]: ~s (~B), args \"~s\":~n~p", [Msg#p11rpc_msg.call_code, Name, ReqId, ArgsDesc, Data3 ]). msg_error(CallCode, ErrorCode) -> DataBuf = serialise_error(ErrorCode), #p11rpc_msg{ state = done, call_code = CallCode, opt_len = 0, data_len = size(DataBuf), data = DataBuf}. msg_ok(CallCode) -> #p11rpc_msg{ state = done, call_code = CallCode, opt_len = 0, data_len = 0}. parse(M) -> parse(M, <<>>). -spec parse(p11rpc_msg(), binary()) -> {done, p11rpc_msg()} | {needmore, p11rpc_msg()} | {err, term()}. parse(#p11rpc_msg{buffer = MsgBuf} = M, DataIn) when M#p11rpc_msg.state == header -> Buf = <>, Msg = M#p11rpc_msg{buffer = Buf}, if size(Buf) < 12 -> {needmore, Msg}; true -> parse(consume_header(Msg)) end; parse(#p11rpc_msg{buffer = MsgBuf} = M, DataIn) when M#p11rpc_msg.state == opts -> Buf = <>, Msg = consume_opts(M#p11rpc_msg{buffer = Buf}), case Msg#p11rpc_msg.state of opts -> {needmore, Msg}; data -> parse(Msg) end; parse(#p11rpc_msg{buffer = MsgBuf} = M, DataIn) when M#p11rpc_msg.state == data -> Buf = <>, Msg = consume_data(M#p11rpc_msg{buffer = Buf}), case Msg#p11rpc_msg.state of data -> {needmore, Msg}; done -> {done, Msg} end. req_id(Msg) when Msg#p11rpc_msg.data_len >= 4 -> {ReqId, _} = parse_req_id(Msg#p11rpc_msg.data), ReqId. -spec serialise(p11rpc_msg()) -> binary(). serialise(M) when M#p11rpc_msg.state == done, M#p11rpc_msg.call_code > -1, M#p11rpc_msg.opt_len > -1, M#p11rpc_msg.data_len > -1 -> CallCode = M#p11rpc_msg.call_code, OptLen = M#p11rpc_msg.opt_len, DataLen = M#p11rpc_msg.data_len, Options = M#p11rpc_msg.options, Data = M#p11rpc_msg.data, <>. %% Private new() -> new(<<>>). new(Buffer) -> #p11rpc_msg{buffer = Buffer}. consume_header(#p11rpc_msg{buffer = Buf} = Msg) -> Msg#p11rpc_msg{call_code = binary:decode_unsigned(binary:part(Buf, 0, 4)), opt_len = binary:decode_unsigned(binary:part(Buf, 4, 4)), data_len = binary:decode_unsigned(binary:part(Buf, 8, 4)), state = opts, buffer = binary:part(Buf, 12, size(Buf) - 12)}. consume_opts(#p11rpc_msg{opt_len = Len, options = Opts, buffer = MsgBuf} = M) -> {Data, Buf} = move_between_binaries(Opts, MsgBuf, Len - size(Opts)), State = case size(Data) == Len of true -> data; false -> opts end, M#p11rpc_msg{options = Data, buffer = Buf, state = State}. consume_data(#p11rpc_msg{data_len = Len, data = DataIn, buffer = MsgBuf} = M) -> {Data, Buf} = move_between_binaries(DataIn, MsgBuf, Len - size(DataIn)), State = case size(Data) == Len of true -> done; false -> data end, M#p11rpc_msg{data = Data, buffer = Buf, state = State}. move_between_binaries(DstIn, SrcIn, 0) -> {DstIn, SrcIn}; move_between_binaries(DstIn, SrcIn, NBytes) -> N = min(NBytes, size(SrcIn)), FromSrc = binary:part(SrcIn, 0, N), Dst = <>, Src = binary:part(SrcIn, N, size(SrcIn) - N), {Dst, Src}. serialise_byte_array(Bin) -> Len = size(Bin), <>. serialise_error(ErrCode) -> ReqId = ?P11_RPC_CALL_ERROR, ArgsDescString = "u", % TODO: look this up and generalise. ReqIdBin = serialise_uint32(ReqId), ArgsDescBin = serialise_byte_array(list_to_binary(ArgsDescString)), ArgBin = serialise_uint64(ErrCode), <>. serialise_uint32(U32) -> <>. serialise_uint64(U64) -> <>. -spec parse_req_id(binary()) -> {integer(), binary()}. parse_req_id(Data) -> {binary:decode_unsigned(binary:part(Data, 0, 4)), binary:part(Data, 4, size(Data) - 4)}. parse_args_desc(Data) -> parse_byte_array(Data). -spec parse_byte_array(binary()) -> {binary(), binary()}. parse_byte_array(Data) -> case binary:decode_unsigned(binary:part(Data, 0, 4)) of 16#ffffffff -> {<<>>, binary:part(Data, 4, size(Data) - 4)}; Len -> % TODO: refuse Len >= 0x7fffffff. {binary:part(Data, 4, Len), binary:part(Data, 4 + Len, size(Data) - 4 - Len)} end. %%%%%%%%%%%%%% %% Unit tests. -include_lib("eunit/include/eunit.hrl"). consume_data_test_() -> {setup, fun() -> Msg = #p11rpc_msg{ data_len = 3, data = <<"a">>, buffer = <<"bcde">>}, consume_data(Msg) end, fun(_) -> ok end, fun(Msg) -> [?_assertEqual( {p11rpc_msg, -1, -1, 3, <<>>, <<"abc">>, <<"de">>, done}, Msg)] end}. %% TODO: generate these, to make them tests exhaustive parse1_test_() -> {setup, fun() -> Buf0 = <<0:32/integer, 0:32/integer, 3:32/integer, "a">>, {needmore, #p11rpc_msg{buffer = Buf1} = M1} = parse(new(Buf0)), {needmore, #p11rpc_msg{buffer = Buf2} = M2} = parse(M1#p11rpc_msg{buffer = <>}), {done, MDone} = parse(M2#p11rpc_msg{buffer = <>}), MDone end, fun(_) -> ok end, fun(Msg) -> [?_assertEqual( {p11rpc_msg, 0, 0, 3, <<>>, <<"abc">>, <<"de">>, done}, Msg)] end}. parse2_test_() -> {setup, fun() -> Buf0 = <<47:32/integer, 2:32/integer, 3:32/integer, "o1d">>, {needmore, #p11rpc_msg{buffer = Buf1} = M1} = parse(new(Buf0)), {done, MDone} = parse(M1#p11rpc_msg{buffer = <>}), MDone end, fun(_) -> ok end, fun(Msg) -> [?_assertEqual( {p11rpc_msg, 47, 2, 3, <<"o1">>, <<"d12">>, <<"rest">>, done}, Msg)] end}. parse3_test_() -> {setup, fun() -> Buf0 = <<47:32/integer>>, {needmore, #p11rpc_msg{buffer = Buf1} = M1} = parse(new(Buf0)), {needmore, #p11rpc_msg{buffer = Buf2} = M2} = parse(M1#p11rpc_msg{buffer = <>}), {needmore, #p11rpc_msg{buffer = Buf3} = M3} = parse(M2#p11rpc_msg{buffer = <>}), {needmore, #p11rpc_msg{buffer = Buf4} = M4} = parse(M3#p11rpc_msg{buffer = <>}), {done, MDone} = parse(M4#p11rpc_msg{buffer = <>}), MDone end, fun(_) -> ok end, fun(Msg) -> [?_assertEqual( {p11rpc_msg, 47, 2, 3, <<"o1">>, <<"d12">>, <<"rest">>, done}, Msg)] end}.