From 8f3f74a6d7063781e475e50737e3976d0f43d801 Mon Sep 17 00:00:00 2001 From: Evgeniy Khramtsov Date: Fri, 16 Nov 2012 20:12:24 +1000 Subject: [PATCH] Only migrate C2S processes with remote sockets Conflicts: src/ejabberd_c2s.erl src/ejabberd_sm.erl --- src/ejabberd_c2s.erl | 56 ++++++++++++++++++++++++++++++++++---------- src/ejabberd_sm.erl | 20 ++++++++++------ 2 files changed, 56 insertions(+), 20 deletions(-) diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index 68bb7c34c..58a1f1ae2 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -38,7 +38,8 @@ -export([start/2, stop_or_detach/1, start_link/3, send_text/2, send_element/2, socket_type/0, get_presence/1, get_aux_field/2, set_aux_field/3, del_aux_field/2, - get_subscription/2, broadcast/4, get_subscribed/1]). + get_subscription/2, broadcast/4, is_remote_socket/1, + get_subscribed/1]). %% API: -export([add_rosteritem/3, del_rosteritem/2]). @@ -229,6 +230,19 @@ migrate(_FsmRef, _Node, _After) -> migrate_shutdown(FsmRef, Node, After) -> FsmRef ! {migrate_shutdown, Node, After}. +is_remote_socket(Pid) when node(Pid) == node() -> + case catch process_info(Pid, dictionary) of + {dictionary, Dict} -> + SockMod = proplists:get_value(c2s_sockmod, Dict), + XMLSocket = proplists:get_value(c2s_xml_socket, Dict), + Socket = proplists:get_value(c2s_socket, Dict), + is_remote_socket(SockMod, XMLSocket, Socket); + _ -> + false + end; +is_remote_socket(_) -> + false. + %%%---------------------------------------------------------------------- %%% Callback functions from gen_fsm %%%---------------------------------------------------------------------- @@ -292,6 +306,7 @@ init([{SockMod, Socket}, Opts, FSMLimitOpts]) -> fsm_limit_opts = FSMLimitOpts}, erlang:send_after(?C2S_OPEN_TIMEOUT, self(), open_timeout), + update_internal_dict(StateData), case get_jid_from_opts(Opts) of {ok, #jid{user = U, server = Server, resource = R} = JID} -> @@ -305,6 +320,7 @@ init([{SockMod, Socket}, Opts, FSMLimitOpts]) -> init([StateName, StateData, _FSMLimitOpts]) -> MRef = (StateData#state.sockmod):monitor(StateData#state.socket), + update_internal_dict(StateData), if StateName == session_established -> Conn = (StateData#state.sockmod):get_conn_type( StateData#state.socket), @@ -894,9 +910,10 @@ wait_for_feature_request({xmlstreamelement, El}, = []})), fsm_next_state(wait_for_stream, - StateData#state{socket = TLSSocket, - streamid = new_id(), - tls_enabled = true}); + update_internal_dict( + StateData#state{socket = TLSSocket, + streamid = new_id(), + tls_enabled = true})); {?NS_COMPRESS, <<"compress">>} when Zlib == true, (SockMod == gen_tcp) or (SockMod == tls) -> @@ -925,8 +942,9 @@ wait_for_feature_request({xmlstreamelement, El}, = []})), fsm_next_state(wait_for_stream, - StateData#state{socket = ZlibSocket, - streamid = new_id()}); + update_internal_dict( + StateData#state{socket = ZlibSocket, + streamid = new_id()})); _ -> send_element(StateData, #xmlel{name = <<"failure">>, @@ -1889,12 +1907,9 @@ handle_info({migrate, Node}, StateName, StateData) -> end; handle_info({migrate_shutdown, Node, After}, StateName, StateData) -> - case StateData#state.sockmod == ejabberd_frontend_socket - orelse - StateData#state.xml_socket == true orelse - (StateData#state.sockmod):is_remote_receiver( - StateData#state.socket) - of + case is_remote_socket(StateData#state.sockmod, + StateData#state.xml_socket, + StateData#state.socket) of true -> migrate(self(), Node, After); false -> self() ! system_shutdown end, @@ -1920,8 +1935,9 @@ handle_info({change_socket, Socket}, StateName, Socket), MRef = (StateData#state.sockmod):monitor(NewSocket), fsm_next_state(StateName, + update_internal_dict( StateData#state{socket = NewSocket, - socket_monitor = MRef}); + socket_monitor = MRef})); handle_info(Info, StateName, StateData) -> ?ERROR_MSG("Unexpected info: ~p", [Info]), fsm_next_state(StateName, StateData). @@ -3111,6 +3127,7 @@ rebind(StateData, JID, StreamID) -> keepalive_timer = StateData#state.keepalive_timer, ack_timer = undefined}, + update_internal_dict(StateData2), send_element(StateData2, #xmlel{name = <<"rebind">>, attrs = [{<<"xmlns">>, ?NS_P1_REBIND}], @@ -3474,3 +3491,16 @@ get_jid_from_opts(Opts) -> {ok, JID}; _ -> error end. + +update_internal_dict(#state{sockmod = SockMod, + xml_socket = XMLSocket, + socket = Socket} = StateData) -> + put(c2s_sockmod, SockMod), + put(c2s_xml_socket, XMLSocket), + put(c2s_socket, Socket), + StateData. + +is_remote_socket(SockMod, XMLSocket, Socket) -> + SockMod == ejabberd_frontend_socket orelse + XMLSocket == true orelse + SockMod:is_remote_receiver(Socket). diff --git a/src/ejabberd_sm.erl b/src/ejabberd_sm.erl index 899573d45..d75da9354 100644 --- a/src/ejabberd_sm.erl +++ b/src/ejabberd_sm.erl @@ -389,12 +389,13 @@ migrate(InitiatorNode, UpOrDown, After) -> -spec node_up(atom()) -> ok. node_up(_Node) -> - copy_sessions(mnesia:dirty_first(session)). + copy_sessions(mnesia:dirty_first(session), fun(_) -> true end). -spec node_down(atom()) -> ok. node_down(Node) when Node == node() -> - copy_sessions(mnesia:dirty_first(session)); + copy_sessions(mnesia:dirty_first(session), + fun ejabberd_c2s:is_remote_socket/1); node_down(Node) -> ets:select_delete( session, @@ -402,18 +403,23 @@ node_down(Node) -> [{'==', {'node', '$1'}, Node}], [true]}]). -copy_sessions('$end_of_table') -> ok; -copy_sessions(Key) -> +copy_sessions('$end_of_table', _CheckFun) -> ok; +copy_sessions(Key, CheckFun) -> case mnesia:dirty_read(session, Key) of - [#session{us = US} = Session] -> + [#session{us = US, sid = {_, Pid}} = Session] -> case ejabberd_cluster:get_node_new(US) of Node when node() /= Node -> - rpc:cast(Node, mnesia, dirty_write, [Session]); + case CheckFun(Pid) of + true -> + rpc:cast(Node, mnesia, dirty_write, [Session]); + false -> + ok + end; _ -> ok end; _ -> ok end, - copy_sessions(mnesia:dirty_next(session, Key)). + copy_sessions(mnesia:dirty_next(session, Key), CheckFun). %%==================================================================== %% gen_server callbacks