summaryrefslogtreecommitdiff
path: root/src/ratelimit.erl
blob: c840cc31f480ef252fd32664123f32cddab87cfc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
%%% Copyright (c) 2014-2015, NORDUnet A/S.
%%% See LICENSE for licensing information.
%%%

-module(ratelimit).
-behaviour(gen_server).

-export([start_link/0, stop/0]).
-export([get_token/1]).
%% gen_server callbacks.
-export([init/1, handle_call/3, terminate/2, handle_cast/2, handle_info/2,
         code_change/3]).

start_link() ->
    gen_server:start_link({local, ?MODULE}, ?MODULE,
                          application:get_env(catlfish, ratelimits, []), []).

stop() ->
    gen_server:call(?MODULE, stop).

get_token(Type) ->
    gen_server:call(?MODULE, {get_token, Type}).

init(Types) ->
    lager:debug("starting ratelimit service"),
    State = dict:from_list([{Type, {Rules, queue:new()}} || {Type, Rules} <- Types]),
    {ok, State}.

rule_interval_atom([{_, Interval}]) ->
    Interval.

rule_interval([{_, second}]) ->
    1000;
rule_interval([{_, minute}]) ->
    60*1000;
rule_interval([{_, hour}]) ->
    60*60*1000.

rule_times([{Times, _}]) when is_integer(Times) ->
    Times.

clean_queue(Interval, Queue) ->
    Now = plop:generate_timestamp(),
    case queue:peek(Queue) of
        {value, Item} when Item + Interval < Now ->
            clean_queue(Interval, queue:drop(Queue));
        _ ->
            Queue
    end.

get_token_for_type({none, Queue}) ->
    {ok, {none, Queue}};
get_token_for_type({Rules, Queue}) ->
    CleanedQueue = clean_queue(rule_interval(Rules), Queue),
    MaxTimes = rule_times(Rules),
    QueueLength = queue:len(CleanedQueue),
    if
        QueueLength < MaxTimes ->
            Now = plop:generate_timestamp(),
            {ok, {Rules, queue:in(Now, CleanedQueue)}};
        true ->
            {overload, {Rules, CleanedQueue}}
    end.

handle_call(stop, _From, State) ->
    {stop, normal, stopped, State};
handle_call({get_token, Type}, _From, State) ->
    case dict:find(Type, State) of
        {ok, TypeState} ->
            {Result, NewTypeState} = get_token_for_type(TypeState),
            {Rules, Queue} = NewTypeState,
            lager:debug("current rate: ~p per ~p", [queue:len(Queue), rule_interval_atom(Rules)]),
            {reply, Result, dict:store(Type, NewTypeState, State)};
        error ->
            {reply, ok, State}
    end.

handle_cast(_Request, State) ->
    {noreply, State}.
handle_info(_Info, State) ->
    {noreply, State}.
code_change(_OldVersion, State, _Extra) ->
    {ok, State}.
terminate(_Reason, _State) ->
    ok.