From cd336369a5691da8289574f402fa2311b6dc027c Mon Sep 17 00:00:00 2001 From: Holger Weiss Date: Mon, 1 Jun 2020 21:33:55 +0200 Subject: [PATCH] mod_stream_mgmt: Don't kill new PID on resumption During XEP-0198 resumption, the ejabberd_c2s process that handles the new connection reopens the ejabberd_sm session of the old one. Since commit b4770815c0b0416c21d01507d2908f94c25b3097, the new process adds the new session table entry before the old process removes the old one. While adding the new one, ejabberd_sm checks for old sessions to replace. This check assumes old SIDs compare lower than new ones. This assumption didn't necessarily hold for the session resumption case, where the old SID's timestamp was copied over to the new SID and only the PID was updated. Therefore, the new process was killed if the new PID happened to be smaller than the old one. Fix this by having mod_stream_mgmt use its own SM-ID rather than copying over the old SID's timestamp to the new SID. Thanks to Thilo Molitor and Friedrich Altheide for reporting the issue, and to Thomas Leister for his help with debugging it. --- src/ejabberd_c2s.erl | 8 ++-- src/mod_stream_mgmt.erl | 86 ++++++++++++++++++++++++----------------- 2 files changed, 54 insertions(+), 40 deletions(-) diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index c88221e27..92517c92a 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -180,10 +180,9 @@ host_down(Host) -> %% Copies content of one c2s state to another. %% This is needed for session migration from one pid to another. -spec copy_state(state(), state()) -> state(). -copy_state(#{owner := Owner} = NewState, - #{jid := JID, resource := Resource, sid := {Time, _}, - auth_module := AuthModule, lserver := LServer, - pres_a := PresA} = OldState) -> +copy_state(NewState, + #{jid := JID, resource := Resource, auth_module := AuthModule, + lserver := LServer, pres_a := PresA} = OldState) -> State1 = case OldState of #{pres_last := Pres, pres_timestamp := PresTS} -> NewState#{pres_last => Pres, pres_timestamp => PresTS}; @@ -193,7 +192,6 @@ copy_state(#{owner := Owner} = NewState, Conn = get_conn_type(State1), State2 = State1#{jid => JID, resource => Resource, conn => Conn, - sid => {Time, Owner}, auth_module => AuthModule, pres_a => PresA}, ejabberd_hooks:run_fold(c2s_copy_session, LServer, State2, [OldState]). diff --git a/src/mod_stream_mgmt.erl b/src/mod_stream_mgmt.erl index 45d52ccb0..ee72e152f 100644 --- a/src/mod_stream_mgmt.erl +++ b/src/mod_stream_mgmt.erl @@ -52,6 +52,7 @@ -type state() :: ejabberd_c2s:state(). -type queue() :: p1_queue:queue({non_neg_integer(), erlang:timestamp(), xmpp_element() | xmlel()}). +-type id() :: binary(). -type error_reason() :: session_not_found | session_timed_out | session_is_dead | session_has_exited | session_was_killed | session_copy_timed_out | @@ -228,8 +229,8 @@ c2s_handle_send(#{mgmt_state := MgmtState, mod := Mod, c2s_handle_send(State, _Pkt, _Result) -> State. -c2s_handle_call(#{sid := {Time, _}, mod := Mod, mgmt_queue := Queue} = State, - {resume_session, Time}, From) -> +c2s_handle_call(#{mgmt_id := MgmtID, mgmt_queue := Queue, mod := Mod} = State, + {resume_session, MgmtID}, From) -> State1 = State#{mgmt_queue => p1_queue:file_to_ram(Queue)}, Mod:reply(From, {resume, State1}), {stop, State#{mgmt_state => resumed, mgmt_queue => p1_queue:clear(Queue)}}; @@ -288,10 +289,10 @@ c2s_terminated(#{mgmt_state := resumed, sid := SID, jid := JID} = State, _Reason ejabberd_c2s:bounce_message_queue(SID, JID), {stop, State}; c2s_terminated(#{mgmt_state := MgmtState, mgmt_stanzas_in := In, - sid := {Time, _}, jid := JID} = State, _Reason) -> + mgmt_id := MgmtID, jid := JID} = State, _Reason) -> case MgmtState of timeout -> - store_stanzas_in(jid:tolower(JID), Time, In); + store_stanzas_in(jid:tolower(JID), MgmtID, In); _ -> ok end, @@ -377,6 +378,7 @@ handle_enable(#{mgmt_timeout := DefaultTimeout, mgmt_max_timeout := MaxTimeout, mgmt_xmlns := Xmlns, jid := JID} = State, #sm_enable{resume = Resume, max = Max}) -> + State1 = State#{mgmt_id => make_id()}, Timeout = if Resume == false -> 0; Max /= undefined, Max > 0, Max*1000 =< MaxTimeout -> @@ -388,7 +390,7 @@ handle_enable(#{mgmt_timeout := DefaultTimeout, ?DEBUG("Stream management with resumption enabled for ~ts", [jid:encode(JID)]), #sm_enabled{xmlns = Xmlns, - id = make_resume_id(State), + id = encode_id(State1), resume = true, max = Timeout div 1000}; true -> @@ -396,10 +398,10 @@ handle_enable(#{mgmt_timeout := DefaultTimeout, [jid:encode(JID)]), #sm_enabled{xmlns = Xmlns} end, - State1 = State#{mgmt_state => active, - mgmt_queue => p1_queue:new(QueueType), - mgmt_timeout => Timeout}, - send(State1, Res). + State2 = State1#{mgmt_state => active, + mgmt_queue => p1_queue:new(QueueType), + mgmt_timeout => Timeout}, + send(State2, Res). -spec handle_r(state()) -> state(). handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) -> @@ -431,10 +433,9 @@ handle_resume(#{user := User, lserver := LServer, {ok, #{jid := JID} = ResumedState, NumHandled} -> State1 = check_h_attribute(ResumedState, NumHandled), #{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1, - AttrId = make_resume_id(State1), State2 = send(State1, #sm_resumed{xmlns = AttrXmlns, h = AttrH, - previd = AttrId}), + previd = PrevID}), State3 = resend_unacked_stanzas(State2), State4 = send(State3, #sm_r{xmlns = AttrXmlns}), State5 = ejabberd_hooks:run_fold(c2s_session_resumed, LServer, State4, []), @@ -649,20 +650,19 @@ route_unacked_stanzas(_State) -> {error, error_reason()} | {error, error_reason(), non_neg_integer()}. inherit_session_state(#{user := U, server := S, - mgmt_queue_type := QueueType} = State, ResumeID) -> - case misc:base64_to_term(ResumeID) of - {term, {R, Time}} -> - case ejabberd_sm:get_session_pid(U, S, R) of + mgmt_queue_type := QueueType} = State, PrevID) -> + case decode_id(PrevID) of + {ok, {R, MgmtID}} -> + case ejabberd_sm:get_session_sid(U, S, R) of none -> - case pop_stanzas_in({U, S, R}, Time) of + case pop_stanzas_in({U, S, R}, MgmtID) of error -> {error, session_not_found}; {ok, H} -> {error, session_timed_out, H} end; - OldPID -> - OldSID = {Time, OldPID}, - try resume_session(OldSID, State) of + {_, OldPID} = OldSID -> + try resume_session(OldPID, MgmtID, State) of {resume, #{mgmt_xmlns := Xmlns, mgmt_queue := Queue, mgmt_timeout := Timeout, @@ -673,7 +673,9 @@ inherit_session_state(#{user := U, server := S, ram -> Queue; _ -> p1_queue:ram_to_file(Queue) end, - State2 = State1#{mgmt_xmlns => Xmlns, + State2 = State1#{sid => ejabberd_sm:make_sid(), + mgmt_id => MgmtID, + mgmt_xmlns => Xmlns, mgmt_queue => Queue1, mgmt_timeout => Timeout, mgmt_stanzas_in => NumStanzasIn, @@ -698,18 +700,14 @@ inherit_session_state(#{user := U, server := S, {error, session_copy_timed_out} end end; - _ -> + error -> {error, invalid_previd} end. --spec resume_session({erlang:timestamp(), pid()}, state()) -> {resume, state()} | - {error, error_reason()}. -resume_session({Time, Pid}, _State) -> - ejabberd_c2s:call(Pid, {resume_session, Time}, timer:seconds(15)). - --spec make_resume_id(state()) -> binary(). -make_resume_id(#{sid := {Time, _}, resource := Resource}) -> - misc:term_to_base64({Resource, Time}). +-spec resume_session(pid(), id(), state()) -> {resume, state()} | + {error, error_reason()}. +resume_session(PID, MgmtID, _State) -> + ejabberd_c2s:call(PID, {resume_session, MgmtID}, timer:seconds(15)). -spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza(); (state(), xmlel(), erlang:timestamp()) -> xmlel(). @@ -756,6 +754,24 @@ need_to_enqueue(#{mgmt_force_enqueue := true} = State, #xmlel{}) -> need_to_enqueue(State, _) -> {false, State}. +-spec make_id() -> id(). +make_id() -> + p1_rand:bytes(8). + +-spec encode_id(state()) -> binary(). +encode_id(#{mgmt_id := MgmtID, resource := Resource}) -> + misc:term_to_base64({Resource, MgmtID}). + +-spec decode_id(binary()) -> {ok, {binary(), id()}} | error. +decode_id(Encoded) -> + case misc:base64_to_term(Encoded) of + {term, {Resource, MgmtID}} when is_binary(Resource), + is_binary(MgmtID) -> + {ok, {Resource, MgmtID}}; + _ -> + error + end. + %%%=================================================================== %%% Formatters and Logging %%%=================================================================== @@ -803,14 +819,14 @@ cache_opts(Opts) -> {life_time, mod_stream_mgmt_opt:cache_life_time(Opts)}, {type, ordered_set}]. --spec store_stanzas_in(ljid(), erlang:timestamp(), non_neg_integer()) -> boolean(). -store_stanzas_in(LJID, Time, Num) -> - ets_cache:insert(?STREAM_MGMT_CACHE, {LJID, Time}, Num, +-spec store_stanzas_in(ljid(), id(), non_neg_integer()) -> boolean(). +store_stanzas_in(LJID, MgmtID, Num) -> + ets_cache:insert(?STREAM_MGMT_CACHE, {LJID, MgmtID}, Num, ejabberd_cluster:get_nodes()). --spec pop_stanzas_in(ljid(), erlang:timestamp()) -> {ok, non_neg_integer()} | error. -pop_stanzas_in(LJID, Time) -> - case ets_cache:lookup(?STREAM_MGMT_CACHE, {LJID, Time}) of +-spec pop_stanzas_in(ljid(), id()) -> {ok, non_neg_integer()} | error. +pop_stanzas_in(LJID, MgmtID) -> + case ets_cache:lookup(?STREAM_MGMT_CACHE, {LJID, MgmtID}) of {ok, Val} -> ets_cache:match_delete(?STREAM_MGMT_CACHE, {LJID, '_'}, ejabberd_cluster:get_nodes()),