25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-11-28 16:34:13 +01:00

More refactoring on session management

This commit is contained in:
Evgeniy Khramtsov 2016-12-30 00:00:36 +03:00
parent 309bdfbe28
commit e7fe4dc474
14 changed files with 569 additions and 401 deletions

View File

@ -188,7 +188,7 @@ try_register(User, Server, Password) ->
true -> {atomic, exists}; true -> {atomic, exists};
false -> false ->
LServer = jid:nameprep(Server), LServer = jid:nameprep(Server),
case lists:member(LServer, ?MYHOSTS) of case ejabberd_router:is_my_host(LServer) of
true -> true ->
Res = lists:foldl(fun (_M, {atomic, ok} = Res) -> Res; Res = lists:foldl(fun (_M, {atomic, ok} = Res) -> Res;
(M, _) -> (M, _) ->

View File

@ -37,17 +37,18 @@
compress_methods/1, bind/2, get_password_fun/1, compress_methods/1, bind/2, get_password_fun/1,
check_password_fun/1, check_password_digest_fun/1, check_password_fun/1, check_password_digest_fun/1,
unauthenticated_stream_features/1, authenticated_stream_features/1, unauthenticated_stream_features/1, authenticated_stream_features/1,
handle_stream_start/2, handle_stream_end/2, handle_stream_close/2, handle_stream_start/2, handle_stream_end/2,
handle_unauthenticated_packet/2, handle_authenticated_packet/2, handle_unauthenticated_packet/2, handle_authenticated_packet/2,
handle_auth_success/4, handle_auth_failure/4, handle_send/3, handle_auth_success/4, handle_auth_failure/4, handle_send/3,
handle_recv/3, handle_cdata/2, handle_unbinded_packet/2]). handle_recv/3, handle_cdata/2, handle_unbinded_packet/2]).
%% Hooks %% Hooks
-export([handle_unexpected_info/2, handle_unexpected_cast/2, -export([handle_unexpected_cast/2,
reject_unauthenticated_packet/2, process_closed/2]). reject_unauthenticated_packet/2, process_closed/2,
process_terminated/2, process_info/2]).
%% API %% API
-export([get_presence/1, get_subscription/2, get_subscribed/1, -export([get_presence/1, get_subscription/2, get_subscribed/1,
open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1, open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1,
copy_state/2, add_hooks/0]). reply/2, copy_state/2, set_timeout/2, add_hooks/1]).
-include("ejabberd.hrl"). -include("ejabberd.hrl").
-include("xmpp.hrl"). -include("xmpp.hrl").
@ -76,6 +77,9 @@ socket_type() ->
call(Ref, Msg, Timeout) -> call(Ref, Msg, Timeout) ->
xmpp_stream_in:call(Ref, Msg, Timeout). xmpp_stream_in:call(Ref, Msg, Timeout).
reply(Ref, Reply) ->
xmpp_stream_in:reply(Ref, Reply).
-spec get_presence(pid()) -> presence(). -spec get_presence(pid()) -> presence().
get_presence(Ref) -> get_presence(Ref) ->
call(Ref, get_presence, 1000). call(Ref, get_presence, 1000).
@ -112,28 +116,30 @@ stop(Ref) ->
send(Pid, Pkt) when is_pid(Pid) -> send(Pid, Pkt) when is_pid(Pid) ->
xmpp_stream_in:send(Pid, Pkt); xmpp_stream_in:send(Pid, Pkt);
send(#{lserver := LServer} = State, Pkt) -> send(#{lserver := LServer} = State, Pkt) ->
case ejabberd_hooks:run_fold(c2s_filter_send, LServer, Pkt, [State]) of case ejabberd_hooks:run_fold(c2s_filter_send, LServer, {Pkt, State}, []) of
drop -> State; {drop, State1} -> State1;
Pkt1 -> xmpp_stream_in:send(State, Pkt1) {Pkt1, State1} -> xmpp_stream_in:send(State1, Pkt1)
end. end.
-spec set_timeout(state(), timeout()) -> state().
set_timeout(State, Timeout) ->
xmpp_stream_in:set_timeout(State, Timeout).
-spec establish(state()) -> state(). -spec establish(state()) -> state().
establish(State) -> establish(State) ->
xmpp_stream_in:establish(State). xmpp_stream_in:establish(State).
-spec add_hooks() -> ok. -spec add_hooks(binary()) -> ok.
add_hooks() -> add_hooks(Host) ->
lists:foreach(
fun(Host) ->
ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100), ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100),
ejabberd_hooks:add(c2s_terminated, Host, ?MODULE,
process_terminated, 100),
ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE, ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
reject_unauthenticated_packet, 100), reject_unauthenticated_packet, 100),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
handle_unexpected_info, 100), process_info, 100),
ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE, ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE,
handle_unexpected_cast, 100) handle_unexpected_cast, 100).
end, ?MYHOSTS).
%% Copies content of one c2s state to another. %% Copies content of one c2s state to another.
%% This is needed for session migration from one pid to another. %% This is needed for session migration from one pid to another.
@ -158,10 +164,46 @@ copy_state(#{owner := Owner} = NewState,
pres_f => PresF}, pres_f => PresF},
ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]). ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]).
-spec open_session(state()) -> {ok, state()} | state().
open_session(#{user := U, server := S, resource := R,
sid := SID, ip := IP, auth_module := AuthModule} = State) ->
JID = jid:make(U, S, R),
change_shaper(State),
Conn = get_conn_type(State),
State1 = State#{conn => Conn, resource => R, jid => JID},
Prio = try maps:get(pres_last, State) of
Pres -> get_priority_from_presence(Pres)
catch _:{badkey, _} ->
undefined
end,
Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}],
ejabberd_sm:open_session(SID, U, S, R, Prio, Info),
xmpp_stream_in:establish(State1).
%%%=================================================================== %%%===================================================================
%%% Hooks %%% Hooks
%%%=================================================================== %%%===================================================================
handle_unexpected_info(State, Info) -> process_info(#{lserver := LServer} = State,
{route, From, To, Packet0}) ->
Packet = xmpp:set_from_to(Packet0, From, To),
{Pass, State1} = case Packet of
#presence{} ->
process_presence_in(State, Packet);
#message{} ->
process_message_in(State, Packet);
#iq{} ->
process_iq_in(State, Packet)
end,
if Pass ->
Packet1 = ejabberd_hooks:run_fold(
user_receive_packet, LServer, Packet, [State1]),
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
send(State1, Packet1);
true ->
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
State1
end;
process_info(State, Info) ->
?WARNING_MSG("got unexpected info: ~p", [Info]), ?WARNING_MSG("got unexpected info: ~p", [Info]),
State. State.
@ -173,8 +215,22 @@ reject_unauthenticated_packet(State, Pkt) ->
Err = xmpp:err_not_authorized(), Err = xmpp:err_not_authorized(),
xmpp_stream_in:send_error(State, Pkt, Err). xmpp_stream_in:send_error(State, Pkt, Err).
process_closed(State, _Reason) -> process_closed(State, Reason) ->
stop(State). stop(State#{stop_reason => Reason}).
process_terminated(#{socket := Socket, jid := JID} = State,
Reason) ->
Status = format_reason(State, Reason),
?INFO_MSG("(~s) Closing c2s connection for ~s: ~s",
[ejabberd_socket:pp(Socket), jid:to_string(JID), Status]),
Pres = #presence{type = unavailable,
status = xmpp:mk_text(Status),
from = JID, to = jid:remove_resource(JID)},
State1 = broadcast_presence_unavailable(State, Pres),
bounce_message_queue(),
State1;
process_terminated(State, _Reason) ->
State.
%%%=================================================================== %%%===================================================================
%%% xmpp_stream_in callbacks %%% xmpp_stream_in callbacks
@ -248,25 +304,9 @@ bind(R, #{user := U, server := S, access := Access, lang := Lang,
end end
end. end.
-spec open_session(state()) -> {ok, state()} | state().
open_session(#{user := U, server := S, resource := R,
sid := SID, ip := IP, auth_module := AuthModule} = State) ->
JID = jid:make(U, S, R),
change_shaper(State),
Conn = get_conn_type(State),
State1 = State#{conn => Conn, resource => R, jid => JID},
Prio = try maps:get(pres_last, State) of
Pres -> get_priority_from_presence(Pres)
catch _:{badkey, _} ->
undefined
end,
Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}],
ejabberd_sm:open_session(SID, U, S, R, Prio, Info),
State1.
handle_stream_start(StreamStart, handle_stream_start(StreamStart,
#{lserver := LServer, ip := IP, lang := Lang} = State) -> #{lserver := LServer, ip := IP, lang := Lang} = State) ->
case lists:member(LServer, ?MYHOSTS) of case ejabberd_router:is_my_host(LServer) of
false -> false ->
send(State, xmpp:serr_host_unknown()); send(State, xmpp:serr_host_unknown());
true -> true ->
@ -284,10 +324,8 @@ handle_stream_start(StreamStart,
end. end.
handle_stream_end(Reason, #{lserver := LServer} = State) -> handle_stream_end(Reason, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_closed, LServer, State, [Reason]). State1 = State#{stop_reason => Reason},
ejabberd_hooks:run_fold(c2s_closed, LServer, State1, [Reason]).
handle_stream_close(_Reason, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_closed, LServer, State, [normal]).
handle_auth_success(User, Mech, AuthModule, handle_auth_success(User, Mech, AuthModule,
#{socket := Socket, ip := IP, lserver := LServer} = State) -> #{socket := Socket, ip := IP, lserver := LServer} = State) ->
@ -296,8 +334,7 @@ handle_auth_success(User, Mech, AuthModule,
ejabberd_auth:backend_type(AuthModule), ejabberd_auth:backend_type(AuthModule),
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
State1 = State#{auth_module => AuthModule}, State1 = State#{auth_module => AuthModule},
ejabberd_hooks:run_fold(c2s_auth_result, LServer, ejabberd_hooks:run_fold(c2s_auth_result, LServer, State1, [true, User]).
State1, [true, User]).
handle_auth_failure(User, Mech, Reason, handle_auth_failure(User, Mech, Reason,
#{socket := Socket, ip := IP, lserver := LServer} = State) -> #{socket := Socket, ip := IP, lserver := LServer} = State) ->
@ -307,16 +344,13 @@ handle_auth_failure(User, Mech, Reason,
true -> "" true -> ""
end, end,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
ejabberd_hooks:run_fold(c2s_auth_result, LServer, ejabberd_hooks:run_fold(c2s_auth_result, LServer, State, [false, User]).
State, [false, User]).
handle_unbinded_packet(Pkt, #{lserver := LServer} = State) -> handle_unbinded_packet(Pkt, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer, ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer, State, [Pkt]).
State, [Pkt]).
handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) -> handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_unauthenticated_packet, ejabberd_hooks:run_fold(c2s_unauthenticated_packet, LServer, State, [Pkt]).
LServer, State, [Pkt]).
handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) -> handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) ->
ejabberd_hooks:run_fold(c2s_authenticated_packet, ejabberd_hooks:run_fold(c2s_authenticated_packet,
@ -366,20 +400,22 @@ init([State, Opts]) ->
zlib => Zlib, zlib => Zlib,
lang => ?MYLANG, lang => ?MYLANG,
server => ?MYNAME, server => ?MYNAME,
lserver => ?MYNAME,
access => Access, access => Access,
shaper => Shaper}, shaper => Shaper},
ejabberd_hooks:run_fold(c2s_init, {ok, State1}, [Opts]). ejabberd_hooks:run_fold(c2s_init, {ok, State1}, [Opts]).
handle_call(get_presence, _From, #{jid := JID} = State) -> handle_call(get_presence, From, #{jid := JID} = State) ->
Pres = try maps:get(pres_last, State) Pres = try maps:get(pres_last, State)
catch _:{badkey, _} -> catch _:{badkey, _} ->
BareJID = jid:remove_resource(JID), BareJID = jid:remove_resource(JID),
#presence{from = JID, to = BareJID, type = unavailable} #presence{from = JID, to = BareJID, type = unavailable}
end, end,
{reply, Pres, State}; reply(From, Pres),
handle_call(get_subscribed, _From, #{pres_f := PresF} = State) -> State;
Subscribed = ?SETS:to_list(PresF), handle_call(get_subscribed, From, #{pres_f := PresF} = State) ->
{reply, Subscribed, State}; reply(From, ?SETS:to_list(PresF)),
State;
handle_call(Request, From, #{lserver := LServer} = State) -> handle_call(Request, From, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold( ejabberd_hooks:run_fold(
c2s_handle_call, LServer, State, [Request, From]). c2s_handle_call, LServer, State, [Request, From]).
@ -387,30 +423,22 @@ handle_call(Request, From, #{lserver := LServer} = State) ->
handle_cast(Msg, #{lserver := LServer} = State) -> handle_cast(Msg, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_cast, LServer, State, [Msg]). ejabberd_hooks:run_fold(c2s_handle_cast, LServer, State, [Msg]).
handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) ->
Packet = xmpp:set_from_to(Packet0, From, To),
{Pass, NewState} = case Packet of
#presence{} ->
process_presence_in(State, Packet);
#message{} ->
process_message_in(State, Packet);
#iq{} ->
process_iq_in(State, Packet)
end,
if Pass ->
Packet1 = ejabberd_hooks:run_fold(
user_receive_packet, LServer, Packet, [NewState]),
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
send(NewState, Packet1);
true ->
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
NewState
end;
handle_info(Info, #{lserver := LServer} = State) -> handle_info(Info, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]). ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]).
terminate(_Reason, _State) -> terminate(Reason, #{sid := SID, jid := _,
ok. user := U, server := S, resource := R,
lserver := LServer} = State) ->
Status = format_reason(State, Reason),
case maps:is_key(pres_last, State) of
true ->
ejabberd_sm:close_session_unset_presence(SID, U, S, R, Status);
false ->
ejabberd_sm:close_session(SID, U, S, R)
end,
ejabberd_hooks:run_fold(c2s_terminated, LServer, State, [Reason]);
terminate(Reason, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_terminated, LServer, State, [Reason]).
code_change(_OldVsn, State, _Extra) -> code_change(_OldVsn, State, _Extra) ->
{ok, State}. {ok, State}.
@ -684,6 +712,15 @@ resource_conflict_action(U, S, R) ->
{accept_resource, Rnew} {accept_resource, Rnew}
end. end.
-spec bounce_message_queue() -> ok.
bounce_message_queue() ->
receive {route, From, To, Pkt} ->
ejabberd_router:route(From, To, Pkt),
bounce_message_queue()
after 0 ->
ok
end.
-spec new_uniq_id() -> binary(). -spec new_uniq_id() -> binary().
new_uniq_id() -> new_uniq_id() ->
iolist_to_binary( iolist_to_binary(
@ -735,6 +772,14 @@ do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) ->
end end
end. end.
-spec format_reason(state(), term()) -> binary().
format_reason(#{stop_reason := Reason}, _) ->
xmpp_stream_in:format_error(Reason);
format_reason(_, Reason) when Reason /= normal ->
<<"internal server error">>;
format_reason(_, _) ->
<<"">>.
transform_listen_option(Opt, Opts) -> transform_listen_option(Opt, Opts) ->
[Opt|Opts]. [Opt|Opts].

View File

@ -322,7 +322,7 @@ add_header(Name, Value, State)->
get_host_really_served(undefined, Provided) -> get_host_really_served(undefined, Provided) ->
Provided; Provided;
get_host_really_served(Default, Provided) -> get_host_really_served(Default, Provided) ->
case lists:member(Provided, ?MYHOSTS) of case ejabberd_router:is_my_host(Provided) of
true -> Provided; true -> Provided;
false -> Default false -> Default
end. end.

View File

@ -350,7 +350,7 @@ process_el({xmlstreamelement, #xmlel{name = <<"host">>,
JIDS = fxml:get_attr_s(<<"jid">>, Attrs), JIDS = fxml:get_attr_s(<<"jid">>, Attrs),
case jid:from_string(JIDS) of case jid:from_string(JIDS) of
#jid{lserver = S} -> #jid{lserver = S} ->
case lists:member(S, ?MYHOSTS) of case ejabberd_router:is_my_host(S) of
true -> true ->
process_users(Els, State#state{server = S}); process_users(Els, State#state{server = S});
false -> false ->

View File

@ -34,7 +34,7 @@
-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, -export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
compress_methods/1, compress_methods/1,
unauthenticated_stream_features/1, authenticated_stream_features/1, unauthenticated_stream_features/1, authenticated_stream_features/1,
handle_stream_start/2, handle_stream_end/2, handle_stream_close/2, handle_stream_start/2, handle_stream_end/2,
handle_stream_established/1, handle_auth_success/4, handle_stream_established/1, handle_auth_success/4,
handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2, handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2,
handle_unauthenticated_packet/2, handle_authenticated_packet/2]). handle_unauthenticated_packet/2, handle_authenticated_packet/2]).
@ -160,9 +160,6 @@ handle_stream_start(_StreamStart, #{lserver := LServer} = State) ->
handle_stream_end(Reason, #{server_host := LServer} = State) -> handle_stream_end(Reason, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [Reason]). ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [Reason]).
handle_stream_close(_Reason, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [normal]).
handle_stream_established(State) -> handle_stream_established(State) ->
set_idle_timeout(State#{established => true}). set_idle_timeout(State#{established => true}).

View File

@ -15,14 +15,14 @@
%% xmpp_stream_out callbacks %% xmpp_stream_out callbacks
-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, -export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
handle_auth_success/2, handle_auth_failure/3, handle_packet/2, handle_auth_success/2, handle_auth_failure/3, handle_packet/2,
handle_stream_end/2, handle_stream_close/2, handle_stream_end/2, handle_stream_downgraded/2,
handle_recv/3, handle_send/4, handle_cdata/2, handle_recv/3, handle_send/4, handle_cdata/2,
handle_stream_established/1, handle_timeout/1]). handle_stream_established/1, handle_timeout/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]). terminate/2, code_change/3]).
%% Hooks %% Hooks
-export([process_auth_result/2, process_closed/2, handle_unexpected_info/2, -export([process_auth_result/2, process_closed/2, handle_unexpected_info/2,
handle_unexpected_cast/2]). handle_unexpected_cast/2, process_downgraded/2]).
%% API %% API
-export([start/3, start_link/3, connect/1, close/1, stop/1, send/2, -export([start/3, start_link/3, connect/1, close/1, stop/1, send/2,
route/2, establish/1, update_state/2, add_hooks/0]). route/2, establish/1, update_state/2, add_hooks/0]).
@ -83,7 +83,9 @@ add_hooks() ->
ejabberd_hooks:add(s2s_out_handle_info, Host, ?MODULE, ejabberd_hooks:add(s2s_out_handle_info, Host, ?MODULE,
handle_unexpected_info, 100), handle_unexpected_info, 100),
ejabberd_hooks:add(s2s_out_handle_cast, Host, ?MODULE, ejabberd_hooks:add(s2s_out_handle_cast, Host, ?MODULE,
handle_unexpected_cast, 100) handle_unexpected_cast, 100),
ejabberd_hooks:add(s2s_out_downgraded, Host, ?MODULE,
process_downgraded, 100)
end, ?MYHOSTS). end, ?MYHOSTS).
%%%=================================================================== %%%===================================================================
@ -95,25 +97,28 @@ process_auth_result(#{server := LServer, remote_server := RServer} = State,
?INFO_MSG("Closing outbound s2s connection ~s -> ~s: authentication failed;" ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: authentication failed;"
" bouncing for ~p seconds", " bouncing for ~p seconds",
[LServer, RServer, Delay]), [LServer, RServer, Delay]),
State1 = close(State), State1 = State#{on_route => bounce},
State2 = bounce_queue(State1), State2 = close(State1),
xmpp_stream_out:set_timeout(State2, timer:seconds(Delay)); State3 = bounce_queue(State2),
xmpp_stream_out:set_timeout(State3, timer:seconds(Delay));
process_auth_result(State, true) -> process_auth_result(State, true) ->
State. State.
process_closed(#{server := LServer, remote_server := RServer,
on_route := send} = State,
Reason) ->
?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s",
[LServer, RServer, xmpp_stream_out:format_error(Reason)]),
stop(State);
process_closed(#{server := LServer, remote_server := RServer} = State, process_closed(#{server := LServer, remote_server := RServer} = State,
_Reason) -> Reason) ->
Delay = get_delay(), Delay = get_delay(),
?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s; " ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s; "
"bouncing for ~p seconds", "bouncing for ~p seconds",
[LServer, RServer, [LServer, RServer, xmpp_stream_out:format_error(Reason), Delay]),
try maps:get(stop_reason, State) of State1 = State#{on_route => bounce},
{error, Why} -> xmpp_stream_out:format_error(Why) State2 = bounce_queue(State1),
catch _:undef -> <<"unexplained reason">> xmpp_stream_out:set_timeout(State2, timer:seconds(Delay)).
end,
Delay]),
State1 = bounce_queue(State),
xmpp_stream_out:set_timeout(State1, timer:seconds(Delay)).
handle_unexpected_info(State, Info) -> handle_unexpected_info(State, Info) ->
?WARNING_MSG("got unexpected info: ~p", [Info]), ?WARNING_MSG("got unexpected info: ~p", [Info]),
@ -123,6 +128,9 @@ handle_unexpected_cast(State, Msg) ->
?WARNING_MSG("got unexpected cast: ~p", [Msg]), ?WARNING_MSG("got unexpected cast: ~p", [Msg]),
State. State.
process_downgraded(State, _StreamStart) ->
send(State, xmpp:serr_unsupported_version()).
%%%=================================================================== %%%===================================================================
%%% gen_server callbacks %%% gen_server callbacks
%%%=================================================================== %%%===================================================================
@ -153,21 +161,19 @@ handle_auth_failure(Mech, Reason,
?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s", ?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s",
[ejabberd_socket:pp(Socket), Mech, LServer, RServer, [ejabberd_socket:pp(Socket), Mech, LServer, RServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
State1 = State#{on_route => bounce, State1 = State#{stop_reason => {auth, Reason}},
stop_reason => {error, {auth, Reason}}},
ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State1, [false]). ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State1, [false]).
handle_packet(Pkt, #{server := LServer} = State) -> handle_packet(Pkt, #{server := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]). ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]).
handle_stream_end(Reason, #{server := LServer} = State) -> handle_stream_end(Reason, #{server := LServer} = State) ->
State1 = State#{on_route => bounce, stop_reason => Reason}, State1 = State#{stop_reason => Reason},
ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [normal]).
handle_stream_close(Reason, #{server := LServer} = State) ->
State1 = State#{on_route => bounce, stop_reason => Reason},
ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [Reason]). ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [Reason]).
handle_stream_downgraded(StreamStart, #{server := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_out_downgraded, LServer, State, [StreamStart]).
handle_stream_established(State) -> handle_stream_established(State) ->
State1 = State#{on_route => send}, State1 = State#{on_route => send},
State2 = resend_queue(State1), State2 = resend_queue(State1),
@ -183,15 +189,10 @@ handle_send(Pkt, El, Data, #{server := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_out_handle_send, LServer, ejabberd_hooks:run_fold(s2s_out_handle_send, LServer,
State, [Pkt, El, Data]). State, [Pkt, El, Data]).
handle_timeout(#{server := LServer, remote_server := RServer, handle_timeout(#{on_route := Action} = State) ->
on_route := Action} = State) ->
case Action of case Action of
bounce -> stop(State); bounce -> stop(State);
queue -> send(State, xmpp:serr_connection_timeout()); _ -> send(State, xmpp:serr_connection_timeout())
send ->
?INFO_MSG("Closing outbound s2s connection ~s -> ~s: inactive",
[LServer, RServer]),
stop(State)
end. end.
init([#{server := LServer, remote_server := RServer} = State, Opts]) -> init([#{server := LServer, remote_server := RServer} = State, Opts]) ->
@ -229,7 +230,7 @@ terminate(Reason, #{server := LServer,
ejabberd_s2s:remove_connection({LServer, RServer}, self()), ejabberd_s2s:remove_connection({LServer, RServer}, self()),
State1 = case Reason of State1 = case Reason of
normal -> State; normal -> State;
_ -> State#{stop_reason => {error, internal_failure}} _ -> State#{stop_reason => internal_failure}
end, end,
bounce_queue(State1), bounce_queue(State1),
bounce_message_queue(State1). bounce_message_queue(State1).
@ -258,8 +259,7 @@ bounce_queue(#{queue := Q} = State) ->
-spec bounce_message_queue(state()) -> state(). -spec bounce_message_queue(state()) -> state().
bounce_message_queue(State) -> bounce_message_queue(State) ->
receive receive {route, Pkt} ->
{route, Pkt} ->
State1 = bounce_packet(Pkt, State), State1 = bounce_packet(Pkt, State),
bounce_message_queue(State1) bounce_message_queue(State1)
after 0 -> after 0 ->
@ -278,21 +278,19 @@ bounce_packet(_, State) ->
State. State.
-spec mk_bounce_error(binary(), state()) -> stanza_error(). -spec mk_bounce_error(binary(), state()) -> stanza_error().
mk_bounce_error(Lang, State) -> mk_bounce_error(Lang, #{stop_reason := Why}) ->
try maps:get(stop_reason, State) of
{error, internal_failure} ->
xmpp:err_internal_server_error();
{error, Why} ->
Reason = xmpp_stream_out:format_error(Why), Reason = xmpp_stream_out:format_error(Why),
case Why of case Why of
internal_failure ->
xmpp:err_internal_server_error();
{dns, _} -> {dns, _} ->
xmpp:err_remote_server_timeout(Reason, Lang); xmpp:err_remote_server_not_found(Reason, Lang);
_ -> _ ->
xmpp:err_remote_server_not_found(Reason, Lang) xmpp:err_remote_server_timeout(Reason, Lang)
end end;
catch _:{badkey, _} -> mk_bounce_error(_Lang, _State) ->
xmpp:err_remote_server_not_found() %% We should not be here. Probably :)
end. xmpp:err_remote_server_not_found().
-spec get_delay() -> non_neg_integer(). -spec get_delay() -> non_neg_integer().
get_delay() -> get_delay() ->

View File

@ -99,7 +99,7 @@ handle_stream_start(_StreamStart,
#{remote_server := RemoteServer, #{remote_server := RemoteServer,
lang := Lang, lang := Lang,
host_opts := HostOpts} = State) -> host_opts := HostOpts} = State) ->
case lists:member(RemoteServer, ?MYHOSTS) of case ejabberd_router:is_my_host(RemoteServer) of
true -> true ->
Txt = <<"Unable to register route on existing local domain">>, Txt = <<"Unable to register route on existing local domain">>,
xmpp_stream_in:send(State, xmpp:serr_conflict(Txt, Lang)); xmpp_stream_in:send(State, xmpp:serr_conflict(Txt, Lang));

View File

@ -390,7 +390,8 @@ init([]) ->
ejabberd_hooks:add(offline_message_hook, Host, ejabberd_hooks:add(offline_message_hook, Host,
ejabberd_sm, bounce_offline_message, 100), ejabberd_sm, bounce_offline_message, 100),
ejabberd_hooks:add(remove_user, Host, ejabberd_hooks:add(remove_user, Host,
ejabberd_sm, disconnect_removed_user, 100) ejabberd_sm, disconnect_removed_user, 100),
ejabberd_c2s:add_hooks(Host)
end, ?MYHOSTS), end, ?MYHOSTS),
ejabberd_commands:register_commands(get_commands_spec()), ejabberd_commands:register_commands(get_commands_spec()),
{ok, #state{}}. {ok, #state{}}.

View File

@ -192,7 +192,7 @@ process([<<"server">>, SHost | RPath] = Path,
method = Method} = method = Method} =
Request) -> Request) ->
Host = jid:nameprep(SHost), Host = jid:nameprep(SHost),
case lists:member(Host, ?MYHOSTS) of case ejabberd_router:is_my_host(Host) of
true -> true ->
case get_auth_admin(Auth, HostHTTP, Path, Method) of case get_auth_admin(Auth, HostHTTP, Path, Method) of
{ok, {User, Server}} -> {ok, {User, Server}} ->

