xmpp.chapril.org-ejabberd/src/ejabberd_shaper.erl

241 lines
8.3 KiB
Erlang

%%%----------------------------------------------------------------------
%%% ejabberd, Copyright (C) 2002-2021 ProcessOne
%%%
%%% This program is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU General Public License as
%%% published by the Free Software Foundation; either version 2 of the
%%% License, or (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
%%% General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License along
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%----------------------------------------------------------------------
-module(ejabberd_shaper).
-behaviour(gen_server).
-export([start_link/0, new/1, update/2, match/3, get_max_rate/1]).
-export([reload_from_config/0]).
-export([validator/1, shaper_rules_validator/0]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-include("logger.hrl").
-type state() :: #{hosts := [binary()]}.
-type shaper() :: none | p1_shaper:state().
-type shaper_rate() :: {pos_integer(), pos_integer()} | pos_integer() | infinity.
-type shaper_rule() :: {atom() | pos_integer(), [acl:access_rule()]}.
-type shaper_rate_rule() :: {shaper_rate(), [acl:access_rule()]}.
-export_type([shaper/0, shaper_rule/0, shaper_rate/0]).
%%%===================================================================
%%% API
%%%===================================================================
-spec start_link() -> {ok, pid()} | {error, any()}.
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
-spec match(global | binary(), atom() | [shaper_rule()],
jid:jid() | jid:ljid() | inet:ip_address() | acl:match()) -> none | shaper_rate().
match(_, none, _) -> none;
match(_, infinity, _) -> infinity;
match(Host, Shaper, Match) when is_map(Match) ->
Rules = if is_atom(Shaper) -> read_shaper_rules(Shaper, Host);
true -> Shaper
end,
Rate = acl:match_rules(Host, Rules, Match, none),
read_shaper(Rate);
match(Host, Shaper, IP) when tuple_size(IP) == 4; tuple_size(IP) == 8 ->
match(Host, Shaper, #{ip => IP});
match(Host, Shaper, JID) ->
match(Host, Shaper, #{usr => jid:tolower(JID)}).
-spec get_max_rate(none | shaper_rate()) -> none | pos_integer().
get_max_rate({Rate, _}) -> Rate;
get_max_rate(Rate) when is_integer(Rate), Rate > 0 -> Rate;
get_max_rate(_) -> none.
-spec new(none | shaper_rate()) -> shaper().
new({Rate, Burst}) -> p1_shaper:new(Rate, Burst);
new(Rate) when is_integer(Rate), Rate > 0 -> p1_shaper:new(Rate);
new(_) -> none.
-spec update(shaper(), non_neg_integer()) -> {shaper(), non_neg_integer()}.
update(none, _Size) -> {none, 0};
update(Shaper1, Size) ->
Shaper2 = p1_shaper:update(Shaper1, Size),
?DEBUG("Shaper update:~n~ts =>~n~ts",
[p1_shaper:pp(Shaper1), p1_shaper:pp(Shaper2)]),
Shaper2.
-spec validator(shaper | shaper_rules) -> econf:validator().
validator(shaper) ->
econf:options(
#{'_' => shaper_validator()},
[{disallowed, reserved()}, {return, map}, unique]);
validator(shaper_rules) ->
econf:options(
#{'_' => shaper_rules_validator()},
[{disallowed, reserved()}, unique]).
-spec shaper_rules_validator() -> econf:validator().
shaper_rules_validator() ->
fun(L) when is_list(L) ->
lists:map(
fun({K, V}) ->
{(shaper_name())(K), (acl:access_validator())(V)};
(N) ->
{(shaper_name())(N), [{acl, all}]}
end, lists:flatten(L));
(N) ->
[{(shaper_name())(N), [{acl, all}]}]
end.
-spec reload_from_config() -> ok.
reload_from_config() ->
gen_server:call(?MODULE, reload_from_config, timer:minutes(1)).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([]) ->
create_tabs(),
Hosts = ejabberd_option:hosts(),
load_from_config([], Hosts),
ejabberd_hooks:add(config_reloaded, ?MODULE, reload_from_config, 20),
{ok, #{hosts => Hosts}}.
-spec handle_call(term(), term(), state()) -> {reply, ok, state()} | {noreply, state()}.
handle_call(reload_from_config, _, #{hosts := OldHosts} = State) ->
NewHosts = ejabberd_option:hosts(),
load_from_config(OldHosts, NewHosts),
{reply, ok, State#{hosts => NewHosts}};
handle_call(Request, From, State) ->
?WARNING_MSG("Unexpected call from ~p: ~p", [From, Request]),
{noreply, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast(Msg, State) ->
?WARNING_MSG("Unexpected cast: ~p", [Msg]),
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info(Info, State) ->
?WARNING_MSG("Unexpected info: ~p", [Info]),
{noreply, State}.
-spec terminate(any(), state()) -> ok.
terminate(_Reason, _State) ->
ejabberd_hooks:delete(config_reloaded, ?MODULE, reload_from_config, 20).
-spec code_change(term(), state(), term()) -> {ok, state()}.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
%%%===================================================================
%%% Table management
%%%===================================================================
-spec load_from_config([binary()], [binary()]) -> ok.
load_from_config(OldHosts, NewHosts) ->
?DEBUG("Loading shaper rules from config", []),
Shapers = ejabberd_option:shaper(),
ets:insert(shaper, maps:to_list(Shapers)),
ets:insert(
shaper_rules,
lists:flatmap(
fun(Host) ->
lists:flatmap(
fun({Name, List}) ->
case resolve_shapers(Name, List, Shapers) of
[] -> [];
List1 ->
[{{Name, Host}, List1}]
end
end, ejabberd_option:shaper_rules(Host))
end, [global|NewHosts])),
lists:foreach(
fun(Host) ->
ets:match_delete(shaper_rules, {{'_', Host}, '_'})
end, OldHosts -- NewHosts),
?DEBUG("Shaper rules loaded successfully", []).
-spec create_tabs() -> ok.
create_tabs() ->
_ = mnesia:delete_table(shaper),
_ = ets:new(shaper, [named_table, {read_concurrency, true}]),
_ = ets:new(shaper_rules, [named_table, {read_concurrency, true}]),
ok.
-spec read_shaper_rules(atom(), global | binary()) -> [shaper_rate_rule()].
read_shaper_rules(Name, Host) ->
case ets:lookup(shaper_rules, {Name, Host}) of
[{_, Rule}] -> Rule;
[] -> []
end.
-spec read_shaper(atom() | shaper_rate()) -> none | shaper_rate().
read_shaper(Name) when is_atom(Name), Name /= none, Name /= infinity ->
case ets:lookup(shaper, Name) of
[{_, Rate}] -> Rate;
[] -> none
end;
read_shaper(Rate) ->
Rate.
%%%===================================================================
%%% Validators
%%%===================================================================
shaper_name() ->
econf:either(
econf:and_then(
econf:atom(),
fun(infinite) -> infinity;
(unlimited) -> infinity;
(A) -> A
end),
econf:pos_int()).
shaper_validator() ->
econf:either(
econf:and_then(
econf:options(
#{rate => econf:pos_int(),
burst_size => econf:pos_int()},
[unique, {required, [rate]}, {return, map}]),
fun(#{rate := Rate} = Map) ->
{Rate, maps:get(burst_size, Map, Rate)}
end),
econf:pos_int(infinity)).
%%%===================================================================
%%% Aux
%%%===================================================================
reserved() ->
[none, infinite, unlimited, infinity].
-spec resolve_shapers(atom(), [shaper_rule()], #{atom() => shaper_rate()}) -> [shaper_rate_rule()].
resolve_shapers(ShaperRule, Rules, Shapers) ->
lists:filtermap(
fun({Name, Rule}) when is_atom(Name), Name /= none, Name /= infinity ->
try {true, {maps:get(Name, Shapers), Rule}}
catch _:{badkey, _} ->
?WARNING_MSG(
"Shaper rule '~ts' refers to unknown shaper: ~ts",
[ShaperRule, Name]),
false
end;
(_) ->
true
end, Rules).