diff --git a/src/ejabberd_s2s.erl b/src/ejabberd_s2s.erl index 8057c9a35..f3259b75c 100644 --- a/src/ejabberd_s2s.erl +++ b/src/ejabberd_s2s.erl @@ -33,8 +33,8 @@ %% API -export([start_link/0, stop/0, route/1, have_connection/1, - get_connections_pids/1, try_register/1, - remove_connection/2, start_connection/2, start_connection/3, + get_connections_pids/1, + start_connection/2, start_connection/3, dirty_get_connections/0, allow_host/2, incoming_s2s_number/0, outgoing_s2s_number/0, stop_s2s_connections/0, @@ -112,24 +112,6 @@ is_temporarly_blocked(Host) -> end end. --spec remove_connection({binary(), binary()}, pid()) -> ok. -remove_connection({From, To} = FromTo, Pid) -> - case mnesia:dirty_match_object(s2s, #s2s{fromto = FromTo, pid = Pid}) of - [#s2s{pid = Pid}] -> - F = fun() -> - mnesia:delete_object(#s2s{fromto = FromTo, pid = Pid}) - end, - case mnesia:transaction(F) of - {atomic, _} -> ok; - {aborted, Reason} -> - ?ERROR_MSG("Failed to unregister s2s connection ~ts -> ~ts: " - "Mnesia failure: ~p", - [From, To, Reason]) - end; - _ -> - ok - end. - -spec have_connection({binary(), binary()}) -> boolean(). have_connection(FromTo) -> case catch mnesia:dirty_read(s2s, FromTo) of @@ -148,31 +130,6 @@ get_connections_pids(FromTo) -> [] end. --spec try_register({binary(), binary()}) -> boolean(). -try_register({From, To} = FromTo) -> - MaxS2SConnectionsNumber = max_s2s_connections_number(FromTo), - MaxS2SConnectionsNumberPerNode = - max_s2s_connections_number_per_node(FromTo), - F = fun () -> - L = mnesia:read({s2s, FromTo}), - NeededConnections = needed_connections_number(L, - MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode), - if NeededConnections > 0 -> - mnesia:write(#s2s{fromto = FromTo, pid = self()}), - true; - true -> false - end - end, - case mnesia:transaction(F) of - {atomic, Res} -> Res; - {aborted, Reason} -> - ?ERROR_MSG("Failed to register s2s connection ~ts -> ~ts: " - "Mnesia failure: ~p", - [From, To, Reason]), - false - end. - -spec dirty_get_connections() -> [{binary(), binary()}]. dirty_get_connections() -> mnesia:dirty_all_keys(s2s). @@ -269,6 +226,8 @@ init([]) -> {stop, Reason} end. +handle_call({new_connection, Args}, _From, State) -> + {reply, erlang:apply(fun new_connection_int/7, Args), State}; handle_call(Request, From, State) -> ?WARNING_MSG("Unexpected call from ~p: ~p", [From, Request]), {noreply, State}. @@ -289,6 +248,21 @@ handle_info({route, Packet}, State) -> misc:format_exception(2, Class, Reason, StackTrace)]) end, {noreply, State}; +handle_info({'DOWN', _Ref, process, Pid, _Reason}, State) -> + case mnesia:dirty_match_object(s2s, #s2s{pid = Pid, fromto = '_'}) of + [#s2s{pid = Pid, fromto = {From, To}} = Obj] -> + F = fun() -> mnesia:delete_object(Obj) end, + case mnesia:transaction(F) of + {atomic, _} -> ok; + {aborted, Reason} -> + ?ERROR_MSG("Failed to unregister s2s connection for pid ~p (~ts -> ~ts):" + "Mnesia failure: ~p", + [Pid, From, To, Reason]) + end, + {noreply, State}; + _ -> + {noreply, State} + end; handle_info(Info, State) -> ?WARNING_MSG("Unexpected info: ~p", [Info]), {noreply, State}. @@ -458,6 +432,18 @@ open_several_connections(N, MyServer, Server, From, integer(), integer(), [proplists:property()]) -> [pid()]. new_connection(MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) -> + case whereis(ejabberd_s2s) == self() of + true -> + new_connection_int(MyServer, Server, From, FromTo, + MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts); + false -> + gen_server:call(ejabberd_s2s, {new_connection, [MyServer, Server, From, FromTo, + MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode, Opts]}) + end. + +new_connection_int(MyServer, Server, From, FromTo, + MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) -> {ok, Pid} = ejabberd_s2s_out:start(MyServer, Server, Opts), F = fun() -> L = mnesia:read({s2s, FromTo}), @@ -474,6 +460,7 @@ new_connection(MyServer, Server, From, FromTo, case TRes of {atomic, Pid1} -> if Pid1 == Pid -> + erlang:monitor(process, Pid), ejabberd_s2s_out:connect(Pid); true -> ejabberd_s2s_out:stop_async(Pid) diff --git a/src/ejabberd_s2s_out.erl b/src/ejabberd_s2s_out.erl index d58396533..f057705ed 100644 --- a/src/ejabberd_s2s_out.erl +++ b/src/ejabberd_s2s_out.erl @@ -318,7 +318,6 @@ handle_info(Info, #{server_host := ServerHost} = State) -> terminate(Reason, #{server := LServer, remote_server := RServer} = State) -> - ejabberd_s2s:remove_connection({LServer, RServer}, self()), State1 = case Reason of normal -> State; _ -> State#{stop_reason => internal_failure} @@ -351,21 +350,12 @@ bounce_queue(State) -> end, State). -spec bounce_message_queue({binary(), binary()}, state()) -> state(). -bounce_message_queue({LServer, RServer} = FromTo, State) -> - Pids = ejabberd_s2s:get_connections_pids(FromTo), - case lists:member(self(), Pids) of - true -> - ?WARNING_MSG("Outgoing s2s connection ~ts -> ~ts is supposed " - "to be unregistered, but pid ~p still presents " - "in 's2s' table", [LServer, RServer, self()]), - State; - false -> - receive {route, Pkt} -> - State1 = bounce_packet(Pkt, State), - bounce_message_queue(FromTo, State1) - after 0 -> - State - end +bounce_message_queue(FromTo, State) -> + receive {route, Pkt} -> + State1 = bounce_packet(Pkt, State), + bounce_message_queue(FromTo, State1) + after 0 -> + State end. -spec bounce_packet(xmpp_element(), state()) -> state().