View File

@ -133,8 +133,7 @@ open_session(State, IQ, R) ->
case ejabberd_c2s:bind(R, State) of case ejabberd_c2s:bind(R, State) of
{ok, State1} -> {ok, State1} ->
Res = xmpp:make_iq_result(IQ), Res = xmpp:make_iq_result(IQ),
State2 = ejabberd_c2s:send(State1, Res), ejabberd_c2s:send(State1, Res);
ejabberd_c2s:establish(State2);
{error, Err, State1} -> {error, Err, State1} ->
Res = xmpp:make_error(IQ, Err), Res = xmpp:make_error(IQ, Err),
ejabberd_c2s:send(State1, Res) ejabberd_c2s:send(State1, Res)

View File

@ -28,7 +28,8 @@
%% gen_mod API %% gen_mod API
-export([start/2, stop/1, depends/2, mod_opt_type/1]). -export([start/2, stop/1, depends/2, mod_opt_type/1]).
%% Hooks %% Hooks
-export([s2s_out_auth_result/2, s2s_in_packet/2, s2s_out_packet/2, -export([s2s_out_auth_result/2, s2s_out_downgraded/2,
s2s_in_packet/2, s2s_out_packet/2,
s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]). s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]).
-include("ejabberd.hrl"). -include("ejabberd.hrl").
@ -57,6 +58,8 @@ start(Host, _Opts) ->
s2s_in_packet, 50), s2s_in_packet, 50),
ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE, ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE,
s2s_out_packet, 50), s2s_out_packet, 50),
ejabberd_hooks:add(s2s_out_downgraded, Host, ?MODULE,
s2s_out_downgraded, 50),
ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE, ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE,
s2s_out_auth_result, 50) s2s_out_auth_result, 50)
end. end.
@ -74,6 +77,8 @@ stop(Host) ->
s2s_in_packet, 50), s2s_in_packet, 50),
ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE, ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE,
s2s_out_packet, 50), s2s_out_packet, 50),
ejabberd_hooks:delete(s2s_out_downgraded, Host, ?MODULE,
s2s_out_downgraded, 50),
ejabberd_hooks:delete(s2s_out_auth_result, Host, ?MODULE, ejabberd_hooks:delete(s2s_out_auth_result, Host, ?MODULE,
s2s_out_auth_result, 50). s2s_out_auth_result, 50).
@ -104,47 +109,56 @@ s2s_out_init(Acc, _Opts) ->
s2s_out_closed(#{server := LServer, s2s_out_closed(#{server := LServer,
remote_server := RServer, remote_server := RServer,
db_verify := {StreamID, _Key, _Pid}} = State, _Reason) -> db_verify := {StreamID, _Key, _Pid}} = State, Reason) ->
%% Outbound s2s verificating connection (created at step 1) is %% Outbound s2s verificating connection (created at step 1) is
%% closed suddenly without receiving the response. %% closed suddenly without receiving the response.
%% Building a response on our own %% Building a response on our own
Response = #db_verify{from = RServer, to = LServer, Response = #db_verify{from = RServer, to = LServer,
id = StreamID, type = error, id = StreamID, type = error,
sub_els = [mk_error(internal_server_error)]}, sub_els = [mk_error(Reason)]},
s2s_out_packet(State, Response); s2s_out_packet(State, Response);
s2s_out_closed(State, _Reason) -> s2s_out_closed(State, _Reason) ->
State. State.
s2s_out_auth_result(#{server := LServer, s2s_out_auth_result(#{db_verify := _} = State, _) ->
remote_server := RServer,
db_verify := {StreamID, Key, _Pid}} = State,
_) ->
%% The temporary outbound s2s connect (intended for verification) %% The temporary outbound s2s connect (intended for verification)
%% has passed authentication state (either successfully or not, no matter) %% has passed authentication state (either successfully or not, no matter)
%% and at this point we can send verification request as described %% and at this point we can send verification request as described
%% in section 2.1.2, step 2 %% in section 2.1.2, step 2
Request = #db_verify{from = LServer, to = RServer, {stop, send_verify_request(State)};
key = Key, id = StreamID},
{stop, ejabberd_s2s_out:send(State, Request)};
s2s_out_auth_result(#{db_enabled := true, s2s_out_auth_result(#{db_enabled := true,
socket := Socket, ip := IP, socket := Socket, ip := IP,
server := LServer, server := LServer,
remote_server := RServer, remote_server := RServer} = State, false) ->
stream_remote_id := StreamID} = State, false) ->
%% SASL authentication has failed, retrying with dialback %% SASL authentication has failed, retrying with dialback
%% Sending dialback request, section 2.1.1, step 1 %% Sending dialback request, section 2.1.1, step 1
?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)", ?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)",
[ejabberd_socket:pp(Socket), LServer, RServer, [ejabberd_socket:pp(Socket), LServer, RServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
Key = make_key(LServer, RServer, StreamID),
State1 = maps:remove(stop_reason, State#{on_route => queue}), State1 = maps:remove(stop_reason, State#{on_route => queue}),
State2 = ejabberd_s2s_out:send(State1, #db_result{from = LServer, {stop, send_db_request(State1)};
to = RServer,
key = Key}),
{stop, State2};
s2s_out_auth_result(State, _) -> s2s_out_auth_result(State, _) ->
State. State.
s2s_out_downgraded(#{db_verify := _} = State, _) ->
%% The verifying outbound s2s connection detected non-RFC compliant
%% server, send verification request immediately without auth phase,
%% section 2.1.2, step 2
{stop, send_verify_request(State)};
s2s_out_downgraded(#{db_enabled := true,
socket := Socket, ip := IP,
server := LServer,
remote_server := RServer} = State, _) ->
%% non-RFC compliant server detected, send dialback request instantly,
%% section 2.1.1, step 1
?INFO_MSG("(~s) Trying s2s dialback authentication with "
"non-RFC compliant server: ~s -> ~s (~s)",
[ejabberd_socket:pp(Socket), LServer, RServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
{stop, send_db_request(State)};
s2s_out_downgraded(State, _) ->
State.
s2s_in_packet(#{stream_id := StreamID} = State, s2s_in_packet(#{stream_id := StreamID} = State,
#db_result{from = From, to = To, key = Key, type = undefined}) -> #db_result{from = From, to = To, key = Key, type = undefined}) ->
%% Received dialback request, section 2.2.1, step 1 %% Received dialback request, section 2.2.1, step 1
@ -220,6 +234,23 @@ make_key(From, To, StreamID) ->
crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)), crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)),
[To, " ", From, " ", StreamID])). [To, " ", From, " ", StreamID])).
-spec send_verify_request(ejabberd_s2s_out:state()) -> ejabberd_s2s_out:state().
send_verify_request(#{server := LServer,
remote_server := RServer,
db_verify := {StreamID, Key, _Pid}} = State) ->
Request = #db_verify{from = LServer, to = RServer,
key = Key, id = StreamID},
ejabberd_s2s_out:send(State, Request).
-spec send_db_request(ejabberd_s2s_out:state()) -> ejabberd_s2s_out:state().
send_db_request(#{server := LServer,
remote_server := RServer,
stream_remote_id := StreamID} = State) ->
Key = make_key(LServer, RServer, StreamID),
ejabberd_s2s_out:send(State, #db_result{from = LServer,
to = RServer,
key = Key}).
-spec send_db_result(ejabberd_s2s_in:state(), db_verify()) -> ejabberd_s2s_in:state(). -spec send_db_result(ejabberd_s2s_in:state(), db_verify()) -> ejabberd_s2s_in:state().
send_db_result(State, #db_verify{from = From, to = To, send_db_result(State, #db_verify{from = From, to = To,
type = Type, sub_els = Els}) -> type = Type, sub_els = Els}) ->
@ -255,6 +286,9 @@ mk_error(forbidden) ->
xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG); xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG);
mk_error(host_unknown) -> mk_error(host_unknown) ->
xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG); xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG);
mk_error({_Class, _Reason} = Why) ->
Txt = xmpp_stream_out:format_error(Why),
xmpp:err_remote_server_not_found(Txt, ?MYLANG);
mk_error(_) -> mk_error(_) ->
xmpp:err_internal_server_error(). xmpp:err_internal_server_error().

