%% TODO: deduplicate code in parse/2 once there's proper tests -module(p11p_rpc). -export([parse/2, new/0, new/1, serialise/1]). -include("p11p_rpc.hrl"). parse(M) -> parse(M, <<>>). -spec parse(p11rpc_msg(), binary()) -> {done, p11rpc_msg()} | {needmore, p11rpc_msg()} | {err, term()}. parse(#p11rpc_msg{buffer = Buf} = M, Data) when M#p11rpc_msg.state == header -> %% NOTE: Does _not_ consume buffer until it has at least 12 octets. NewBuf = <>, Msg = M#p11rpc_msg{buffer = NewBuf}, if size(NewBuf) < 12 -> {needmore, Msg}; true -> parse(consume_header(Msg)) end; parse(#p11rpc_msg{buffer = Buf} = M, Data) when M#p11rpc_msg.state == opts -> NewBuf = <>, Msg = consume_opts(M#p11rpc_msg{buffer = NewBuf}), case Msg#p11rpc_msg.state of opts -> {needmore, Msg}; data -> parse(Msg) end; parse(#p11rpc_msg{buffer = Buf} = M, Data) when M#p11rpc_msg.state == data -> NewBuf = <>, Msg = consume_data(M#p11rpc_msg{buffer = NewBuf}), case Msg#p11rpc_msg.state of data -> {needmore, Msg}; done -> {done, Msg} end. -spec serialise(p11rpc_msg()) -> binary(). serialise(M) when M#p11rpc_msg.state == done, M#p11rpc_msg.call_code > -1, M#p11rpc_msg.opt_len > -1, M#p11rpc_msg.data_len > -1 -> CallCode = M#p11rpc_msg.call_code, OptLen = M#p11rpc_msg.opt_len, DataLen = M#p11rpc_msg.data_len, Options = M#p11rpc_msg.options, Data = M#p11rpc_msg.data, <>. %% Private new() -> new(<<>>). new(Buffer) -> #p11rpc_msg{buffer = Buffer}. consume_header(#p11rpc_msg{buffer = Buf} = Msg) -> Msg#p11rpc_msg{call_code = binary:decode_unsigned(binary:part(Buf, 0, 4)), opt_len = binary:decode_unsigned(binary:part(Buf, 4, 4)), data_len = binary:decode_unsigned(binary:part(Buf, 8, 4)), state = opts, buffer = binary:part(Buf, 12, size(Buf) - 12)}. consume_opts(#p11rpc_msg{opt_len = Len, options = Opts, buffer = Buf} = Msg) -> {NewData, NewBuf} = move_between_binaries(Opts, Buf, Len - size(Opts)), NewState = case size(NewData) == Len of true -> data; false -> opts end, Msg#p11rpc_msg{options = NewData, buffer = NewBuf, state = NewState}. consume_data(#p11rpc_msg{data_len = Len, data = Data, buffer = Buf} = Msg) -> {NewData, NewBuf} = move_between_binaries(Data, Buf, Len - size(Data)), NewState = case size(NewData) == Len of true -> done; false -> data end, Msg#p11rpc_msg{data = NewData, buffer = NewBuf, state = NewState}. move_between_binaries(Dst, Src, N) when N == 0 -> {Dst, Src}; move_between_binaries(Dst, Src, NAsk) -> N = min(NAsk, size(Src)), FromSrc = binary:part(Src, 0, N), NewDst = <>, NewSrc = binary:part(Src, N, size(Src) - N), {NewDst, NewSrc}. %%%%%%%%%%%%%% %% Unit tests. -include_lib("eunit/include/eunit.hrl"). consume_data_test_() -> {setup, fun() -> Msg = #p11rpc_msg{ data_len = 3, data = <<"a">>, buffer = <<"bcde">>}, consume_data(Msg) end, fun(_) -> ok end, fun(Msg) -> [?_assertEqual( {p11rpc_msg, -1, -1, 3, <<>>, <<"abc">>, <<"de">>, done}, Msg)] end}. %% TODO: generate these, to make them exhaustive parse1_test_() -> {setup, fun() -> Buf0 = <<0:32/integer, 0:32/integer, 3:32/integer, "a">>, {needmore, #p11rpc_msg{buffer = Buf1} = M1} = parse(new(Buf0)), {needmore, #p11rpc_msg{buffer = Buf2} = M2} = parse(M1#p11rpc_msg{buffer = <>}), {done, MDone} = parse(M2#p11rpc_msg{buffer = <>}), MDone end, fun(_) -> ok end, fun(Msg) -> [?_assertEqual( {p11rpc_msg, 0, 0, 3, <<>>, <<"abc">>, <<"de">>, done}, Msg)] end}. parse2_test_() -> {setup, fun() -> Buf0 = <<47:32/integer, 2:32/integer, 3:32/integer, "o1d">>, {needmore, #p11rpc_msg{buffer = Buf1} = M1} = parse(new(Buf0)), {done, MDone} = parse(M1#p11rpc_msg{buffer = <>}), MDone end, fun(_) -> ok end, fun(Msg) -> [?_assertEqual( {p11rpc_msg, 47, 2, 3, <<"o1">>, <<"d12">>, <<"rest">>, done}, Msg)] end}. parse3_test_() -> {setup, fun() -> Buf0 = <<47:32/integer>>, {needmore, #p11rpc_msg{buffer = Buf1} = M1} = parse(new(Buf0)), {needmore, #p11rpc_msg{buffer = Buf2} = M2} = parse(M1#p11rpc_msg{buffer = <>}), {needmore, #p11rpc_msg{buffer = Buf3} = M3} = parse(M2#p11rpc_msg{buffer = <>}), {needmore, #p11rpc_msg{buffer = Buf4} = M4} = parse(M3#p11rpc_msg{buffer = <>}), {done, MDone} = parse(M4#p11rpc_msg{buffer = <>}), MDone end, fun(_) -> ok end, fun(Msg) -> [?_assertEqual( {p11rpc_msg, 47, 2, 3, <<"o1">>, <<"d12">>, <<"rest">>, done}, Msg)] end}.