mod_stream_mgmt: Don't kill new PID on resumption

During XEP-0198 resumption, the ejabberd_c2s process that handles the
new connection reopens the ejabberd_sm session of the old one.  Since
commit b4770815c0, the new process adds
the new session table entry before the old process removes the old one.
While adding the new one, ejabberd_sm checks for old sessions to
replace.  This check assumes old SIDs compare lower than new ones.  This
assumption didn't necessarily hold for the session resumption case,
where the old SID's timestamp was copied over to the new SID and only
the PID was updated.  Therefore, the new process was killed if the new
PID happened to be smaller than the old one.

Fix this by having mod_stream_mgmt use its own SM-ID rather than copying
over the old SID's timestamp to the new SID.

Thanks to Thilo Molitor and Friedrich Altheide for reporting the issue,
and to Thomas Leister for his help with debugging it.
This commit is contained in:
Holger Weiss 2020-06-01 21:33:55 +02:00
parent c62956ab7b
commit cd336369a5
2 changed files with 54 additions and 40 deletions

View File

@ -180,10 +180,9 @@ host_down(Host) ->
%% Copies content of one c2s state to another.
%% This is needed for session migration from one pid to another.
-spec copy_state(state(), state()) -> state().
copy_state(#{owner := Owner} = NewState,
#{jid := JID, resource := Resource, sid := {Time, _},
auth_module := AuthModule, lserver := LServer,
pres_a := PresA} = OldState) ->
copy_state(NewState,
#{jid := JID, resource := Resource, auth_module := AuthModule,
lserver := LServer, pres_a := PresA} = OldState) ->
State1 = case OldState of
#{pres_last := Pres, pres_timestamp := PresTS} ->
NewState#{pres_last => Pres, pres_timestamp => PresTS};
@ -193,7 +192,6 @@ copy_state(#{owner := Owner} = NewState,
Conn = get_conn_type(State1),
State2 = State1#{jid => JID, resource => Resource,
conn => Conn,
sid => {Time, Owner},
auth_module => AuthModule,
pres_a => PresA},
ejabberd_hooks:run_fold(c2s_copy_session, LServer, State2, [OldState]).

View File