View File

@ -30,8 +30,9 @@
%% hooks %% hooks
-export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2, -export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2,
c2s_authenticated_packet/2, c2s_unauthenticated_packet/2, c2s_authenticated_packet/2, c2s_unauthenticated_packet/2,
c2s_unbinded_packet/2, c2s_closed/2, c2s_unbinded_packet/2, c2s_closed/2, c2s_terminated/2,
c2s_handle_send/3, c2s_filter_send/2, c2s_handle_info/2]). c2s_handle_send/3, c2s_filter_send/1, c2s_handle_info/2,
c2s_handle_call/3, c2s_handle_recv/3]).
-include("xmpp.hrl"). -include("xmpp.hrl").
-include("logger.hrl"). -include("logger.hrl").
@ -60,13 +61,13 @@ start(Host, _Opts) ->
c2s_unbinded_packet, 50), c2s_unbinded_packet, 50),
ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE, ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE,
c2s_authenticated_packet, 50), c2s_authenticated_packet, 50),
ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE, ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50),
c2s_handle_send, 50), ejabberd_hooks:add(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50),
ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50),
c2s_filter_send, 50), ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, ejabberd_hooks:add(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50),
c2s_handle_info, 50), ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50),
ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50). ejabberd_hooks:add(c2s_terminated, Host, ?MODULE, c2s_terminated, 50).
stop(Host) -> stop(Host) ->
%% TODO: do something with global 'c2s_init' hook %% TODO: do something with global 'c2s_init' hook
@ -80,13 +81,13 @@ stop(Host) ->
c2s_unbinded_packet, 50), c2s_unbinded_packet, 50),
ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE, ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE,
c2s_authenticated_packet, 50), c2s_authenticated_packet, 50),
ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE, ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50),
c2s_handle_send, 50), ejabberd_hooks:delete(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50),
ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50),
c2s_filter_send, 50), ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50),
ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, ejabberd_hooks:delete(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50),
c2s_handle_info, 50), ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50),
ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50). ejabberd_hooks:delete(c2s_terminated, Host, ?MODULE, c2s_terminated, 50).
depends(_Host, _Opts) -> depends(_Host, _Opts) ->
[]. [].
@ -115,7 +116,10 @@ c2s_stream_started(#{lserver := LServer, mgmt_options := Opts} = State,
mgmt_timeout => ResumeTimeout, mgmt_timeout => ResumeTimeout,
mgmt_max_timeout => MaxResumeTimeout, mgmt_max_timeout => MaxResumeTimeout,
mgmt_ack_timeout => get_ack_timeout(LServer, Opts), mgmt_ack_timeout => get_ack_timeout(LServer, Opts),
mgmt_resend => get_resend_on_timeout(LServer, Opts)}; mgmt_resend => get_resend_on_timeout(LServer, Opts),
mgmt_stanzas_in => 0,
mgmt_stanzas_out => 0,
mgmt_stanzas_req => 0};
c2s_stream_started(State, _StreamStart) -> c2s_stream_started(State, _StreamStart) ->
State. State.
@ -143,8 +147,8 @@ c2s_unbinded_packet(State, #sm_resume{} = Pkt) ->
case handle_resume(State, Pkt) of case handle_resume(State, Pkt) of
{ok, ResumedState} -> {ok, ResumedState} ->
{stop, ResumedState}; {stop, ResumedState};
error -> {error, State1} ->
{stop, State} {stop, State1}
end; end;
c2s_unbinded_packet(State, Pkt) when ?is_sm_packet(Pkt) -> c2s_unbinded_packet(State, Pkt) when ?is_sm_packet(Pkt) ->
c2s_unauthenticated_packet(State, Pkt); c2s_unauthenticated_packet(State, Pkt);
@ -161,12 +165,26 @@ c2s_authenticated_packet(#{mgmt_state := MgmtState} = State, Pkt)
c2s_authenticated_packet(State, Pkt) -> c2s_authenticated_packet(State, Pkt) ->
update_num_stanzas_in(State, Pkt). update_num_stanzas_in(State, Pkt).
c2s_handle_recv(#{lang := Lang} = State, El, {error, Why}) ->
Xmlns = xmpp:get_ns(El),
if Xmlns == ?NS_STREAM_MGMT_2; Xmlns == ?NS_STREAM_MGMT_3 ->
Txt = xmpp:io_format_error(Why),
Err = #sm_failed{reason = 'bad-request',
text = xmpp:mk_text(Txt, Lang),
xmlns = Xmlns},
send(State, Err);
true ->
State
end;
c2s_handle_recv(State, _, _) ->
State.
c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result) c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
when MgmtState == pending; MgmtState == active -> when MgmtState == pending; MgmtState == active ->
State1 = mgmt_queue_add(State, Pkt), State1 = mgmt_queue_add(State, Pkt),
case Result of case Result of
ok when ?is_stanza(Pkt) -> ok when ?is_stanza(Pkt) ->
send_ack(State1); send_rack(State1);
ok -> ok ->
State1; State1;
{error, _} -> {error, _} ->
@ -175,21 +193,57 @@ c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
c2s_handle_send(State, _Pkt, _Result) -> c2s_handle_send(State, _Pkt, _Result) ->
State. State.
c2s_filter_send(Pkt, _State) -> c2s_filter_send({Pkt, State}) ->
Pkt. {Pkt, State}.
c2s_handle_info(#{mgmt_ack_timer := T, jid := JID} = State, c2s_handle_call(#{sid := {Time, _}} = State,
{timeout, T, ack_timeout}) -> {resume_session, Time}, From) ->
?DEBUG("Timeout waiting for stream management acknowledgement of ~s", ejabberd_c2s:reply(From, {resume, State}),
{stop, State#{mgmt_state => resumed}};
c2s_handle_call(State, {resume_session, _}, From) ->
ejabberd_c2s:reply(From, {error, <<"Previous session not found">>}),
{stop, State};
c2s_handle_call(State, _Call, _From) ->
State.
c2s_handle_info(#{mgmt_ack_timer := TRef, jid := JID} = State,
{timeout, TRef, ack_timeout}) ->
?DEBUG("Timed out waiting for stream management acknowledgement of ~s",
[jid:to_string(JID)]), [jid:to_string(JID)]),
State1 = ejabberd_c2s:close(State, _SendTrailer = false), State1 = ejabberd_c2s:close(State, _SendTrailer = false),
c2s_closed(State1, ack_timeout); {stop, transition_to_pending(State1)};
c2s_handle_info(#{mgmt_state := pending, jid := JID} = State,
{timeout, _, pending_timeout}) ->
?DEBUG("Timed out waiting for resumption of stream for ~s",
[jid:to_string(JID)]),
ejabberd_c2s:stop(State#{mgmt_state => timeout});
c2s_handle_info(State, _) -> c2s_handle_info(State, _) ->
State. State.
c2s_closed(#{mgmt_state := active} = State, Reason) when Reason /= normal -> c2s_closed(State, {stream, _}) ->
{stop, transition_to_pending(State)}; State;
c2s_closed(State, _) -> c2s_closed(#{mgmt_state := active} = State, Reason) ->
{stop, transition_to_pending(State#{stop_reason => Reason})};
c2s_closed(State, _Reason) ->
State.
c2s_terminated(#{mgmt_state := resumed, jid := JID} = State, _Reason) ->
?INFO_MSG("Closing former stream of resumed session for ~s",
[jid:to_string(JID)]),
bounce_message_queue(),
{stop, State};
c2s_terminated(#{mgmt_state := MgmtState, mgmt_stanzas_in := In, sid := SID,
user := U, server := S, resource := R} = State, _Reason) ->
case MgmtState of
timeout ->
Info = [{num_stanzas_in, In}],
ejabberd_sm:set_offline_info(SID, U, S, R, Info);
_ ->
ok
end,
route_unacked_stanzas(State),
State;
c2s_terminated(State, _Reason) ->
State. State.
%%%=================================================================== %%%===================================================================
@ -201,17 +255,14 @@ negotiate_stream_mgmt(Pkt, State) ->
case Pkt of case Pkt of
#sm_enable{} -> #sm_enable{} ->
handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt); handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt);
_ -> _ when is_record(Pkt, sm_a);
Res = if is_record(Pkt, sm_a);
is_record(Pkt, sm_r); is_record(Pkt, sm_r);
is_record(Pkt, sm_resume) -> is_record(Pkt, sm_resume) ->
#sm_failed{reason = 'unexpected-request', Err = #sm_failed{reason = 'unexpected-request', xmlns = Xmlns},
xmlns = Xmlns}; send(State, Err);
true -> _ ->
#sm_failed{reason = 'bad-request', Err = #sm_failed{reason = 'bad-request', xmlns = Xmlns},
xmlns = Xmlns} send(State, Err)
end,
send(State, Res)
end. end.
-spec perform_stream_mgmt(xmpp_element(), state()) -> state(). -spec perform_stream_mgmt(xmpp_element(), state()) -> state().
@ -223,16 +274,13 @@ perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) ->
handle_r(State); handle_r(State);
#sm_a{} -> #sm_a{} ->
handle_a(State, Pkt); handle_a(State, Pkt);
_ -> _ when is_record(Pkt, sm_enable);
Res = if is_record(Pkt, sm_enable);
is_record(Pkt, sm_resume) -> is_record(Pkt, sm_resume) ->
#sm_failed{reason = 'unexpected-request', send(State, #sm_failed{reason = 'unexpected-request',
xmlns = Xmlns}; xmlns = Xmlns});
true -> _ ->
#sm_failed{reason = 'bad-request', send(State, #sm_failed{reason = 'bad-request',
xmlns = Xmlns} xmlns = Xmlns})
end,
send(State, Res)
end; end;
_ -> _ ->
send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns}) send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns})
@ -241,7 +289,7 @@ perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) ->
-spec handle_enable(state(), sm_enable()) -> state(). -spec handle_enable(state(), sm_enable()) -> state().
handle_enable(#{mgmt_timeout := DefaultTimeout, handle_enable(#{mgmt_timeout := DefaultTimeout,
mgmt_max_timeout := MaxTimeout, mgmt_max_timeout := MaxTimeout,
xmlns := Xmlns, jid := JID} = State, mgmt_xmlns := Xmlns, jid := JID} = State,
#sm_enable{resume = Resume, max = Max}) -> #sm_enable{resume = Resume, max = Max}) ->
Timeout = if Resume == false -> Timeout = if Resume == false ->
0; 0;
@ -264,7 +312,7 @@ handle_enable(#{mgmt_timeout := DefaultTimeout,
end, end,
State1 = State#{mgmt_state => active, State1 = State#{mgmt_state => active,
mgmt_queue => queue_new(), mgmt_queue => queue_new(),
mgmt_timeout => Timeout * 1000}, mgmt_timeout => Timeout},
send(State1, Res). send(State1, Res).
-spec handle_r(state()) -> state(). -spec handle_r(state()) -> state().
@ -275,23 +323,26 @@ handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) ->
-spec handle_a(state(), sm_a()) -> state(). -spec handle_a(state(), sm_a()) -> state().
handle_a(State, #sm_a{h = H}) -> handle_a(State, #sm_a{h = H}) ->
State1 = check_h_attribute(State, H), State1 = check_h_attribute(State, H),
resend_ack(State1). resend_rack(State1).
-spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}. -spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}.
handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State, handle_resume(#{user := User, lserver := LServer,
lang := Lang, socket := Socket} = State,
#sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) -> #sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) ->
R = case inherit_session_state(State, PrevID) of R = case inherit_session_state(State, PrevID) of
{ok, InheritedState} -> {ok, InheritedState} ->
{ok, InheritedState, H}; {ok, InheritedState, H};
{error, Err, InH} -> {error, Err, InH} ->
{error, #sm_failed{reason = 'item-not-found', {error, #sm_failed{reason = 'item-not-found',
text = xmpp:mk_text(Err, Lang),
h = InH, xmlns = Xmlns}, Err}; h = InH, xmlns = Xmlns}, Err};
{error, Err} -> {error, Err} ->
{error, #sm_failed{reason = 'item-not-found', {error, #sm_failed{reason = 'item-not-found',
text = xmpp:mk_text(Err, Lang),
xmlns = Xmlns}, Err} xmlns = Xmlns}, Err}
end, end,
case R of case R of
{ok, ResumedState, NumHandled} -> {ok, #{jid := JID} = ResumedState, NumHandled} ->
State1 = check_h_attribute(ResumedState, NumHandled), State1 = check_h_attribute(ResumedState, NumHandled),
#{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1, #{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1,
AttrId = make_resume_id(State1), AttrId = make_resume_id(State1),
@ -307,14 +358,20 @@ handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State,
[ejabberd_socket:pp(Socket), jid:to_string(JID)]), [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
{ok, State5}; {ok, State5};
{error, El, Msg} -> {error, El, Msg} ->
?INFO_MSG("Cannot resume session for ~s: ~s", [jid:to_string(JID), Msg]), ?INFO_MSG("Cannot resume session for ~s@~s: ~s",
[User, LServer, Msg]),
{error, send(State, El)} {error, send(State, El)}
end. end.
-spec transition_to_pending(state()) -> state(). -spec transition_to_pending(state()) -> state().
transition_to_pending(#{mgmt_state := active} = State) -> transition_to_pending(#{mgmt_state := active, jid := JID,
%% TODO lserver := LServer, mgmt_timeout := Timeout} = State) ->
State; State1 = cancel_ack_timer(State),
?INFO_MSG("Waiting for resumption of stream for ~s", [jid:to_string(JID)]),
State2 = ejabberd_hooks:run_fold(c2s_session_pending, LServer, State1, []),
State3 = ejabberd_c2s:close(State2, _SendTrailer = false),
erlang:start_timer(timer:seconds(Timeout), self(), pending_timeout),
State3#{mgmt_state => pending};
transition_to_pending(State) -> transition_to_pending(State) ->
State. State.
@ -345,25 +402,25 @@ update_num_stanzas_in(#{mgmt_state := MgmtState,
update_num_stanzas_in(State, _El) -> update_num_stanzas_in(State, _El) ->
State. State.
send_ack(#{mgmt_ack_timer := _} = State) -> send_rack(#{mgmt_ack_timer := _} = State) ->
State; State;
send_ack(#{mgmt_xmlns := Xmlns, send_rack(#{mgmt_xmlns := Xmlns,
mgmt_stanzas_out := NumStanzasOut, mgmt_stanzas_out := NumStanzasOut,
mgmt_ack_timeout := AckTimeout} = State) -> mgmt_ack_timeout := AckTimeout} = State) ->
State1 = send(State, #sm_r{xmlns = Xmlns}), State1 = send(State, #sm_r{xmlns = Xmlns}),
TRef = erlang:start_timer(AckTimeout, self(), ack_timeout), TRef = erlang:start_timer(AckTimeout, self(), ack_timeout),
State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}. State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}.
resend_ack(#{mgmt_ack_timer := _, resend_rack(#{mgmt_ack_timer := _,
mgmt_queue := Queue, mgmt_queue := Queue,
mgmt_stanzas_out := NumStanzasOut, mgmt_stanzas_out := NumStanzasOut,
mgmt_stanzas_req := NumStanzasReq} = State) -> mgmt_stanzas_req := NumStanzasReq} = State) ->
State1 = cancel_ack_timer(State), State1 = cancel_ack_timer(State),
case NumStanzasReq < NumStanzasOut andalso not queue_is_empty(Queue) of case NumStanzasReq < NumStanzasOut andalso not queue_is_empty(Queue) of
true -> send_ack(State1); true -> send_rack(State1);
false -> State1 false -> State1
end; end;
resend_ack(State) -> resend_rack(State) ->
State. State.
-spec mgmt_queue_add(state(), xmpp_element()) -> state(). -spec mgmt_queue_add(state(), xmpp_element()) -> state().
@ -492,10 +549,22 @@ inherit_session_state(#{user := U, server := S} = State, ResumeID) ->
OldPID -> OldPID ->
OldSID = {Time, OldPID}, OldSID = {Time, OldPID},
try resume_session(OldSID, State) of try resume_session(OldSID, State) of
{resume, OldState} -> {resume, #{mgmt_xmlns := Xmlns,
mgmt_queue := Queue,
mgmt_timeout := Timeout,
mgmt_stanzas_in := NumStanzasIn,
mgmt_stanzas_out := NumStanzasOut} = OldState} ->
State1 = ejabberd_c2s:copy_state(State, OldState), State1 = ejabberd_c2s:copy_state(State, OldState),
State2 = ejabberd_c2s:open_session(State1), State2 = State1#{mgmt_xmlns => Xmlns,
{ok, State2}; mgmt_queue => Queue,
mgmt_timeout => Timeout,
mgmt_stanzas_in => NumStanzasIn,
mgmt_stanzas_out => NumStanzasOut,
mgmt_state => active},
ejabberd_sm:close_session(OldSID, U, S, R),
State3 = ejabberd_c2s:open_session(State2),
ejabberd_c2s:stop(OldPID),
{ok, State3};
{error, Msg} -> {error, Msg} ->
{error, Msg} {error, Msg}
catch exit:{noproc, _} -> catch exit:{noproc, _} ->
@ -591,6 +660,15 @@ cancel_ack_timer(#{mgmt_ack_timer := TRef} = State) ->
cancel_ack_timer(State) -> cancel_ack_timer(State) ->
State. State.
-spec bounce_message_queue() -> ok.
bounce_message_queue() ->
receive {route, From, To, Pkt} ->
ejabberd_router:route(From, To, Pkt),
bounce_message_queue()
after 0 ->
ok
end.
%%%=================================================================== %%%===================================================================
%%% Configuration processing %%% Configuration processing
%%%=================================================================== %%%===================================================================

View File

@ -44,7 +44,8 @@
-type state() :: map(). -type state() :: map().
-type stop_reason() :: {stream, reset | stream_error()} | -type stop_reason() :: {stream, reset | stream_error()} |
{tls, term()} | {tls, term()} |
{socket, inet:posix() | closed | timeout}. {socket, inet:posix() | closed | timeout} |
internal_failure.
-callback init(list()) -> {ok, state()} | {stop, term()} | ignore. -callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
-callback handle_cast(term(), state()) -> state(). -callback handle_cast(term(), state()) -> state().
@ -54,7 +55,6 @@
-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. -callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
-callback handle_stream_start(state()) -> state(). -callback handle_stream_start(state()) -> state().
-callback handle_stream_end(stop_reason(), state()) -> state(). -callback handle_stream_end(stop_reason(), state()) -> state().
-callback handle_stream_close(stop_reason(), state()) -> state().
-callback handle_cdata(binary(), state()) -> state(). -callback handle_cdata(binary(), state()) -> state().
-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state(). -callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
-callback handle_authenticated_packet(xmpp_element(), state()) -> state(). -callback handle_authenticated_packet(xmpp_element(), state()) -> state().
@ -83,7 +83,6 @@
code_change/3, code_change/3,
handle_stream_start/1, handle_stream_start/1,
handle_stream_end/2, handle_stream_end/2,
handle_stream_close/2,
handle_cdata/2, handle_cdata/2,
handle_authenticated_packet/2, handle_authenticated_packet/2,
handle_unauthenticated_packet/2, handle_unauthenticated_packet/2,
@ -193,6 +192,8 @@ format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
format("Stream failed: ~s", [format_stream_error(Reason, Txt)]); format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
format_error({tls, Reason}) -> format_error({tls, Reason}) ->
format("TLS failed: ~w", [Reason]); format("TLS failed: ~w", [Reason]);
format_error(internal_failure) ->
<<"Internal server error">>;
format_error(Err) -> format_error(Err) ->
format("Unrecognized error: ~w", [Err]). format("Unrecognized error: ~w", [Err]).
@ -263,18 +264,19 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
#{stream_state := wait_for_stream, #{stream_state := wait_for_stream,
xmlns := XMLNS, lang := MyLang} = State) -> xmlns := XMLNS, lang := MyLang} = State) ->
El = #xmlel{name = Name, attrs = Attrs}, El = #xmlel{name = Name, attrs = Attrs},
noreply(
try xmpp:decode(El, XMLNS, []) of try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt -> #stream_start{} = Pkt ->
State1 = send_header(State, Pkt), State1 = send_header(State, Pkt),
case is_disconnected(State1) of case is_disconnected(State1) of
true -> State1; true -> State1;
false -> noreply(process_stream(Pkt, State1)) false -> process_stream(Pkt, State1)
end; end;
_ -> _ ->
State1 = send_header(State), State1 = send_header(State),
case is_disconnected(State1) of case is_disconnected(State1) of
true -> State1; true -> State1;
false -> noreply(send_element(State1, xmpp:serr_invalid_xml())) false -> send_element(State1, xmpp:serr_invalid_xml())
end end
catch _:{xmpp_codec, Why} -> catch _:{xmpp_codec, Why} ->
State1 = send_header(State), State1 = send_header(State),
@ -284,11 +286,12 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
Txt = xmpp:io_format_error(Why), Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)), Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang), Err = xmpp:serr_invalid_xml(Txt, Lang),
noreply(send_element(State1, Err)) send_element(State1, Err)
end end
end; end);
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
State1 = send_header(State), State1 = send_header(State),
noreply(
case is_disconnected(State1) of case is_disconnected(State1) of
true -> State1; true -> State1;
false -> false ->
@ -298,10 +301,11 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
_ -> _ ->
xmpp:serr_not_well_formed() xmpp:serr_not_well_formed()
end, end,
noreply(send_element(State1, Err)) send_element(State1, Err)
end; end);
handle_info({'$gen_event', {xmlstreamelement, El}}, handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, lang := MyLang, mod := Mod} = State) -> #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
noreply(
try xmpp:decode(El, NS, [ignore_els]) of try xmpp:decode(El, NS, [ignore_els]) of
Pkt -> Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State) State1 = try Mod:handle_recv(El, Pkt, State)
@ -309,7 +313,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
end, end,
case is_disconnected(State1) of case is_disconnected(State1) of
true -> State1; true -> State1;
false -> noreply(process_element(Pkt, State1)) false -> process_element(Pkt, State1)
end end
catch _:{xmpp_codec, Why} -> catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, {error, Why}, State) State1 = try Mod:handle_recv(El, {error, Why}, State)
@ -320,18 +324,18 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
false -> false ->
Txt = xmpp:io_format_error(Why), Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)), Lang = select_lang(MyLang, xmpp:get_lang(El)),
noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang))) send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
end end
end; end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) -> #{mod := Mod} = State) ->
noreply(try Mod:handle_cdata(Data, State) noreply(try Mod:handle_cdata(Data, State)
catch _:undef -> State catch _:undef -> State
end); end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) -> handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({error, {stream, reset}}, State)); noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) -> handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_close({error, {socket, closed}}, State)); noreply(process_stream_end({socket, closed}, State));
handle_info(timeout, #{mod := Mod} = State) -> handle_info(timeout, #{mod := Mod} = State) ->
Disconnected = is_disconnected(State), Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State) noreply(try Mod:handle_timeout(State)
@ -342,7 +346,7 @@ handle_info(timeout, #{mod := Mod} = State) ->
end); end);
handle_info({'DOWN', MRef, _Type, _Object, _Info}, handle_info({'DOWN', MRef, _Type, _Object, _Info},
#{socket_monitor := MRef} = State) -> #{socket_monitor := MRef} = State) ->
noreply(process_stream_close({error, {socket, closed}}, State)); noreply(process_stream_end({socket, closed}, State));
handle_info(Info, #{mod := Mod} = State) -> handle_info(Info, #{mod := Mod} = State) ->
noreply(try Mod:handle_info(Info, State) noreply(try Mod:handle_info(Info, State)
catch _:undef -> State catch _:undef -> State
@ -390,15 +394,6 @@ peername(SockMod, Socket) ->
_ -> SockMod:peername(Socket) _ -> SockMod:peername(Socket)
end. end.
-spec process_stream_close(stop_reason(), state()) -> state().
process_stream_close(_, #{stream_state := disconnected} = State) ->
State;
process_stream_close(Reason, #{mod := Mod} = State) ->
State1 = send_trailer(State),
try Mod:handle_stream_close(Reason, State1)
catch _:undef -> stop(State1)
end.
-spec process_stream_end(stop_reason(), state()) -> state(). -spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) -> process_stream_end(_, #{stream_state := disconnected} = State) ->
State; State;
@ -414,6 +409,8 @@ process_stream(#stream_start{xmlns = XML_NS,
#{xmlns := NS} = State) #{xmlns := NS} = State)
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
send_element(State, xmpp:serr_invalid_namespace()); send_element(State, xmpp:serr_invalid_namespace());
process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
send_element(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang}, process_stream(#stream_start{lang = Lang},
#{xmlns := ?NS_CLIENT, lang := DefaultLang} = State) #{xmlns := ?NS_CLIENT, lang := DefaultLang} = State)
when size(Lang) > 35 -> when size(Lang) > 35 ->
@ -520,7 +517,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
#handshake{} -> #handshake{} ->
State; State;
#stream_error{} -> #stream_error{} ->
process_stream_end({error, {stream, Pkt}}, State); process_stream_end({stream, Pkt}, State);
_ when StateName == wait_for_sasl_request; _ when StateName == wait_for_sasl_request;
StateName == wait_for_handshake; StateName == wait_for_handshake;
StateName == wait_for_sasl_response -> StateName == wait_for_sasl_response ->
@ -704,7 +701,7 @@ process_starttls_failure(Why, State) ->
State1 = send_element(State, #starttls_failure{}), State1 = send_element(State, #starttls_failure{}),
case is_disconnected(State1) of case is_disconnected(State1) of
true -> State1; true -> State1;
false -> process_stream_end({error, {tls, Why}}, State1) false -> process_stream_end({tls, Why}, State1)
end. end.
-spec process_sasl_request(sasl_auth(), state()) -> state(). -spec process_sasl_request(sasl_auth(), state()) -> state().
@ -939,8 +936,8 @@ set_from_to(Pkt, #{lang := Lang}) ->
end. end.
-spec send_header(state()) -> state(). -spec send_header(state()) -> state().
send_header(State) -> send_header(#{stream_version := Version} = State) ->
send_header(State, #stream_start{}). send_header(State, #stream_start{version = Version}).
-spec send_header(state(), stream_start()) -> state(). -spec send_header(state(), stream_start()) -> state().
send_header(#{stream_id := StreamID, send_header(#{stream_id := StreamID,
@ -959,8 +956,9 @@ send_header(#{stream_id := StreamID,
undefined -> jid:make(DefaultServer) undefined -> jid:make(DefaultServer)
end, end,
Version = case HisVersion of Version = case HisVersion of
undefined -> MyVersion; undefined -> undefined;
_ -> HisVersion {0,_} -> HisVersion;
_ -> MyVersion
end, end,
Header = xmpp:encode(#stream_start{version = Version, Header = xmpp:encode(#stream_start{version = Version,
lang = Lang, lang = Lang,
@ -969,10 +967,12 @@ send_header(#{stream_id := StreamID,
db_xmlns = NS_DB, db_xmlns = NS_DB,
id = StreamID, id = StreamID,
from = From}), from = From}),
State1 = State#{lang => Lang, stream_header_sent => true}, State1 = State#{lang => Lang,
stream_version => Version,
stream_header_sent => true},
case send_text(State1, fxml:element_to_header(Header)) of case send_text(State1, fxml:element_to_header(Header)) of
ok -> State1; ok -> State1;
{error, Why} -> process_stream_close({error, {socket, Why}}, State1) {error, Why} -> process_stream_end({socket, Why}, State1)
end; end;
send_header(State, _) -> send_header(State, _) ->
State. State.
@ -987,11 +987,11 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
end, end,
case Result of case Result of
_ when is_record(Pkt, stream_error) -> _ when is_record(Pkt, stream_error) ->
process_stream_end({error, {stream, Pkt}}, State1); process_stream_end({stream, Pkt}, State1);
ok -> ok ->
State1; State1;
{error, Why} -> {error, Why} ->
process_stream_close({error, {socket, Why}}, State1) process_stream_end({socket, Why}, State1)
end. end.
-spec send_error(state(), xmpp_element(), stanza_error()) -> state(). -spec send_error(state(), xmpp_element(), stanza_error()) -> state().
@ -1022,7 +1022,7 @@ send_text(#{socket := Sock, sockmod := SockMod,
stream_header_sent := true}, Data) when StateName /= disconnected -> stream_header_sent := true}, Data) when StateName /= disconnected ->
SockMod:send(Sock, Data); SockMod:send(Sock, Data);
send_text(_, _) -> send_text(_, _) ->
{error, einval}. {error, closed}.
-spec close_socket(state()) -> state(). -spec close_socket(state()) -> state().
close_socket(#{sockmod := SockMod, socket := Socket} = State) -> close_socket(#{sockmod := SockMod, socket := Socket} = State) ->

View File

@ -33,6 +33,7 @@
-include_lib("kernel/include/inet.hrl"). -include_lib("kernel/include/inet.hrl").
-type state() :: map(). -type state() :: map().
-type noreply() :: {noreply, state(), timeout()}.
-type host_port() :: {inet:hostname(), inet:port_number()}. -type host_port() :: {inet:hostname(), inet:port_number()}.
-type ip_port() :: {inet:ip_address(), inet:port_number()}. -type ip_port() :: {inet:ip_address(), inet:port_number()}.
-type network_error() :: {error, inet:posix() | inet_res:res_error()}. -type network_error() :: {error, inet:posix() | inet_res:res_error()}.
@ -42,7 +43,8 @@
{tls, term()} | {tls, term()} |
{pkix, binary()} | {pkix, binary()} |
{auth, atom() | binary() | string()} | {auth, atom() | binary() | string()} |
{socket, inet:posix() | closed | timeout}. {socket, inet:posix() | closed | timeout} |
internal_failure.
-callback init(list()) -> {ok, state()} | {stop, term()} | ignore. -callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
@ -107,7 +109,7 @@ close(_, _) ->
establish(State) -> establish(State) ->
process_stream_established(State). process_stream_established(State).
-spec set_timeout(state(), non_neg_integer() | infinity) -> state(). -spec set_timeout(state(), timeout()) -> state().
set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
case Timeout of case Timeout of
infinity -> State#{stream_timeout => infinity}; infinity -> State#{stream_timeout => infinity};
@ -148,12 +150,15 @@ format_error({tls, Reason}) ->
format("TLS failed: ~w", [Reason]); format("TLS failed: ~w", [Reason]);
format_error({auth, Reason}) -> format_error({auth, Reason}) ->
format("Authentication failed: ~s", [Reason]); format("Authentication failed: ~s", [Reason]);
format_error(internal_failure) ->
<<"Internal server error">>;
format_error(Err) -> format_error(Err) ->
format("Unrecognized error: ~w", [Err]). format("Unrecognized error: ~w", [Err]).
%%%=================================================================== %%%===================================================================
%%% gen_server callbacks %%% gen_server callbacks
%%%=================================================================== %%%===================================================================
-spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore.
init([Mod, SockMod, From, To, Opts]) -> init([Mod, SockMod, From, To, Opts]) ->
Time = p1_time_compat:monotonic_time(milli_seconds), Time = p1_time_compat:monotonic_time(milli_seconds),
State = #{owner => self(), State = #{owner => self(),
@ -183,17 +188,20 @@ init([Mod, SockMod, From, To, Opts]) ->
Err Err
end. end.
-spec handle_call(term(), term(), state()) -> noreply().
handle_call(Call, From, #{mod := Mod} = State) -> handle_call(Call, From, #{mod := Mod} = State) ->
noreply(try Mod:handle_call(Call, From, State) noreply(try Mod:handle_call(Call, From, State)
catch _:undef -> State catch _:undef -> State
end). end).
-spec handle_cast(term(), state()) -> noreply().
handle_cast(connect, #{remote_server := RemoteServer, handle_cast(connect, #{remote_server := RemoteServer,
sockmod := SockMod, sockmod := SockMod,
stream_state := connecting} = State) -> stream_state := connecting} = State) ->
noreply(
case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of
false -> false ->
noreply(process_stream_close({error, {idna, bad_string}}, State)); process_stream_end({idna, bad_string}, State);
ASCIIName -> ASCIIName ->
case resolve(binary_to_list(ASCIIName), State) of case resolve(binary_to_list(ASCIIName), State) of
{ok, AddrPorts} -> {ok, AddrPorts} ->
@ -204,15 +212,14 @@ handle_cast(connect, #{remote_server := RemoteServer,
socket => Socket, socket => Socket,
socket_monitor => SocketMonitor}, socket_monitor => SocketMonitor},
State2 = State1#{stream_state => wait_for_stream}, State2 = State1#{stream_state => wait_for_stream},
noreply(send_header(State2)); send_header(State2);
{error, Why} -> {error, Why} ->
Err = {error, {socket, Why}}, process_stream_end({socket, Why}, State)
noreply(process_stream_close(Err, State))
end; end;
{error, Why} -> {error, Why} ->
noreply(process_stream_close({error, {dns, Why}}, State)) process_stream_end({dns, Why}, State)
end end
end; end);
handle_cast(connect, State) -> handle_cast(connect, State) ->
%% Ignoring connection attempts in other states %% Ignoring connection attempts in other states
noreply(State); noreply(State);
@ -225,23 +232,26 @@ handle_cast(Cast, #{mod := Mod} = State) ->
catch _:undef -> State catch _:undef -> State
end). end).
-spec handle_info(term(), state()) -> noreply().
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
#{stream_state := wait_for_stream, #{stream_state := wait_for_stream,
xmlns := XMLNS, lang := MyLang} = State) -> xmlns := XMLNS, lang := MyLang} = State) ->
El = #xmlel{name = Name, attrs = Attrs}, El = #xmlel{name = Name, attrs = Attrs},
noreply(
try xmpp:decode(El, XMLNS, []) of try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt -> #stream_start{} = Pkt ->
noreply(process_stream(Pkt, State)); process_stream(Pkt, State);
_ -> _ ->
noreply(send_element(State, xmpp:serr_invalid_xml())) send_element(State, xmpp:serr_invalid_xml())
catch _:{xmpp_codec, Why} -> catch _:{xmpp_codec, Why} ->
Txt = xmpp:io_format_error(Why), Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)), Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang), Err = xmpp:serr_invalid_xml(Txt, Lang),
noreply(send_element(State, Err)) send_element(State, Err)
end; end);
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
State1 = send_header(State), State1 = send_header(State),
noreply(
case is_disconnected(State1) of case is_disconnected(State1) of
true -> State1; true -> State1;
false -> false ->
@ -251,10 +261,11 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
_ -> _ ->
xmpp:serr_not_well_formed() xmpp:serr_not_well_formed()
end, end,
noreply(send_element(State1, Err)) send_element(State1, Err)
end; end);
handle_info({'$gen_event', {xmlstreamelement, El}}, handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, lang := MyLang, mod := Mod} = State) -> #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
noreply(
try xmpp:decode(El, NS, [ignore_els]) of try xmpp:decode(El, NS, [ignore_els]) of
Pkt -> Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State) State1 = try Mod:handle_recv(El, Pkt, State)
@ -262,7 +273,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
end, end,
case is_disconnected(State1) of case is_disconnected(State1) of
true -> State1; true -> State1;
false -> noreply(process_element(Pkt, State1)) false -> process_element(Pkt, State1)
end end
catch _:{xmpp_codec, Why} -> catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, undefined, State) State1 = try Mod:handle_recv(El, undefined, State)
@ -273,18 +284,18 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
false -> false ->
Txt = xmpp:io_format_error(Why), Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)), Lang = select_lang(MyLang, xmpp:get_lang(El)),
noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang))) send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
end end
end; end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) -> #{mod := Mod} = State) ->
noreply(try Mod:handle_cdata(Data, State) noreply(try Mod:handle_cdata(Data, State)
catch _:undef -> State catch _:undef -> State
end); end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) -> handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({error, {stream, reset}}, State)); noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) -> handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_close({error, {socket, closed}}, State)); noreply(process_stream_end({socket, closed}, State));
handle_info(timeout, #{mod := Mod} = State) -> handle_info(timeout, #{mod := Mod} = State) ->
Disconnected = is_disconnected(State), Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State) noreply(try Mod:handle_timeout(State)
@ -295,12 +306,13 @@ handle_info(timeout, #{mod := Mod} = State) ->
end); end);
handle_info({'DOWN', MRef, _Type, _Object, _Info}, handle_info({'DOWN', MRef, _Type, _Object, _Info},
#{socket_monitor := MRef} = State) -> #{socket_monitor := MRef} = State) ->
noreply(process_stream_close({error, {socket, closed}}, State)); noreply(process_stream_end({socket, closed}, State));
handle_info(Info, #{mod := Mod} = State) -> handle_info(Info, #{mod := Mod} = State) ->
noreply(try Mod:handle_info(Info, State) noreply(try Mod:handle_info(Info, State)
catch _:undef -> State catch _:undef -> State
end). end).
-spec terminate(term(), state()) -> any().
terminate(Reason, #{mod := Mod} = State) -> terminate(Reason, #{mod := Mod} = State) ->
case get(already_terminated) of case get(already_terminated) of
true -> true ->
@ -319,7 +331,7 @@ code_change(OldVsn, #{mod := Mod} = State, Extra) ->
%%%=================================================================== %%%===================================================================
%%% Internal functions %%% Internal functions
%%%=================================================================== %%%===================================================================
-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. -spec noreply(state()) -> noreply().
noreply(#{stream_timeout := infinity} = State) -> noreply(#{stream_timeout := infinity} = State) ->
{noreply, State, infinity}; {noreply, State, infinity};
noreply(#{stream_timeout := {MSecs, OldTime}} = State) -> noreply(#{stream_timeout := {MSecs, OldTime}} = State) ->
@ -335,15 +347,6 @@ new_id() ->
is_disconnected(#{stream_state := StreamState}) -> is_disconnected(#{stream_state := StreamState}) ->
StreamState == disconnected. StreamState == disconnected.
-spec process_stream_close(stop_reason(), state()) -> state().
process_stream_close(_, #{stream_state := disconnected} = State) ->
State;
process_stream_close(Reason, #{mod := Mod} = State) ->
State1 = send_trailer(State),
try Mod:handle_stream_close(Reason, State1)
catch _:undef -> stop(State1)
end.
-spec process_stream_end(stop_reason(), state()) -> state(). -spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) -> process_stream_end(_, #{stream_state := disconnected} = State) ->
State; State;
@ -359,6 +362,8 @@ process_stream(#stream_start{xmlns = XML_NS,
#{xmlns := NS} = State) #{xmlns := NS} = State)
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
send_element(State, xmpp:serr_invalid_namespace()); send_element(State, xmpp:serr_invalid_namespace());
process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
send_element(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang, id = ID, process_stream(#stream_start{lang = Lang, id = ID,
version = Version} = StreamStart, version = Version} = StreamStart,
#{mod := Mod} = State) -> #{mod := Mod} = State) ->
@ -370,8 +375,10 @@ process_stream(#stream_start{lang = Lang, id = ID,
true -> State2; true -> State2;
false -> false ->
case Version of case Version of
{1,0} -> State2#{stream_state => wait_for_features}; {1, _} ->
_ -> process_stream_downgrade(StreamStart, State) State2#{stream_state => wait_for_features};
_ ->
process_stream_downgrade(StreamStart, State2)
end end
end. end.
@ -387,7 +394,7 @@ process_element(Pkt, #{stream_state := StateName} = State) ->
#sasl_failure{} when StateName == wait_for_sasl_response -> #sasl_failure{} when StateName == wait_for_sasl_response ->
process_sasl_failure(Pkt, State); process_sasl_failure(Pkt, State);
#stream_error{} -> #stream_error{} ->
process_stream_end({error, {stream, Pkt}}, State); process_stream_end({stream, Pkt}, State);
_ when is_record(Pkt, stream_features); _ when is_record(Pkt, stream_features);
is_record(Pkt, starttls_proceed); is_record(Pkt, starttls_proceed);
is_record(Pkt, starttls); is_record(Pkt, starttls);
@ -487,14 +494,23 @@ process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) ->
stream_encrypted => true}, stream_encrypted => true},
send_header(State1); send_header(State1);
{error, Why} -> {error, Why} ->
process_stream_close({error, {tls, Why}}, State) process_stream_end({tls, Why}, State)
end. end.
-spec process_stream_downgrade(stream_start(), state()) -> state(). -spec process_stream_downgrade(stream_start(), state()) -> state().
process_stream_downgrade(StreamStart, #{mod := Mod} = State) -> process_stream_downgrade(StreamStart,
try Mod:downgrade_stream(StreamStart, State) #{mod := Mod, lang := Lang,
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
if not Encrypted and TLSRequired ->
Txt = <<"Use of STARTTLS required">>,
send_element(State, xmpp:err_policy_violation(Txt, Lang));
true ->
State1 = State#{stream_state => downgraded},
try Mod:handle_stream_downgraded(StreamStart, State1)
catch _:undef -> catch _:undef ->
send_element(State, xmpp:serr_unsupported_version()) send_element(State1, xmpp:serr_unsupported_version())
end
end. end.
-spec process_cert_verification(state()) -> state(). -spec process_cert_verification(state()) -> state().
@ -509,7 +525,7 @@ process_cert_verification(#{stream_encrypted := true,
{ok, _} -> {ok, _} ->
State#{stream_verified => true}; State#{stream_verified => true};
{error, Why, _Peer} -> {error, Why, _Peer} ->
process_stream_close({error, {pkix, Why}}, State) process_stream_end({pkix, Why}, State)
end; end;
false -> false ->
State#{stream_verified => true} State#{stream_verified => true}
@ -538,7 +554,7 @@ process_sasl_success(#{mod := Mod,
-spec process_sasl_failure(sasl_failure(), state()) -> state(). -spec process_sasl_failure(sasl_failure(), state()) -> state().
process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) -> process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) ->
try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State) try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State)
catch _:undef -> process_stream_close({error, {auth, Reason}}, State) catch _:undef -> process_stream_end({auth, Reason}, State)
end. end.
-spec process_packet(xmpp_element(), state()) -> state(). -spec process_packet(xmpp_element(), state()) -> state().
@ -581,7 +597,7 @@ send_header(#{remote_server := RemoteServer,
version = {1,0}}), version = {1,0}}),
case send_text(State, fxml:element_to_header(Header)) of case send_text(State, fxml:element_to_header(Header)) of
ok -> State; ok -> State;
{error, Why} -> process_stream_close({error, {socket, Why}}, State) {error, Why} -> process_stream_end({socket, Why}, State)
end. end.
-spec send_element(state(), xmpp_element()) -> state(). -spec send_element(state(), xmpp_element()) -> state().
@ -596,11 +612,11 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
false -> false ->
case send_text(State1, Data) of case send_text(State1, Data) of
_ when is_record(Pkt, stream_error) -> _ when is_record(Pkt, stream_error) ->
process_stream_end({error, {stream, Pkt}}, State1); process_stream_end({stream, Pkt}, State1);
ok -> ok ->
State1; State1;
{error, Why} -> {error, Why} ->
process_stream_close({error, {socket, Why}}, State1) process_stream_end({socket, Why}, State1)
end end
end. end.
@ -626,7 +642,7 @@ send_text(#{sockmod := SockMod, socket := Socket,
stream_state := StateName}, Data) when StateName /= disconnected -> stream_state := StateName}, Data) when StateName /= disconnected ->
SockMod:send(Socket, Data); SockMod:send(Socket, Data);
send_text(_, _) -> send_text(_, _) ->
{error, einval}. {error, closed}.
-spec send_trailer(state()) -> state(). -spec send_trailer(state()) -> state().
send_trailer(State) -> send_trailer(State) ->