From f19b41fd1948a38d0c8f1ea81e3613ea4335869c Mon Sep 17 00:00:00 2001 From: Evgeny Khramtsov Date: Tue, 9 Jul 2019 16:42:24 +0300 Subject: [PATCH] Improve type specs for ejabberd_s2s Also minor code cleanup --- src/ejabberd_s2s.erl | 202 +++++++++++++++++++++---------------------- 1 file changed, 101 insertions(+), 101 deletions(-) diff --git a/src/ejabberd_s2s.erl b/src/ejabberd_s2s.erl index 2128d6b6a..a1937da7e 100644 --- a/src/ejabberd_s2s.erl +++ b/src/ejabberd_s2s.erl @@ -54,51 +54,38 @@ -include("logger.hrl"). -include("xmpp.hrl"). -include("ejabberd_commands.hrl"). --include_lib("public_key/include/public_key.hrl"). +-include_lib("stdlib/include/ms_transform.hrl"). -include("ejabberd_stacktrace.hrl"). -include("translate.hrl"). --define(PKIXEXPLICIT, 'OTP-PUB-KEY'). - --define(PKIXIMPLICIT, 'OTP-PUB-KEY'). - --include("XmppAddr.hrl"). - -define(DEFAULT_MAX_S2S_CONNECTIONS_NUMBER, 1). - -define(DEFAULT_MAX_S2S_CONNECTIONS_NUMBER_PER_NODE, 1). - -define(S2S_OVERLOAD_BLOCK_PERIOD, 60). %% once a server is temporarly blocked, it stay blocked for 60 seconds --record(s2s, {fromto = {<<"">>, <<"">>} :: {binary(), binary()} | '_', - pid = self() :: pid() | '_' | '$1'}). +-record(s2s, {fromto :: {binary(), binary()}, + pid :: pid()}). -record(state, {}). --record(temporarily_blocked, {host = <<"">> :: binary(), - timestamp :: integer()}). +-record(temporarily_blocked, {host :: binary(), + timestamp :: integer()}). -type temporarily_blocked() :: #temporarily_blocked{}. - start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], - []). + gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). clean_temporarily_blocked_table() -> mnesia:clear_table(temporarily_blocked). -spec list_temporarily_blocked_hosts() -> [temporarily_blocked()]. - list_temporarily_blocked_hosts() -> ets:tab2list(temporarily_blocked). -spec external_host_overloaded(binary()) -> {aborted, any()} | {atomic, ok}. - external_host_overloaded(Host) -> - ?INFO_MSG("Disabling connections from ~s for ~p " - "seconds", + ?INFO_MSG("Disabling s2s connections to ~s for ~p seconds", [Host, ?S2S_OVERLOAD_BLOCK_PERIOD]), mnesia:transaction(fun () -> Time = erlang:monotonic_time(), @@ -107,21 +94,20 @@ external_host_overloaded(Host) -> end). -spec is_temporarly_blocked(binary()) -> boolean(). - is_temporarly_blocked(Host) -> case mnesia:dirty_read(temporarily_blocked, Host) of - [] -> false; - [#temporarily_blocked{timestamp = T} = Entry] -> - Diff = erlang:monotonic_time() - T, - case erlang:convert_time_unit(Diff, native, microsecond) of - N when N > (?S2S_OVERLOAD_BLOCK_PERIOD) * 1000 * 1000 -> - mnesia:dirty_delete_object(Entry), false; - _ -> true - end + [] -> false; + [#temporarily_blocked{timestamp = T} = Entry] -> + Diff = erlang:monotonic_time() - T, + case erlang:convert_time_unit(Diff, native, microsecond) of + N when N > (?S2S_OVERLOAD_BLOCK_PERIOD) * 1000 * 1000 -> + mnesia:dirty_delete_object(Entry), false; + _ -> true + end end. -spec remove_connection({binary(), binary()}, pid()) -> ok. -remove_connection(FromTo, Pid) -> +remove_connection({From, To} = FromTo, Pid) -> case mnesia:dirty_match_object(s2s, #s2s{fromto = FromTo, pid = Pid}) of [#s2s{pid = Pid}] -> F = fun() -> @@ -130,25 +116,24 @@ remove_connection(FromTo, Pid) -> case mnesia:transaction(F) of {atomic, _} -> ok; {aborted, Reason} -> - ?ERROR_MSG("Failed to unregister s2s connection: " - "Mnesia failure: ~p", [Reason]) + ?ERROR_MSG("Failed to unregister s2s connection ~s -> ~s: " + "Mnesia failure: ~p", + [From, To, Reason]) end; _ -> ok end. -spec have_connection({binary(), binary()}) -> boolean(). - have_connection(FromTo) -> case catch mnesia:dirty_read(s2s, FromTo) of - [_] -> + [_] -> true; _ -> false end. -spec get_connections_pids({binary(), binary()}) -> [pid()]. - get_connections_pids(FromTo) -> case catch mnesia:dirty_read(s2s, FromTo) of L when is_list(L) -> @@ -158,8 +143,7 @@ get_connections_pids(FromTo) -> end. -spec try_register({binary(), binary()}) -> boolean(). - -try_register(FromTo) -> +try_register({From, To} = FromTo) -> MaxS2SConnectionsNumber = max_s2s_connections_number(FromTo), MaxS2SConnectionsNumberPerNode = max_s2s_connections_number_per_node(FromTo), @@ -169,18 +153,21 @@ try_register(FromTo) -> MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode), if NeededConnections > 0 -> - mnesia:write(#s2s{fromto = FromTo, pid = self()}), - true; + mnesia:write(#s2s{fromto = FromTo, pid = self()}), + true; true -> false end end, case mnesia:transaction(F) of - {atomic, Res} -> Res; - _ -> false + {atomic, Res} -> Res; + {aborted, Reason} -> + ?ERROR_MSG("Failed to register s2s connection ~s -> ~s: " + "Mnesia failure: ~p", + [From, To, Reason]), + false end. -spec dirty_get_connections() -> [{binary(), binary()}]. - dirty_get_connections() -> mnesia:dirty_all_keys(s2s). @@ -276,10 +263,12 @@ init([]) -> {stop, Reason} end. -handle_call(_Request, _From, State) -> - {reply, ok, State}. +handle_call(Request, From, State) -> + ?WARNING_MSG("Unexpected call from ~p: ~p", [From, Request]), + {noreply, State}. -handle_cast(_Msg, State) -> +handle_cast(Msg, State) -> + ?WARNING_MSG("Unexpected cast: ~p", [Msg]), {noreply, State}. handle_info({mnesia_system_event, {mnesia_down, Node}}, State) -> @@ -294,14 +283,15 @@ handle_info({route, Packet}, State) -> misc:format_exception(2, Class, Reason, StackTrace)]) end, {noreply, State}; -handle_info(_Info, State) -> {noreply, State}. +handle_info(Info, State) -> + ?WARNING_MSG("Unexpected info: ~p", [Info]), + {noreply, State}. terminate(_Reason, _State) -> ejabberd_commands:unregister_commands(get_commands_spec()), lists:foreach(fun host_down/1, ejabberd_option:hosts()), ejabberd_hooks:delete(host_up, ?MODULE, host_up, 50), - ejabberd_hooks:delete(host_down, ?MODULE, host_down, 60), - ok. + ejabberd_hooks:delete(host_down, ?MODULE, host_down, 60). code_change(_OldVsn, State, _Extra) -> {ok, State}. @@ -309,10 +299,12 @@ code_change(_OldVsn, State, _Extra) -> %%-------------------------------------------------------------------- %%% Internal functions %%-------------------------------------------------------------------- +-spec host_up(binary()) -> ok. host_up(Host) -> ejabberd_s2s_in:host_up(Host), ejabberd_s2s_out:host_up(Host). +-spec host_down(binary()) -> ok. host_down(Host) -> lists:foreach( fun(#s2s{fromto = {From, _}, pid = Pid}) when node(Pid) == node() -> @@ -334,12 +326,11 @@ clean_table_from_bad_node(Node) -> F = fun() -> Es = mnesia:select( s2s, - [{#s2s{pid = '$1', _ = '_'}, - [{'==', {node, '$1'}, Node}], - ['$_']}]), - lists:foreach(fun(E) -> - mnesia:delete_object(E) - end, Es) + ets:fun2ms( + fun(#s2s{pid = Pid} = E) when node(Pid) == Node -> + E + end)), + lists:foreach(fun mnesia:delete_object/1, Es) end, mnesia:async_dirty(F). @@ -350,12 +341,12 @@ route(Packet) -> To = xmpp:get_to(Packet), case start_connection(From, To) of {ok, Pid} when is_pid(Pid) -> - ?DEBUG("Sending to process ~p~n", [Pid]), - #jid{lserver = MyServer} = From, + ?DEBUG("Sending to process ~p~n", [Pid]), + #jid{lserver = MyServer} = From, ejabberd_hooks:run(s2s_send_packet, MyServer, [Packet]), ejabberd_s2s_out:route(Pid, Packet); {error, Reason} -> - Lang = xmpp:get_lang(Packet), + Lang = xmpp:get_lang(Packet), Err = case Reason of forbidden -> xmpp:err_forbidden(?T("Access denied by service policy"), Lang); @@ -366,12 +357,12 @@ route(Packet) -> end. -spec start_connection(jid(), jid()) - -> {ok, pid()} | {error, forbidden | internal_server_error}. + -> {ok, pid()} | {error, forbidden | internal_server_error}. start_connection(From, To) -> start_connection(From, To, []). -spec start_connection(jid(), jid(), [proplists:property()]) - -> {ok, pid()} | {error, forbidden | internal_server_error}. + -> {ok, pid()} | {error, forbidden | internal_server_error}. start_connection(From, To, Opts) -> #jid{lserver = MyServer} = From, #jid{lserver = Server} = To, @@ -382,11 +373,11 @@ start_connection(From, To, Opts) -> max_s2s_connections_number_per_node(FromTo), ?DEBUG("Finding connection for ~p~n", [FromTo]), case mnesia:dirty_read(s2s, FromTo) of - [] -> - %% We try to establish all the connections if the host is not a - %% service and if the s2s host is not blacklisted or - %% is in whitelist: - LServer = ejabberd_router:host_of_route(MyServer), + [] -> + %% We try to establish all the connections if the host is not a + %% service and if the s2s host is not blacklisted or + %% is in whitelist: + LServer = ejabberd_router:host_of_route(MyServer), case allow_host(LServer, Server) of true -> NeededConnections = needed_connections_number( @@ -400,20 +391,20 @@ start_connection(From, To, Opts) -> false -> {error, forbidden} end; - L when is_list(L) -> - NeededConnections = needed_connections_number(L, - MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode), - if NeededConnections > 0 -> - %% We establish the missing connections for this pair. - open_several_connections(NeededConnections, MyServer, - Server, From, FromTo, - MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode, Opts); - true -> - %% We choose a connexion from the pool of opened ones. - {ok, choose_connection(From, L)} - end + L when is_list(L) -> + NeededConnections = needed_connections_number(L, + MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode), + if NeededConnections > 0 -> + %% We establish the missing connections for this pair. + open_several_connections(NeededConnections, MyServer, + Server, From, FromTo, + MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode, Opts); + true -> + %% We choose a connexion from the pool of opened ones. + {ok, choose_connection(From, L)} + end end. -spec choose_connection(jid(), [#s2s{}]) -> pid(). @@ -423,8 +414,8 @@ choose_connection(From, Connections) -> -spec choose_pid(jid(), [pid()]) -> pid(). choose_pid(From, Pids) -> Pids1 = case [P || P <- Pids, node(P) == node()] of - [] -> Pids; - Ps -> Ps + [] -> Pids; + Ps -> Ps end, Pid = lists:nth(erlang:phash(jid:remove_resource(From), @@ -433,13 +424,17 @@ choose_pid(From, Pids) -> ?DEBUG("Using ejabberd_s2s_out ~p~n", [Pid]), Pid. +-spec open_several_connections(pos_integer(), binary(), binary(), + jid(), {binary(), binary()}, + integer(), integer(), [proplists:property()]) -> + {ok, pid()} | {error, internal_server_error}. open_several_connections(N, MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) -> case lists:flatmap( fun(_) -> new_connection(MyServer, Server, - From, FromTo, MaxS2SConnectionsNumber, + From, FromTo, MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) end, lists:seq(1, N)) of [] -> @@ -448,6 +443,8 @@ open_several_connections(N, MyServer, Server, From, {ok, choose_pid(From, PIDs)} end. +-spec new_connection(binary(), binary(), jid(), {binary(), binary()}, + integer(), integer(), [proplists:property()]) -> [pid()]. new_connection(MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) -> {ok, Pid} = ejabberd_s2s_out:start(MyServer, Server, Opts), @@ -457,22 +454,23 @@ new_connection(MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode), if NeededConnections > 0 -> - mnesia:write(#s2s{fromto = FromTo, pid = Pid}), - Pid; + mnesia:write(#s2s{fromto = FromTo, pid = Pid}), + Pid; true -> choose_connection(From, L) end end, TRes = mnesia:transaction(F), case TRes of - {atomic, Pid1} -> + {atomic, Pid1} -> if Pid1 == Pid -> ejabberd_s2s_out:connect(Pid); true -> ejabberd_s2s_out:stop(Pid) end, [Pid1]; - {aborted, Reason} -> - ?ERROR_MSG("Failed to register connection ~s -> ~s: ~p", + {aborted, Reason} -> + ?ERROR_MSG("Failed to register s2s connection ~s -> ~s: " + "Mnesia failure: ~p", [MyServer, Server, Reason]), ejabberd_s2s_out:stop(Pid), [] @@ -529,11 +527,13 @@ incoming_s2s_number() -> outgoing_s2s_number() -> supervisor_count(ejabberd_s2s_out_sup). +-spec supervisor_count(atom()) -> non_neg_integer(). supervisor_count(Supervisor) -> - case catch supervisor:which_children(Supervisor) of - {'EXIT', _} -> 0; - Result -> - length(Result) + try supervisor:count_children(Supervisor) of + Props -> + proplists:get_value(workers, Props, 0) + catch _:_ -> + 0 end. -spec stop_s2s_connections() -> ok. @@ -557,10 +557,12 @@ update_tables() -> ok. %% Check if host is in blacklist or white list +-spec allow_host(binary(), binary()) -> boolean(). allow_host(MyServer, S2SHost) -> allow_host1(MyServer, S2SHost) andalso - not is_temporarly_blocked(S2SHost). + not is_temporarly_blocked(S2SHost). +-spec allow_host1(binary(), binary()) -> boolean(). allow_host1(MyHost, S2SHost) -> Rule = ejabberd_option:s2s_access(MyHost), JID = jid:make(S2SHost), @@ -570,8 +572,7 @@ allow_host1(MyHost, S2SHost) -> case ejabberd_hooks:run_fold(s2s_allow_host, MyHost, allow, [MyHost, S2SHost]) of deny -> false; - allow -> true; - _ -> true + allow -> true end end. @@ -581,8 +582,8 @@ allow_host1(MyHost, S2SHost) -> %% Info = [{InfoName::atom(), InfoValue::any()}] get_info_s2s_connections(Type) -> ChildType = case Type of - in -> ejabberd_s2s_in_sup; - out -> ejabberd_s2s_out_sup + in -> ejabberd_s2s_in_sup; + out -> ejabberd_s2s_out_sup end, Connections = supervisor:which_children(ChildType), get_s2s_info(Connections, Type). @@ -597,13 +598,12 @@ complete_s2s_info([Connection | T], Type, Result) -> complete_s2s_info(T, Type, [State | Result]). -spec get_s2s_state(pid()) -> [{status, open | closed | error} | {s2s_pid, pid()}]. - get_s2s_state(S2sPid) -> Infos = case p1_fsm:sync_send_all_state_event(S2sPid, - get_state_infos) - of - {state_infos, Is} -> [{status, open} | Is]; - {noproc, _} -> [{status, closed}]; %% Connection closed - {badrpc, _} -> [{status, error}] + get_state_infos) + of + {state_infos, Is} -> [{status, open} | Is]; + {noproc, _} -> [{status, closed}]; %% Connection closed + {badrpc, _} -> [{status, error}] end, [{s2s_pid, S2sPid} | Infos].