@ -52,6 +52,7 @@
-type state() :: ejabberd_c2s:state().
-type queue() :: p1_queue:queue({non_neg_integer(), erlang:timestamp(), xmpp_element() | xmlel()}).
-type id() :: binary().
-type error_reason() :: session_not_found | session_timed_out |
session_is_dead | session_has_exited |
session_was_killed | session_copy_timed_out |
@ -228,8 +229,8 @@ c2s_handle_send(#{mgmt_state := MgmtState, mod := Mod,
c2s_handle_send(State, _Pkt, _Result) ->
State.
c2s_handle_call(#{sid := {Time, _}, mod := Mod, mgmt_queue := Queue} = State,
{resume_session, Time}, From) ->
c2s_handle_call(#{mgmt_id := MgmtID, mgmt_queue := Queue, mod := Mod} = State,
{resume_session, MgmtID}, From) ->
State1 = State#{mgmt_queue => p1_queue:file_to_ram(Queue)},
Mod:reply(From, {resume, State1}),
{stop, State#{mgmt_state => resumed, mgmt_queue => p1_queue:clear(Queue)}};
@ -288,10 +289,10 @@ c2s_terminated(#{mgmt_state := resumed, sid := SID, jid := JID} = State, _Reason
ejabberd_c2s:bounce_message_queue(SID, JID),
{stop, State};
c2s_terminated(#{mgmt_state := MgmtState, mgmt_stanzas_in := In,
sid := {Time, _}, jid := JID} = State, _Reason) ->
mgmt_id := MgmtID, jid := JID} = State, _Reason) ->
case MgmtState of
timeout ->
store_stanzas_in(jid:tolower(JID), Time, In);
store_stanzas_in(jid:tolower(JID), MgmtID, In);
_ ->
ok
end,
@ -377,6 +378,7 @@ handle_enable(#{mgmt_timeout := DefaultTimeout,
mgmt_max_timeout := MaxTimeout,
mgmt_xmlns := Xmlns, jid := JID} = State,
#sm_enable{resume = Resume, max = Max}) ->
State1 = State#{mgmt_id => make_id()},
Timeout = if Resume == false ->
0;
Max /= undefined, Max > 0, Max*1000 =< MaxTimeout ->
@ -388,7 +390,7 @@ handle_enable(#{mgmt_timeout := DefaultTimeout,
?DEBUG("Stream management with resumption enabled for ~ts",
[jid:encode(JID)]),
#sm_enabled{xmlns = Xmlns,
id = make_resume_id(State),
id = encode_id(State1),
resume = true,
max = Timeout div 1000};
true ->
@ -396,10 +398,10 @@ handle_enable(#{mgmt_timeout := DefaultTimeout,
[jid:encode(JID)]),
#sm_enabled{xmlns = Xmlns}
end,
State1 = State#{mgmt_state => active,
mgmt_queue => p1_queue:new(QueueType),
mgmt_timeout => Timeout},
send(State1, Res).
State2 = State1#{mgmt_state => active,
mgmt_queue => p1_queue:new(QueueType),
mgmt_timeout => Timeout},
send(State2, Res).
-spec handle_r(state()) -> state().
handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) ->
@ -431,10 +433,9 @@ handle_resume(#{user := User, lserver := LServer,
{ok, #{jid := JID} = ResumedState, NumHandled} ->
State1 = check_h_attribute(ResumedState, NumHandled),
#{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1,
AttrId = make_resume_id(State1),
State2 = send(State1, #sm_resumed{xmlns = AttrXmlns,
h = AttrH,
previd = AttrId}),
previd = PrevID}),
State3 = resend_unacked_stanzas(State2),
State4 = send(State3, #sm_r{xmlns = AttrXmlns}),
State5 = ejabberd_hooks:run_fold(c2s_session_resumed, LServer, State4, []),
@ -649,20 +650,19 @@ route_unacked_stanzas(_State) ->
{error, error_reason()} |
{error, error_reason(), non_neg_integer()}.
inherit_session_state(#{user := U, server := S,
mgmt_queue_type := QueueType} = State, ResumeID) ->
case misc:base64_to_term(ResumeID) of
{term, {R, Time}} ->
case ejabberd_sm:get_session_pid(U, S, R) of
mgmt_queue_type := QueueType} = State, PrevID) ->
case decode_id(PrevID) of
{ok, {R, MgmtID}} ->
case ejabberd_sm:get_session_sid(U, S, R) of
none ->
case pop_stanzas_in({U, S, R}, Time) of
case pop_stanzas_in({U, S, R}, MgmtID) of
error ->
{error, session_not_found};
{ok, H} ->
{error, session_timed_out, H}
end;
OldPID ->
OldSID = {Time, OldPID},
try resume_session(OldSID, State) of
{_, OldPID} = OldSID ->
try resume_session(OldPID, MgmtID, State) of
{resume, #{mgmt_xmlns := Xmlns,
mgmt_queue := Queue,
mgmt_timeout := Timeout,
@ -673,7 +673,9 @@ inherit_session_state(#{user := U, server := S,
ram -> Queue;
_ -> p1_queue:ram_to_file(Queue)
end,
State2 = State1#{mgmt_xmlns => Xmlns,
State2 = State1#{sid => ejabberd_sm:make_sid(),
mgmt_id => MgmtID,
mgmt_xmlns => Xmlns,
mgmt_queue => Queue1,
mgmt_timeout => Timeout,
mgmt_stanzas_in => NumStanzasIn,
@ -698,18 +700,14 @@ inherit_session_state(#{user := U, server := S,
{error, session_copy_timed_out}
end
end;
_ ->
error ->
{error, invalid_previd}
end.
-spec resume_session({erlang:timestamp(), pid()}, state()) -> {resume, state()} |
{error, error_reason()}.
resume_session({Time, Pid}, _State) ->
ejabberd_c2s:call(Pid, {resume_session, Time}, timer:seconds(15)).
-spec make_resume_id(state()) -> binary().
make_resume_id(#{sid := {Time, _}, resource := Resource}) ->
misc:term_to_base64({Resource, Time}).
-spec resume_session(pid(), id(), state()) -> {resume, state()} |
{error, error_reason()}.
resume_session(PID, MgmtID, _State) ->
ejabberd_c2s:call(PID, {resume_session, MgmtID}, timer:seconds(15)).
-spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza();
(state(), xmlel(), erlang:timestamp()) -> xmlel().
@ -756,6 +754,24 @@ need_to_enqueue(#{mgmt_force_enqueue := true} = State, #xmlel{}) ->
need_to_enqueue(State, _) ->
{false, State}.
-spec make_id() -> id().
make_id() ->
p1_rand:bytes(8).
-spec encode_id(state()) -> binary().
encode_id(#{mgmt_id := MgmtID, resource := Resource}) ->
misc:term_to_base64({Resource, MgmtID}).
-spec decode_id(binary()) -> {ok, {binary(), id()}} | error.
decode_id(Encoded) ->
case misc:base64_to_term(Encoded) of
{term, {Resource, MgmtID}} when is_binary(Resource),
is_binary(MgmtID) ->
{ok, {Resource, MgmtID}};
_ ->
error
end.
%%%===================================================================
%%% Formatters and Logging
%%%===================================================================
@ -803,14 +819,14 @@ cache_opts(Opts) ->
{life_time, mod_stream_mgmt_opt:cache_life_time(Opts)},
{type, ordered_set}].
-spec store_stanzas_in(ljid(), erlang:timestamp(), non_neg_integer()) -> boolean().
store_stanzas_in(LJID, Time, Num) ->
ets_cache:insert(?STREAM_MGMT_CACHE, {LJID, Time}, Num,
-spec store_stanzas_in(ljid(), id(), non_neg_integer()) -> boolean().
store_stanzas_in(LJID, MgmtID, Num) ->
ets_cache:insert(?STREAM_MGMT_CACHE, {LJID, MgmtID}, Num,
ejabberd_cluster:get_nodes()).
-spec pop_stanzas_in(ljid(), erlang:timestamp()) -> {ok, non_neg_integer()} | error.
pop_stanzas_in(LJID, Time) ->
case ets_cache:lookup(?STREAM_MGMT_CACHE, {LJID, Time}) of
-spec pop_stanzas_in(ljid(), id()) -> {ok, non_neg_integer()} | error.
pop_stanzas_in(LJID, MgmtID) ->
case ets_cache:lookup(?STREAM_MGMT_CACHE, {LJID, MgmtID}) of
{ok, Val} ->
ets_cache:match_delete(?STREAM_MGMT_CACHE, {LJID, '_'},
ejabberd_cluster:get_nodes()),