More refactoring on session management

This commit is contained in:
Evgeniy Khramtsov 2016-12-30 00:00:36 +03:00
parent 309bdfbe28
commit e7fe4dc474
14 changed files with 569 additions and 401 deletions

View File

@ -188,7 +188,7 @@ try_register(User, Server, Password) ->
true -> {atomic, exists};
false ->
LServer = jid:nameprep(Server),
case lists:member(LServer, ?MYHOSTS) of
case ejabberd_router:is_my_host(LServer) of
true ->
Res = lists:foldl(fun (_M, {atomic, ok} = Res) -> Res;
(M, _) ->

View File

@ -37,17 +37,18 @@
compress_methods/1, bind/2, get_password_fun/1,
check_password_fun/1, check_password_digest_fun/1,
unauthenticated_stream_features/1, authenticated_stream_features/1,
handle_stream_start/2, handle_stream_end/2, handle_stream_close/2,
handle_stream_start/2, handle_stream_end/2,
handle_unauthenticated_packet/2, handle_authenticated_packet/2,
handle_auth_success/4, handle_auth_failure/4, handle_send/3,
handle_recv/3, handle_cdata/2, handle_unbinded_packet/2]).
%% Hooks
-export([handle_unexpected_info/2, handle_unexpected_cast/2,
reject_unauthenticated_packet/2, process_closed/2]).
-export([handle_unexpected_cast/2,
reject_unauthenticated_packet/2, process_closed/2,
process_terminated/2, process_info/2]).
%% API
-export([get_presence/1, get_subscription/2, get_subscribed/1,
open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1,
copy_state/2, add_hooks/0]).
reply/2, copy_state/2, set_timeout/2, add_hooks/1]).
-include("ejabberd.hrl").
-include("xmpp.hrl").
@ -76,6 +77,9 @@ socket_type() ->
call(Ref, Msg, Timeout) ->
xmpp_stream_in:call(Ref, Msg, Timeout).
reply(Ref, Reply) ->
xmpp_stream_in:reply(Ref, Reply).
-spec get_presence(pid()) -> presence().
get_presence(Ref) ->
call(Ref, get_presence, 1000).
@ -112,37 +116,39 @@ stop(Ref) ->
send(Pid, Pkt) when is_pid(Pid) ->
xmpp_stream_in:send(Pid, Pkt);
send(#{lserver := LServer} = State, Pkt) ->
case ejabberd_hooks:run_fold(c2s_filter_send, LServer, Pkt, [State]) of
drop -> State;
Pkt1 -> xmpp_stream_in:send(State, Pkt1)
case ejabberd_hooks:run_fold(c2s_filter_send, LServer, {Pkt, State}, []) of
{drop, State1} -> State1;
{Pkt1, State1} -> xmpp_stream_in:send(State1, Pkt1)
end.
-spec set_timeout(state(), timeout()) -> state().
set_timeout(State, Timeout) ->
xmpp_stream_in:set_timeout(State, Timeout).
-spec establish(state()) -> state().
establish(State) ->
xmpp_stream_in:establish(State).
-spec add_hooks() -> ok.
add_hooks() ->
lists:foreach(
fun(Host) ->
ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100),
ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
reject_unauthenticated_packet, 100),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
handle_unexpected_info, 100),
ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE,
handle_unexpected_cast, 100)
end, ?MYHOSTS).
-spec add_hooks(binary()) -> ok.
add_hooks(Host) ->
ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100),
ejabberd_hooks:add(c2s_terminated, Host, ?MODULE,
process_terminated, 100),
ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
reject_unauthenticated_packet, 100),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
process_info, 100),
ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE,
handle_unexpected_cast, 100).
%% Copies content of one c2s state to another.
%% This is needed for session migration from one pid to another.
-spec copy_state(state(), state()) -> state().
copy_state(#{owner := Owner} = NewState,
#{jid := JID, resource := Resource, sid := {Time, _},
auth_module := AuthModule, lserver := LServer,
pres_t := PresT, pres_a := PresA,
pres_f := PresF} = OldState) ->
#{jid := JID, resource := Resource, sid := {Time, _},
auth_module := AuthModule, lserver := LServer,
pres_t := PresT, pres_a := PresA,
pres_f := PresF} = OldState) ->
State1 = case OldState of
#{pres_last := Pres, pres_timestamp := PresTS} ->
NewState#{pres_last => Pres, pres_timestamp => PresTS};
@ -158,10 +164,46 @@ copy_state(#{owner := Owner} = NewState,
pres_f => PresF},
ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]).
-spec open_session(state()) -> {ok, state()} | state().
open_session(#{user := U, server := S, resource := R,
sid := SID, ip := IP, auth_module := AuthModule} = State) ->
JID = jid:make(U, S, R),
change_shaper(State),
Conn = get_conn_type(State),
State1 = State#{conn => Conn, resource => R, jid => JID},
Prio = try maps:get(pres_last, State) of
Pres -> get_priority_from_presence(Pres)
catch _:{badkey, _} ->
undefined
end,
Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}],
ejabberd_sm:open_session(SID, U, S, R, Prio, Info),
xmpp_stream_in:establish(State1).
%%%===================================================================
%%% Hooks
%%%===================================================================
handle_unexpected_info(State, Info) ->
process_info(#{lserver := LServer} = State,
{route, From, To, Packet0}) ->
Packet = xmpp:set_from_to(Packet0, From, To),
{Pass, State1} = case Packet of
#presence{} ->
process_presence_in(State, Packet);
#message{} ->
process_message_in(State, Packet);
#iq{} ->
process_iq_in(State, Packet)
end,
if Pass ->
Packet1 = ejabberd_hooks:run_fold(
user_receive_packet, LServer, Packet, [State1]),
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
send(State1, Packet1);
true ->
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
State1
end;
process_info(State, Info) ->
?WARNING_MSG("got unexpected info: ~p", [Info]),
State.
@ -173,8 +215,22 @@ reject_unauthenticated_packet(State, Pkt) ->
Err = xmpp:err_not_authorized(),
xmpp_stream_in:send_error(State, Pkt, Err).
process_closed(State, _Reason) ->
stop(State).
process_closed(State, Reason) ->
stop(State#{stop_reason => Reason}).
process_terminated(#{socket := Socket, jid := JID} = State,
Reason) ->
Status = format_reason(State, Reason),
?INFO_MSG("(~s) Closing c2s connection for ~s: ~s",
[ejabberd_socket:pp(Socket), jid:to_string(JID), Status]),
Pres = #presence{type = unavailable,
status = xmpp:mk_text(Status),
from = JID, to = jid:remove_resource(JID)},
State1 = broadcast_presence_unavailable(State, Pres),
bounce_message_queue(),
State1;
process_terminated(State, _Reason) ->
State.
%%%===================================================================
%%% xmpp_stream_in callbacks
@ -248,25 +304,9 @@ bind(R, #{user := U, server := S, access := Access, lang := Lang,
end
end.
-spec open_session(state()) -> {ok, state()} | state().
open_session(#{user := U, server := S, resource := R,
sid := SID, ip := IP, auth_module := AuthModule} = State) ->
JID = jid:make(U, S, R),
change_shaper(State),
Conn = get_conn_type(State),
State1 = State#{conn => Conn, resource => R, jid => JID},
Prio = try maps:get(pres_last, State) of
Pres -> get_priority_from_presence(Pres)
catch _:{badkey, _} ->
undefined
end,
Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}],
ejabberd_sm:open_session(SID, U, S, R, Prio, Info),
State1.
handle_stream_start(StreamStart,
#{lserver := LServer, ip := IP, lang := Lang} = State) ->
case lists:member(LServer, ?MYHOSTS) of
case ejabberd_router:is_my_host(LServer) of
false ->
send(State, xmpp:serr_host_unknown());
true ->
@ -284,10 +324,8 @@ handle_stream_start(StreamStart,
end.
handle_stream_end(Reason, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_closed, LServer, State, [Reason]).
handle_stream_close(_Reason, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_closed, LServer, State, [normal]).
State1 = State#{stop_reason => Reason},
ejabberd_hooks:run_fold(c2s_closed, LServer, State1, [Reason]).
handle_auth_success(User, Mech, AuthModule,
#{socket := Socket, ip := IP, lserver := LServer} = State) ->
@ -296,8 +334,7 @@ handle_auth_success(User, Mech, AuthModule,
ejabberd_auth:backend_type(AuthModule),
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
State1 = State#{auth_module => AuthModule},
ejabberd_hooks:run_fold(c2s_auth_result, LServer,
State1, [true, User]).
ejabberd_hooks:run_fold(c2s_auth_result, LServer, State1, [true, User]).
handle_auth_failure(User, Mech, Reason,
#{socket := Socket, ip := IP, lserver := LServer} = State) ->
@ -307,16 +344,13 @@ handle_auth_failure(User, Mech, Reason,
true -> ""
end,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
ejabberd_hooks:run_fold(c2s_auth_result, LServer,
State, [false, User]).
ejabberd_hooks:run_fold(c2s_auth_result, LServer, State, [false, User]).
handle_unbinded_packet(Pkt, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer,
State, [Pkt]).
ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer, State, [Pkt]).
handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_unauthenticated_packet,
LServer, State, [Pkt]).
ejabberd_hooks:run_fold(c2s_unauthenticated_packet, LServer, State, [Pkt]).
handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) ->
ejabberd_hooks:run_fold(c2s_authenticated_packet,
@ -366,20 +400,22 @@ init([State, Opts]) ->
zlib => Zlib,
lang => ?MYLANG,
server => ?MYNAME,
lserver => ?MYNAME,
access => Access,
shaper => Shaper},
ejabberd_hooks:run_fold(c2s_init, {ok, State1}, [Opts]).
handle_call(get_presence, _From, #{jid := JID} = State) ->
handle_call(get_presence, From, #{jid := JID} = State) ->
Pres = try maps:get(pres_last, State)
catch _:{badkey, _} ->
BareJID = jid:remove_resource(JID),
#presence{from = JID, to = BareJID, type = unavailable}
end,
{reply, Pres, State};
handle_call(get_subscribed, _From, #{pres_f := PresF} = State) ->
Subscribed = ?SETS:to_list(PresF),
{reply, Subscribed, State};
reply(From, Pres),
State;
handle_call(get_subscribed, From, #{pres_f := PresF} = State) ->
reply(From, ?SETS:to_list(PresF)),
State;
handle_call(Request, From, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(
c2s_handle_call, LServer, State, [Request, From]).
@ -387,30 +423,22 @@ handle_call(Request, From, #{lserver := LServer} = State) ->
handle_cast(Msg, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_cast, LServer, State, [Msg]).
handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) ->
Packet = xmpp:set_from_to(Packet0, From, To),
{Pass, NewState} = case Packet of
#presence{} ->
process_presence_in(State, Packet);
#message{} ->
process_message_in(State, Packet);
#iq{} ->
process_iq_in(State, Packet)
end,
if Pass ->
Packet1 = ejabberd_hooks:run_fold(
user_receive_packet, LServer, Packet, [NewState]),
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
send(NewState, Packet1);
true ->
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
NewState
end;
handle_info(Info, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]).
terminate(_Reason, _State) ->
ok.
terminate(Reason, #{sid := SID, jid := _,
user := U, server := S, resource := R,
lserver := LServer} = State) ->
Status = format_reason(State, Reason),
case maps:is_key(pres_last, State) of
true ->
ejabberd_sm:close_session_unset_presence(SID, U, S, R, Status);
false ->
ejabberd_sm:close_session(SID, U, S, R)
end,
ejabberd_hooks:run_fold(c2s_terminated, LServer, State, [Reason]);
terminate(Reason, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_terminated, LServer, State, [Reason]).
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
@ -684,6 +712,15 @@ resource_conflict_action(U, S, R) ->
{accept_resource, Rnew}
end.
-spec bounce_message_queue() -> ok.
bounce_message_queue() ->
receive {route, From, To, Pkt} ->
ejabberd_router:route(From, To, Pkt),
bounce_message_queue()
after 0 ->
ok
end.
-spec new_uniq_id() -> binary().
new_uniq_id() ->
iolist_to_binary(
@ -735,6 +772,14 @@ do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) ->
end
end.
-spec format_reason(state(), term()) -> binary().
format_reason(#{stop_reason := Reason}, _) ->
xmpp_stream_in:format_error(Reason);
format_reason(_, Reason) when Reason /= normal ->
<<"internal server error">>;
format_reason(_, _) ->
<<"">>.
transform_listen_option(Opt, Opts) ->
[Opt|Opts].

View File

@ -322,7 +322,7 @@ add_header(Name, Value, State)->
get_host_really_served(undefined, Provided) ->
Provided;
get_host_really_served(Default, Provided) ->
case lists:member(Provided, ?MYHOSTS) of
case ejabberd_router:is_my_host(Provided) of
true -> Provided;
false -> Default
end.

View File

@ -350,7 +350,7 @@ process_el({xmlstreamelement, #xmlel{name = <<"host">>,
JIDS = fxml:get_attr_s(<<"jid">>, Attrs),
case jid:from_string(JIDS) of
#jid{lserver = S} ->
case lists:member(S, ?MYHOSTS) of
case ejabberd_router:is_my_host(S) of
true ->
process_users(Els, State#state{server = S});
false ->

View File

@ -34,7 +34,7 @@
-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
compress_methods/1,
unauthenticated_stream_features/1, authenticated_stream_features/1,
handle_stream_start/2, handle_stream_end/2, handle_stream_close/2,
handle_stream_start/2, handle_stream_end/2,
handle_stream_established/1, handle_auth_success/4,
handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2,
handle_unauthenticated_packet/2, handle_authenticated_packet/2]).
@ -160,9 +160,6 @@ handle_stream_start(_StreamStart, #{lserver := LServer} = State) ->
handle_stream_end(Reason, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [Reason]).
handle_stream_close(_Reason, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [normal]).
handle_stream_established(State) ->
set_idle_timeout(State#{established => true}).

View File

@ -15,14 +15,14 @@
%% xmpp_stream_out callbacks
-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
handle_auth_success/2, handle_auth_failure/3, handle_packet/2,
handle_stream_end/2, handle_stream_close/2,
handle_stream_end/2, handle_stream_downgraded/2,
handle_recv/3, handle_send/4, handle_cdata/2,
handle_stream_established/1, handle_timeout/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
%% Hooks
-export([process_auth_result/2, process_closed/2, handle_unexpected_info/2,
handle_unexpected_cast/2]).
handle_unexpected_cast/2, process_downgraded/2]).
%% API
-export([start/3, start_link/3, connect/1, close/1, stop/1, send/2,
route/2, establish/1, update_state/2, add_hooks/0]).
@ -83,7 +83,9 @@ add_hooks() ->
ejabberd_hooks:add(s2s_out_handle_info, Host, ?MODULE,
handle_unexpected_info, 100),
ejabberd_hooks:add(s2s_out_handle_cast, Host, ?MODULE,
handle_unexpected_cast, 100)
handle_unexpected_cast, 100),
ejabberd_hooks:add(s2s_out_downgraded, Host, ?MODULE,
process_downgraded, 100)
end, ?MYHOSTS).
%%%===================================================================
@ -95,25 +97,28 @@ process_auth_result(#{server := LServer, remote_server := RServer} = State,
?INFO_MSG("Closing outbound s2s connection ~s -> ~s: authentication failed;"
" bouncing for ~p seconds",
[LServer, RServer, Delay]),
State1 = close(State),
State2 = bounce_queue(State1),
xmpp_stream_out:set_timeout(State2, timer:seconds(Delay));
State1 = State#{on_route => bounce},
State2 = close(State1),
State3 = bounce_queue(State2),
xmpp_stream_out:set_timeout(State3, timer:seconds(Delay));
process_auth_result(State, true) ->
State.
process_closed(#{server := LServer, remote_server := RServer,
on_route := send} = State,
Reason) ->
?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s",
[LServer, RServer, xmpp_stream_out:format_error(Reason)]),
stop(State);
process_closed(#{server := LServer, remote_server := RServer} = State,
_Reason) ->
Reason) ->
Delay = get_delay(),
?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s; "
"bouncing for ~p seconds",
[LServer, RServer,
try maps:get(stop_reason, State) of
{error, Why} -> xmpp_stream_out:format_error(Why)
catch _:undef -> <<"unexplained reason">>
end,
Delay]),
State1 = bounce_queue(State),
xmpp_stream_out:set_timeout(State1, timer:seconds(Delay)).
[LServer, RServer, xmpp_stream_out:format_error(Reason), Delay]),
State1 = State#{on_route => bounce},
State2 = bounce_queue(State1),
xmpp_stream_out:set_timeout(State2, timer:seconds(Delay)).
handle_unexpected_info(State, Info) ->
?WARNING_MSG("got unexpected info: ~p", [Info]),
@ -123,6 +128,9 @@ handle_unexpected_cast(State, Msg) ->
?WARNING_MSG("got unexpected cast: ~p", [Msg]),
State.
process_downgraded(State, _StreamStart) ->
send(State, xmpp:serr_unsupported_version()).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
@ -153,21 +161,19 @@ handle_auth_failure(Mech, Reason,
?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s",
[ejabberd_socket:pp(Socket), Mech, LServer, RServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
State1 = State#{on_route => bounce,
stop_reason => {error, {auth, Reason}}},
State1 = State#{stop_reason => {auth, Reason}},
ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State1, [false]).
handle_packet(Pkt, #{server := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]).
handle_stream_end(Reason, #{server := LServer} = State) ->
State1 = State#{on_route => bounce, stop_reason => Reason},
ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [normal]).
handle_stream_close(Reason, #{server := LServer} = State) ->
State1 = State#{on_route => bounce, stop_reason => Reason},
State1 = State#{stop_reason => Reason},
ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [Reason]).
handle_stream_downgraded(StreamStart, #{server := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_out_downgraded, LServer, State, [StreamStart]).
handle_stream_established(State) ->
State1 = State#{on_route => send},
State2 = resend_queue(State1),
@ -183,15 +189,10 @@ handle_send(Pkt, El, Data, #{server := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_out_handle_send, LServer,
State, [Pkt, El, Data]).
handle_timeout(#{server := LServer, remote_server := RServer,
on_route := Action} = State) ->
handle_timeout(#{on_route := Action} = State) ->
case Action of
bounce -> stop(State);
queue -> send(State, xmpp:serr_connection_timeout());
send ->
?INFO_MSG("Closing outbound s2s connection ~s -> ~s: inactive",
[LServer, RServer]),
stop(State)
_ -> send(State, xmpp:serr_connection_timeout())
end.
init([#{server := LServer, remote_server := RServer} = State, Opts]) ->
@ -229,7 +230,7 @@ terminate(Reason, #{server := LServer,
ejabberd_s2s:remove_connection({LServer, RServer}, self()),
State1 = case Reason of
normal -> State;
_ -> State#{stop_reason => {error, internal_failure}}
_ -> State#{stop_reason => internal_failure}
end,
bounce_queue(State1),
bounce_message_queue(State1).
@ -258,8 +259,7 @@ bounce_queue(#{queue := Q} = State) ->
-spec bounce_message_queue(state()) -> state().
bounce_message_queue(State) ->
receive
{route, Pkt} ->
receive {route, Pkt} ->
State1 = bounce_packet(Pkt, State),
bounce_message_queue(State1)
after 0 ->
@ -278,21 +278,19 @@ bounce_packet(_, State) ->
State.
-spec mk_bounce_error(binary(), state()) -> stanza_error().
mk_bounce_error(Lang, State) ->
try maps:get(stop_reason, State) of
{error, internal_failure} ->
mk_bounce_error(Lang, #{stop_reason := Why}) ->
Reason = xmpp_stream_out:format_error(Why),
case Why of
internal_failure ->
xmpp:err_internal_server_error();
{error, Why} ->
Reason = xmpp_stream_out:format_error(Why),
case Why of
{dns, _} ->
xmpp:err_remote_server_timeout(Reason, Lang);
_ ->
xmpp:err_remote_server_not_found(Reason, Lang)
end
catch _:{badkey, _} ->
xmpp:err_remote_server_not_found()
end.
{dns, _} ->
xmpp:err_remote_server_not_found(Reason, Lang);
_ ->
xmpp:err_remote_server_timeout(Reason, Lang)
end;
mk_bounce_error(_Lang, _State) ->
%% We should not be here. Probably :)
xmpp:err_remote_server_not_found().
-spec get_delay() -> non_neg_integer().
get_delay() ->

View File

@ -99,7 +99,7 @@ handle_stream_start(_StreamStart,
#{remote_server := RemoteServer,
lang := Lang,
host_opts := HostOpts} = State) ->
case lists:member(RemoteServer, ?MYHOSTS) of
case ejabberd_router:is_my_host(RemoteServer) of
true ->
Txt = <<"Unable to register route on existing local domain">>,
xmpp_stream_in:send(State, xmpp:serr_conflict(Txt, Lang));

View File

@ -390,7 +390,8 @@ init([]) ->
ejabberd_hooks:add(offline_message_hook, Host,
ejabberd_sm, bounce_offline_message, 100),
ejabberd_hooks:add(remove_user, Host,
ejabberd_sm, disconnect_removed_user, 100)
ejabberd_sm, disconnect_removed_user, 100),
ejabberd_c2s:add_hooks(Host)
end, ?MYHOSTS),
ejabberd_commands:register_commands(get_commands_spec()),
{ok, #state{}}.

View File

@ -192,7 +192,7 @@ process([<<"server">>, SHost | RPath] = Path,
method = Method} =
Request) ->
Host = jid:nameprep(SHost),
case lists:member(Host, ?MYHOSTS) of
case ejabberd_router:is_my_host(Host) of
true ->
case get_auth_admin(Auth, HostHTTP, Path, Method) of
{ok, {User, Server}} ->

View File

@ -133,8 +133,7 @@ open_session(State, IQ, R) ->
case ejabberd_c2s:bind(R, State) of
{ok, State1} ->
Res = xmpp:make_iq_result(IQ),
State2 = ejabberd_c2s:send(State1, Res),
ejabberd_c2s:establish(State2);
ejabberd_c2s:send(State1, Res);
{error, Err, State1} ->
Res = xmpp:make_error(IQ, Err),
ejabberd_c2s:send(State1, Res)

View File

@ -28,7 +28,8 @@
%% gen_mod API
-export([start/2, stop/1, depends/2, mod_opt_type/1]).
%% Hooks
-export([s2s_out_auth_result/2, s2s_in_packet/2, s2s_out_packet/2,
-export([s2s_out_auth_result/2, s2s_out_downgraded/2,
s2s_in_packet/2, s2s_out_packet/2,
s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]).
-include("ejabberd.hrl").
@ -57,6 +58,8 @@ start(Host, _Opts) ->
s2s_in_packet, 50),
ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE,
s2s_out_packet, 50),
ejabberd_hooks:add(s2s_out_downgraded, Host, ?MODULE,
s2s_out_downgraded, 50),
ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE,
s2s_out_auth_result, 50)
end.
@ -74,6 +77,8 @@ stop(Host) ->
s2s_in_packet, 50),
ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE,
s2s_out_packet, 50),
ejabberd_hooks:delete(s2s_out_downgraded, Host, ?MODULE,
s2s_out_downgraded, 50),
ejabberd_hooks:delete(s2s_out_auth_result, Host, ?MODULE,
s2s_out_auth_result, 50).
@ -104,47 +109,56 @@ s2s_out_init(Acc, _Opts) ->
s2s_out_closed(#{server := LServer,
remote_server := RServer,
db_verify := {StreamID, _Key, _Pid}} = State, _Reason) ->
db_verify := {StreamID, _Key, _Pid}} = State, Reason) ->
%% Outbound s2s verificating connection (created at step 1) is
%% closed suddenly without receiving the response.
%% Building a response on our own
Response = #db_verify{from = RServer, to = LServer,
id = StreamID, type = error,
sub_els = [mk_error(internal_server_error)]},
sub_els = [mk_error(Reason)]},
s2s_out_packet(State, Response);
s2s_out_closed(State, _Reason) ->
State.
s2s_out_auth_result(#{server := LServer,
remote_server := RServer,
db_verify := {StreamID, Key, _Pid}} = State,
_) ->
s2s_out_auth_result(#{db_verify := _} = State, _) ->
%% The temporary outbound s2s connect (intended for verification)
%% has passed authentication state (either successfully or not, no matter)
%% and at this point we can send verification request as described
%% in section 2.1.2, step 2
Request = #db_verify{from = LServer, to = RServer,
key = Key, id = StreamID},
{stop, ejabberd_s2s_out:send(State, Request)};
{stop, send_verify_request(State)};
s2s_out_auth_result(#{db_enabled := true,
socket := Socket, ip := IP,
server := LServer,
remote_server := RServer,
stream_remote_id := StreamID} = State, false) ->
remote_server := RServer} = State, false) ->
%% SASL authentication has failed, retrying with dialback
%% Sending dialback request, section 2.1.1, step 1
?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)",
[ejabberd_socket:pp(Socket), LServer, RServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
Key = make_key(LServer, RServer, StreamID),
State1 = maps:remove(stop_reason, State#{on_route => queue}),
State2 = ejabberd_s2s_out:send(State1, #db_result{from = LServer,
to = RServer,
key = Key}),
{stop, State2};
{stop, send_db_request(State1)};
s2s_out_auth_result(State, _) ->
State.
s2s_out_downgraded(#{db_verify := _} = State, _) ->
%% The verifying outbound s2s connection detected non-RFC compliant
%% server, send verification request immediately without auth phase,
%% section 2.1.2, step 2
{stop, send_verify_request(State)};
s2s_out_downgraded(#{db_enabled := true,
socket := Socket, ip := IP,
server := LServer,
remote_server := RServer} = State, _) ->
%% non-RFC compliant server detected, send dialback request instantly,
%% section 2.1.1, step 1
?INFO_MSG("(~s) Trying s2s dialback authentication with "
"non-RFC compliant server: ~s -> ~s (~s)",
[ejabberd_socket:pp(Socket), LServer, RServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
{stop, send_db_request(State)};
s2s_out_downgraded(State, _) ->
State.
s2s_in_packet(#{stream_id := StreamID} = State,
#db_result{from = From, to = To, key = Key, type = undefined}) ->
%% Received dialback request, section 2.2.1, step 1
@ -220,6 +234,23 @@ make_key(From, To, StreamID) ->
crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)),
[To, " ", From, " ", StreamID])).
-spec send_verify_request(ejabberd_s2s_out:state()) -> ejabberd_s2s_out:state().
send_verify_request(#{server := LServer,
remote_server := RServer,
db_verify := {StreamID, Key, _Pid}} = State) ->
Request = #db_verify{from = LServer, to = RServer,
key = Key, id = StreamID},
ejabberd_s2s_out:send(State, Request).
-spec send_db_request(ejabberd_s2s_out:state()) -> ejabberd_s2s_out:state().
send_db_request(#{server := LServer,
remote_server := RServer,
stream_remote_id := StreamID} = State) ->
Key = make_key(LServer, RServer, StreamID),
ejabberd_s2s_out:send(State, #db_result{from = LServer,
to = RServer,
key = Key}).
-spec send_db_result(ejabberd_s2s_in:state(), db_verify()) -> ejabberd_s2s_in:state().
send_db_result(State, #db_verify{from = From, to = To,
type = Type, sub_els = Els}) ->
@ -255,6 +286,9 @@ mk_error(forbidden) ->
xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG);
mk_error(host_unknown) ->
xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG);
mk_error({_Class, _Reason} = Why) ->
Txt = xmpp_stream_out:format_error(Why),
xmpp:err_remote_server_not_found(Txt, ?MYLANG);
mk_error(_) ->
xmpp:err_internal_server_error().

View File

@ -30,8 +30,9 @@
%% hooks
-export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2,
c2s_authenticated_packet/2, c2s_unauthenticated_packet/2,
c2s_unbinded_packet/2, c2s_closed/2,
c2s_handle_send/3, c2s_filter_send/2, c2s_handle_info/2]).
c2s_unbinded_packet/2, c2s_closed/2, c2s_terminated/2,
c2s_handle_send/3, c2s_filter_send/1, c2s_handle_info/2,
c2s_handle_call/3, c2s_handle_recv/3]).
-include("xmpp.hrl").
-include("logger.hrl").
@ -60,13 +61,13 @@ start(Host, _Opts) ->
c2s_unbinded_packet, 50),
ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE,
c2s_authenticated_packet, 50),
ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE,
c2s_handle_send, 50),
ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE,
c2s_filter_send, 50),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
c2s_handle_info, 50),
ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50).
ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50),
ejabberd_hooks:add(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50),
ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50),
ejabberd_hooks:add(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50),
ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50),
ejabberd_hooks:add(c2s_terminated, Host, ?MODULE, c2s_terminated, 50).
stop(Host) ->
%% TODO: do something with global 'c2s_init' hook
@ -80,13 +81,13 @@ stop(Host) ->
c2s_unbinded_packet, 50),
ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE,
c2s_authenticated_packet, 50),
ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE,
c2s_handle_send, 50),
ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE,
c2s_filter_send, 50),
ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
c2s_handle_info, 50),
ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50).
ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50),
ejabberd_hooks:delete(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50),
ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50),
ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50),
ejabberd_hooks:delete(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50),
ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50),
ejabberd_hooks:delete(c2s_terminated, Host, ?MODULE, c2s_terminated, 50).
depends(_Host, _Opts) ->
[].
@ -115,7 +116,10 @@ c2s_stream_started(#{lserver := LServer, mgmt_options := Opts} = State,
mgmt_timeout => ResumeTimeout,
mgmt_max_timeout => MaxResumeTimeout,
mgmt_ack_timeout => get_ack_timeout(LServer, Opts),
mgmt_resend => get_resend_on_timeout(LServer, Opts)};
mgmt_resend => get_resend_on_timeout(LServer, Opts),
mgmt_stanzas_in => 0,
mgmt_stanzas_out => 0,
mgmt_stanzas_req => 0};
c2s_stream_started(State, _StreamStart) ->
State.
@ -143,8 +147,8 @@ c2s_unbinded_packet(State, #sm_resume{} = Pkt) ->
case handle_resume(State, Pkt) of
{ok, ResumedState} ->
{stop, ResumedState};
error ->
{stop, State}
{error, State1} ->
{stop, State1}
end;
c2s_unbinded_packet(State, Pkt) when ?is_sm_packet(Pkt) ->
c2s_unauthenticated_packet(State, Pkt);
@ -161,12 +165,26 @@ c2s_authenticated_packet(#{mgmt_state := MgmtState} = State, Pkt)
c2s_authenticated_packet(State, Pkt) ->
update_num_stanzas_in(State, Pkt).
c2s_handle_recv(#{lang := Lang} = State, El, {error, Why}) ->
Xmlns = xmpp:get_ns(El),
if Xmlns == ?NS_STREAM_MGMT_2; Xmlns == ?NS_STREAM_MGMT_3 ->
Txt = xmpp:io_format_error(Why),
Err = #sm_failed{reason = 'bad-request',
text = xmpp:mk_text(Txt, Lang),
xmlns = Xmlns},
send(State, Err);
true ->
State
end;
c2s_handle_recv(State, _, _) ->
State.
c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
when MgmtState == pending; MgmtState == active ->
State1 = mgmt_queue_add(State, Pkt),
case Result of
ok when ?is_stanza(Pkt) ->
send_ack(State1);
send_rack(State1);
ok ->
State1;
{error, _} ->
@ -175,21 +193,57 @@ c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
c2s_handle_send(State, _Pkt, _Result) ->
State.
c2s_filter_send(Pkt, _State) ->
Pkt.
c2s_filter_send({Pkt, State}) ->
{Pkt, State}.
c2s_handle_info(#{mgmt_ack_timer := T, jid := JID} = State,
{timeout, T, ack_timeout}) ->
?DEBUG("Timeout waiting for stream management acknowledgement of ~s",
c2s_handle_call(#{sid := {Time, _}} = State,
{resume_session, Time}, From) ->
ejabberd_c2s:reply(From, {resume, State}),
{stop, State#{mgmt_state => resumed}};
c2s_handle_call(State, {resume_session, _}, From) ->
ejabberd_c2s:reply(From, {error, <<"Previous session not found">>}),
{stop, State};
c2s_handle_call(State, _Call, _From) ->
State.
c2s_handle_info(#{mgmt_ack_timer := TRef, jid := JID} = State,
{timeout, TRef, ack_timeout}) ->
?DEBUG("Timed out waiting for stream management acknowledgement of ~s",
[jid:to_string(JID)]),
State1 = ejabberd_c2s:close(State, _SendTrailer = false),
c2s_closed(State1, ack_timeout);
{stop, transition_to_pending(State1)};
c2s_handle_info(#{mgmt_state := pending, jid := JID} = State,
{timeout, _, pending_timeout}) ->
?DEBUG("Timed out waiting for resumption of stream for ~s",
[jid:to_string(JID)]),
ejabberd_c2s:stop(State#{mgmt_state => timeout});
c2s_handle_info(State, _) ->
State.
c2s_closed(#{mgmt_state := active} = State, Reason) when Reason /= normal ->
{stop, transition_to_pending(State)};
c2s_closed(State, _) ->
c2s_closed(State, {stream, _}) ->
State;
c2s_closed(#{mgmt_state := active} = State, Reason) ->
{stop, transition_to_pending(State#{stop_reason => Reason})};
c2s_closed(State, _Reason) ->
State.
c2s_terminated(#{mgmt_state := resumed, jid := JID} = State, _Reason) ->
?INFO_MSG("Closing former stream of resumed session for ~s",
[jid:to_string(JID)]),
bounce_message_queue(),
{stop, State};
c2s_terminated(#{mgmt_state := MgmtState, mgmt_stanzas_in := In, sid := SID,
user := U, server := S, resource := R} = State, _Reason) ->
case MgmtState of
timeout ->
Info = [{num_stanzas_in, In}],
ejabberd_sm:set_offline_info(SID, U, S, R, Info);
_ ->
ok
end,
route_unacked_stanzas(State),
State;
c2s_terminated(State, _Reason) ->
State.
%%%===================================================================
@ -201,17 +255,14 @@ negotiate_stream_mgmt(Pkt, State) ->
case Pkt of
#sm_enable{} ->
handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt);
_ when is_record(Pkt, sm_a);
is_record(Pkt, sm_r);
is_record(Pkt, sm_resume) ->
Err = #sm_failed{reason = 'unexpected-request', xmlns = Xmlns},
send(State, Err);
_ ->
Res = if is_record(Pkt, sm_a);
is_record(Pkt, sm_r);
is_record(Pkt, sm_resume) ->
#sm_failed{reason = 'unexpected-request',
xmlns = Xmlns};
true ->
#sm_failed{reason = 'bad-request',
xmlns = Xmlns}
end,
send(State, Res)
Err = #sm_failed{reason = 'bad-request', xmlns = Xmlns},
send(State, Err)
end.
-spec perform_stream_mgmt(xmpp_element(), state()) -> state().
@ -223,16 +274,13 @@ perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) ->
handle_r(State);
#sm_a{} ->
handle_a(State, Pkt);
_ when is_record(Pkt, sm_enable);
is_record(Pkt, sm_resume) ->
send(State, #sm_failed{reason = 'unexpected-request',
xmlns = Xmlns});
_ ->
Res = if is_record(Pkt, sm_enable);
is_record(Pkt, sm_resume) ->
#sm_failed{reason = 'unexpected-request',
xmlns = Xmlns};
true ->
#sm_failed{reason = 'bad-request',
xmlns = Xmlns}
end,
send(State, Res)
send(State, #sm_failed{reason = 'bad-request',
xmlns = Xmlns})
end;
_ ->
send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns})
@ -241,7 +289,7 @@ perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) ->
-spec handle_enable(state(), sm_enable()) -> state().
handle_enable(#{mgmt_timeout := DefaultTimeout,
mgmt_max_timeout := MaxTimeout,
xmlns := Xmlns, jid := JID} = State,
mgmt_xmlns := Xmlns, jid := JID} = State,
#sm_enable{resume = Resume, max = Max}) ->
Timeout = if Resume == false ->
0;
@ -264,7 +312,7 @@ handle_enable(#{mgmt_timeout := DefaultTimeout,
end,
State1 = State#{mgmt_state => active,
mgmt_queue => queue_new(),
mgmt_timeout => Timeout * 1000},
mgmt_timeout => Timeout},
send(State1, Res).
-spec handle_r(state()) -> state().
@ -275,23 +323,26 @@ handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) ->
-spec handle_a(state(), sm_a()) -> state().
handle_a(State, #sm_a{h = H}) ->
State1 = check_h_attribute(State, H),
resend_ack(State1).
resend_rack(State1).
-spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}.
handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State,
handle_resume(#{user := User, lserver := LServer,
lang := Lang, socket := Socket} = State,
#sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) ->
R = case inherit_session_state(State, PrevID) of
{ok, InheritedState} ->
{ok, InheritedState, H};
{error, Err, InH} ->
{error, #sm_failed{reason = 'item-not-found',
text = xmpp:mk_text(Err, Lang),
h = InH, xmlns = Xmlns}, Err};
{error, Err} ->
{error, #sm_failed{reason = 'item-not-found',
text = xmpp:mk_text(Err, Lang),
xmlns = Xmlns}, Err}
end,
case R of
{ok, ResumedState, NumHandled} ->
{ok, #{jid := JID} = ResumedState, NumHandled} ->
State1 = check_h_attribute(ResumedState, NumHandled),
#{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1,
AttrId = make_resume_id(State1),
@ -307,14 +358,20 @@ handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State,
[ejabberd_socket:pp(Socket), jid:to_string(JID)]),
{ok, State5};
{error, El, Msg} ->
?INFO_MSG("Cannot resume session for ~s: ~s", [jid:to_string(JID), Msg]),
?INFO_MSG("Cannot resume session for ~s@~s: ~s",
[User, LServer, Msg]),
{error, send(State, El)}
end.
-spec transition_to_pending(state()) -> state().
transition_to_pending(#{mgmt_state := active} = State) ->
%% TODO
State;
transition_to_pending(#{mgmt_state := active, jid := JID,
lserver := LServer, mgmt_timeout := Timeout} = State) ->
State1 = cancel_ack_timer(State),
?INFO_MSG("Waiting for resumption of stream for ~s", [jid:to_string(JID)]),
State2 = ejabberd_hooks:run_fold(c2s_session_pending, LServer, State1, []),
State3 = ejabberd_c2s:close(State2, _SendTrailer = false),
erlang:start_timer(timer:seconds(Timeout), self(), pending_timeout),
State3#{mgmt_state => pending};
transition_to_pending(State) ->
State.
@ -345,25 +402,25 @@ update_num_stanzas_in(#{mgmt_state := MgmtState,
update_num_stanzas_in(State, _El) ->
State.
send_ack(#{mgmt_ack_timer := _} = State) ->
send_rack(#{mgmt_ack_timer := _} = State) ->
State;
send_ack(#{mgmt_xmlns := Xmlns,
send_rack(#{mgmt_xmlns := Xmlns,
mgmt_stanzas_out := NumStanzasOut,
mgmt_ack_timeout := AckTimeout} = State) ->
State1 = send(State, #sm_r{xmlns = Xmlns}),
TRef = erlang:start_timer(AckTimeout, self(), ack_timeout),
State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}.
resend_ack(#{mgmt_ack_timer := _,
mgmt_queue := Queue,
mgmt_stanzas_out := NumStanzasOut,
mgmt_stanzas_req := NumStanzasReq} = State) ->
resend_rack(#{mgmt_ack_timer := _,
mgmt_queue := Queue,
mgmt_stanzas_out := NumStanzasOut,
mgmt_stanzas_req := NumStanzasReq} = State) ->
State1 = cancel_ack_timer(State),
case NumStanzasReq < NumStanzasOut andalso not queue_is_empty(Queue) of
true -> send_ack(State1);
true -> send_rack(State1);
false -> State1
end;
resend_ack(State) ->
resend_rack(State) ->
State.
-spec mgmt_queue_add(state(), xmpp_element()) -> state().
@ -492,10 +549,22 @@ inherit_session_state(#{user := U, server := S} = State, ResumeID) ->
OldPID ->
OldSID = {Time, OldPID},
try resume_session(OldSID, State) of
{resume, OldState} ->
{resume, #{mgmt_xmlns := Xmlns,
mgmt_queue := Queue,
mgmt_timeout := Timeout,
mgmt_stanzas_in := NumStanzasIn,
mgmt_stanzas_out := NumStanzasOut} = OldState} ->
State1 = ejabberd_c2s:copy_state(State, OldState),
State2 = ejabberd_c2s:open_session(State1),
{ok, State2};
State2 = State1#{mgmt_xmlns => Xmlns,
mgmt_queue => Queue,
mgmt_timeout => Timeout,
mgmt_stanzas_in => NumStanzasIn,
mgmt_stanzas_out => NumStanzasOut,
mgmt_state => active},
ejabberd_sm:close_session(OldSID, U, S, R),
State3 = ejabberd_c2s:open_session(State2),
ejabberd_c2s:stop(OldPID),
{ok, State3};
{error, Msg} ->
{error, Msg}
catch exit:{noproc, _} ->
@ -591,6 +660,15 @@ cancel_ack_timer(#{mgmt_ack_timer := TRef} = State) ->
cancel_ack_timer(State) ->
State.
-spec bounce_message_queue() -> ok.
bounce_message_queue() ->
receive {route, From, To, Pkt} ->
ejabberd_router:route(From, To, Pkt),
bounce_message_queue()
after 0 ->
ok
end.
%%%===================================================================
%%% Configuration processing
%%%===================================================================

View File

@ -44,7 +44,8 @@
-type state() :: map().
-type stop_reason() :: {stream, reset | stream_error()} |
{tls, term()} |
{socket, inet:posix() | closed | timeout}.
{socket, inet:posix() | closed | timeout} |
internal_failure.
-callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
-callback handle_cast(term(), state()) -> state().
@ -54,7 +55,6 @@
-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
-callback handle_stream_start(state()) -> state().
-callback handle_stream_end(stop_reason(), state()) -> state().
-callback handle_stream_close(stop_reason(), state()) -> state().
-callback handle_cdata(binary(), state()) -> state().
-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
-callback handle_authenticated_packet(xmpp_element(), state()) -> state().
@ -83,7 +83,6 @@
code_change/3,
handle_stream_start/1,
handle_stream_end/2,
handle_stream_close/2,
handle_cdata/2,
handle_authenticated_packet/2,
handle_unauthenticated_packet/2,
@ -193,6 +192,8 @@ format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
format_error({tls, Reason}) ->
format("TLS failed: ~w", [Reason]);
format_error(internal_failure) ->
<<"Internal server error">>;
format_error(Err) ->
format("Unrecognized error: ~w", [Err]).
@ -263,75 +264,78 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
#{stream_state := wait_for_stream,
xmlns := XMLNS, lang := MyLang} = State) ->
El = #xmlel{name = Name, attrs = Attrs},
try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt ->
State1 = send_header(State, Pkt),
case is_disconnected(State1) of
true -> State1;
false -> noreply(process_stream(Pkt, State1))
end;
_ ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false -> noreply(send_element(State1, xmpp:serr_invalid_xml()))
end
catch _:{xmpp_codec, Why} ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang),
noreply(send_element(State1, Err))
end
end;
noreply(
try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt ->
State1 = send_header(State, Pkt),
case is_disconnected(State1) of
true -> State1;
false -> process_stream(Pkt, State1)
end;
_ ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false -> send_element(State1, xmpp:serr_invalid_xml())
end
catch _:{xmpp_codec, Why} ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang),
send_element(State1, Err)
end
end);
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false ->
Err = case Reason of
<<"XML stanza is too big">> ->
xmpp:serr_policy_violation(Reason, Lang);
_ ->
xmpp:serr_not_well_formed()
end,
noreply(send_element(State1, Err))
end;
noreply(
case is_disconnected(State1) of
true -> State1;
false ->
Err = case Reason of
<<"XML stanza is too big">> ->
xmpp:serr_policy_violation(Reason, Lang);
_ ->
xmpp:serr_not_well_formed()
end,
send_element(State1, Err)
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> noreply(process_element(Pkt, State1))
end
catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, {error, Why}, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
end
end;
noreply(
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_element(Pkt, State1)
end
catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, {error, Why}, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) ->
noreply(try Mod:handle_cdata(Data, State)
catch _:undef -> State
end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({error, {stream, reset}}, State));
noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_close({error, {socket, closed}}, State));
noreply(process_stream_end({socket, closed}, State));
handle_info(timeout, #{mod := Mod} = State) ->
Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State)
@ -342,7 +346,7 @@ handle_info(timeout, #{mod := Mod} = State) ->
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
#{socket_monitor := MRef} = State) ->
noreply(process_stream_close({error, {socket, closed}}, State));
noreply(process_stream_end({socket, closed}, State));
handle_info(Info, #{mod := Mod} = State) ->
noreply(try Mod:handle_info(Info, State)
catch _:undef -> State
@ -390,15 +394,6 @@ peername(SockMod, Socket) ->
_ -> SockMod:peername(Socket)
end.
-spec process_stream_close(stop_reason(), state()) -> state().
process_stream_close(_, #{stream_state := disconnected} = State) ->
State;
process_stream_close(Reason, #{mod := Mod} = State) ->
State1 = send_trailer(State),
try Mod:handle_stream_close(Reason, State1)
catch _:undef -> stop(State1)
end.
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
@ -414,6 +409,8 @@ process_stream(#stream_start{xmlns = XML_NS,
#{xmlns := NS} = State)
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
send_element(State, xmpp:serr_invalid_namespace());
process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
send_element(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang},
#{xmlns := ?NS_CLIENT, lang := DefaultLang} = State)
when size(Lang) > 35 ->
@ -520,7 +517,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
#handshake{} ->
State;
#stream_error{} ->
process_stream_end({error, {stream, Pkt}}, State);
process_stream_end({stream, Pkt}, State);
_ when StateName == wait_for_sasl_request;
StateName == wait_for_handshake;
StateName == wait_for_sasl_response ->
@ -704,7 +701,7 @@ process_starttls_failure(Why, State) ->
State1 = send_element(State, #starttls_failure{}),
case is_disconnected(State1) of
true -> State1;
false -> process_stream_end({error, {tls, Why}}, State1)
false -> process_stream_end({tls, Why}, State1)
end.
-spec process_sasl_request(sasl_auth(), state()) -> state().
@ -939,8 +936,8 @@ set_from_to(Pkt, #{lang := Lang}) ->
end.
-spec send_header(state()) -> state().
send_header(State) ->
send_header(State, #stream_start{}).
send_header(#{stream_version := Version} = State) ->
send_header(State, #stream_start{version = Version}).
-spec send_header(state(), stream_start()) -> state().
send_header(#{stream_id := StreamID,
@ -959,8 +956,9 @@ send_header(#{stream_id := StreamID,
undefined -> jid:make(DefaultServer)
end,
Version = case HisVersion of
undefined -> MyVersion;
_ -> HisVersion
undefined -> undefined;
{0,_} -> HisVersion;
_ -> MyVersion
end,
Header = xmpp:encode(#stream_start{version = Version,
lang = Lang,
@ -969,10 +967,12 @@ send_header(#{stream_id := StreamID,
db_xmlns = NS_DB,
id = StreamID,
from = From}),
State1 = State#{lang => Lang, stream_header_sent => true},
State1 = State#{lang => Lang,
stream_version => Version,
stream_header_sent => true},
case send_text(State1, fxml:element_to_header(Header)) of
ok -> State1;
{error, Why} -> process_stream_close({error, {socket, Why}}, State1)
{error, Why} -> process_stream_end({socket, Why}, State1)
end;
send_header(State, _) ->
State.
@ -987,11 +987,11 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
end,
case Result of
_ when is_record(Pkt, stream_error) ->
process_stream_end({error, {stream, Pkt}}, State1);
process_stream_end({stream, Pkt}, State1);
ok ->
State1;
{error, Why} ->
process_stream_close({error, {socket, Why}}, State1)
process_stream_end({socket, Why}, State1)
end.
-spec send_error(state(), xmpp_element(), stanza_error()) -> state().
@ -1022,7 +1022,7 @@ send_text(#{socket := Sock, sockmod := SockMod,
stream_header_sent := true}, Data) when StateName /= disconnected ->
SockMod:send(Sock, Data);
send_text(_, _) ->
{error, einval}.
{error, closed}.
-spec close_socket(state()) -> state().
close_socket(#{sockmod := SockMod, socket := Socket} = State) ->

View File

@ -33,6 +33,7 @@
-include_lib("kernel/include/inet.hrl").
-type state() :: map().
-type noreply() :: {noreply, state(), timeout()}.
-type host_port() :: {inet:hostname(), inet:port_number()}.
-type ip_port() :: {inet:ip_address(), inet:port_number()}.
-type network_error() :: {error, inet:posix() | inet_res:res_error()}.
@ -42,7 +43,8 @@
{tls, term()} |
{pkix, binary()} |
{auth, atom() | binary() | string()} |
{socket, inet:posix() | closed | timeout}.
{socket, inet:posix() | closed | timeout} |
internal_failure.
-callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
@ -107,7 +109,7 @@ close(_, _) ->
establish(State) ->
process_stream_established(State).
-spec set_timeout(state(), non_neg_integer() | infinity) -> state().
-spec set_timeout(state(), timeout()) -> state().
set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
case Timeout of
infinity -> State#{stream_timeout => infinity};
@ -148,12 +150,15 @@ format_error({tls, Reason}) ->
format("TLS failed: ~w", [Reason]);
format_error({auth, Reason}) ->
format("Authentication failed: ~s", [Reason]);
format_error(internal_failure) ->
<<"Internal server error">>;
format_error(Err) ->
format("Unrecognized error: ~w", [Err]).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
-spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore.
init([Mod, SockMod, From, To, Opts]) ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State = #{owner => self(),
@ -183,36 +188,38 @@ init([Mod, SockMod, From, To, Opts]) ->
Err
end.
-spec handle_call(term(), term(), state()) -> noreply().
handle_call(Call, From, #{mod := Mod} = State) ->
noreply(try Mod:handle_call(Call, From, State)
catch _:undef -> State
end).
-spec handle_cast(term(), state()) -> noreply().
handle_cast(connect, #{remote_server := RemoteServer,
sockmod := SockMod,
stream_state := connecting} = State) ->
case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of
false ->
noreply(process_stream_close({error, {idna, bad_string}}, State));
ASCIIName ->
case resolve(binary_to_list(ASCIIName), State) of
{ok, AddrPorts} ->
case connect(AddrPorts, State) of
{ok, Socket, AddrPort} ->
SocketMonitor = SockMod:monitor(Socket),
State1 = State#{ip => AddrPort,
socket => Socket,
socket_monitor => SocketMonitor},
State2 = State1#{stream_state => wait_for_stream},
noreply(send_header(State2));
{error, Why} ->
Err = {error, {socket, Why}},
noreply(process_stream_close(Err, State))
end;
{error, Why} ->
noreply(process_stream_close({error, {dns, Why}}, State))
end
end;
noreply(
case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of
false ->
process_stream_end({idna, bad_string}, State);
ASCIIName ->
case resolve(binary_to_list(ASCIIName), State) of
{ok, AddrPorts} ->
case connect(AddrPorts, State) of
{ok, Socket, AddrPort} ->
SocketMonitor = SockMod:monitor(Socket),
State1 = State#{ip => AddrPort,
socket => Socket,
socket_monitor => SocketMonitor},
State2 = State1#{stream_state => wait_for_stream},
send_header(State2);
{error, Why} ->
process_stream_end({socket, Why}, State)
end;
{error, Why} ->
process_stream_end({dns, Why}, State)
end
end);
handle_cast(connect, State) ->
%% Ignoring connection attempts in other states
noreply(State);
@ -225,66 +232,70 @@ handle_cast(Cast, #{mod := Mod} = State) ->
catch _:undef -> State
end).
-spec handle_info(term(), state()) -> noreply().
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
#{stream_state := wait_for_stream,
xmlns := XMLNS, lang := MyLang} = State) ->
El = #xmlel{name = Name, attrs = Attrs},
try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt ->
noreply(process_stream(Pkt, State));
_ ->
noreply(send_element(State, xmpp:serr_invalid_xml()))
catch _:{xmpp_codec, Why} ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang),
noreply(send_element(State, Err))
end;
noreply(
try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt ->
process_stream(Pkt, State);
_ ->
send_element(State, xmpp:serr_invalid_xml())
catch _:{xmpp_codec, Why} ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang),
send_element(State, Err)
end);
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false ->
Err = case Reason of
<<"XML stanza is too big">> ->
xmpp:serr_policy_violation(Reason, Lang);
_ ->
xmpp:serr_not_well_formed()
end,
noreply(send_element(State1, Err))
end;
noreply(
case is_disconnected(State1) of
true -> State1;
false ->
Err = case Reason of
<<"XML stanza is too big">> ->
xmpp:serr_policy_violation(Reason, Lang);
_ ->
xmpp:serr_not_well_formed()
end,
send_element(State1, Err)
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> noreply(process_element(Pkt, State1))
end
catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, undefined, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
end
end;
noreply(
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_element(Pkt, State1)
end
catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, undefined, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) ->
noreply(try Mod:handle_cdata(Data, State)
catch _:undef -> State
end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({error, {stream, reset}}, State));
noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_close({error, {socket, closed}}, State));
noreply(process_stream_end({socket, closed}, State));
handle_info(timeout, #{mod := Mod} = State) ->
Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State)
@ -295,12 +306,13 @@ handle_info(timeout, #{mod := Mod} = State) ->
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
#{socket_monitor := MRef} = State) ->
noreply(process_stream_close({error, {socket, closed}}, State));
noreply(process_stream_end({socket, closed}, State));
handle_info(Info, #{mod := Mod} = State) ->
noreply(try Mod:handle_info(Info, State)
catch _:undef -> State
end).
-spec terminate(term(), state()) -> any().
terminate(Reason, #{mod := Mod} = State) ->
case get(already_terminated) of
true ->
@ -319,7 +331,7 @@ code_change(OldVsn, #{mod := Mod} = State, Extra) ->
%%%===================================================================
%%% Internal functions
%%%===================================================================
-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
-spec noreply(state()) -> noreply().
noreply(#{stream_timeout := infinity} = State) ->
{noreply, State, infinity};
noreply(#{stream_timeout := {MSecs, OldTime}} = State) ->
@ -335,15 +347,6 @@ new_id() ->
is_disconnected(#{stream_state := StreamState}) ->
StreamState == disconnected.
-spec process_stream_close(stop_reason(), state()) -> state().
process_stream_close(_, #{stream_state := disconnected} = State) ->
State;
process_stream_close(Reason, #{mod := Mod} = State) ->
State1 = send_trailer(State),
try Mod:handle_stream_close(Reason, State1)
catch _:undef -> stop(State1)
end.
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
@ -359,6 +362,8 @@ process_stream(#stream_start{xmlns = XML_NS,
#{xmlns := NS} = State)
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
send_element(State, xmpp:serr_invalid_namespace());
process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
send_element(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang, id = ID,
version = Version} = StreamStart,
#{mod := Mod} = State) ->
@ -370,8 +375,10 @@ process_stream(#stream_start{lang = Lang, id = ID,
true -> State2;
false ->
case Version of
{1,0} -> State2#{stream_state => wait_for_features};
_ -> process_stream_downgrade(StreamStart, State)
{1, _} ->
State2#{stream_state => wait_for_features};
_ ->
process_stream_downgrade(StreamStart, State2)
end
end.
@ -387,7 +394,7 @@ process_element(Pkt, #{stream_state := StateName} = State) ->
#sasl_failure{} when StateName == wait_for_sasl_response ->
process_sasl_failure(Pkt, State);
#stream_error{} ->
process_stream_end({error, {stream, Pkt}}, State);
process_stream_end({stream, Pkt}, State);
_ when is_record(Pkt, stream_features);
is_record(Pkt, starttls_proceed);
is_record(Pkt, starttls);
@ -487,14 +494,23 @@ process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) ->
stream_encrypted => true},
send_header(State1);
{error, Why} ->
process_stream_close({error, {tls, Why}}, State)
process_stream_end({tls, Why}, State)
end.
-spec process_stream_downgrade(stream_start(), state()) -> state().
process_stream_downgrade(StreamStart, #{mod := Mod} = State) ->
try Mod:downgrade_stream(StreamStart, State)
catch _:undef ->
send_element(State, xmpp:serr_unsupported_version())
process_stream_downgrade(StreamStart,
#{mod := Mod, lang := Lang,
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
if not Encrypted and TLSRequired ->
Txt = <<"Use of STARTTLS required">>,
send_element(State, xmpp:err_policy_violation(Txt, Lang));
true ->
State1 = State#{stream_state => downgraded},
try Mod:handle_stream_downgraded(StreamStart, State1)
catch _:undef ->
send_element(State1, xmpp:serr_unsupported_version())
end
end.
-spec process_cert_verification(state()) -> state().
@ -509,7 +525,7 @@ process_cert_verification(#{stream_encrypted := true,
{ok, _} ->
State#{stream_verified => true};
{error, Why, _Peer} ->
process_stream_close({error, {pkix, Why}}, State)
process_stream_end({pkix, Why}, State)
end;
false ->
State#{stream_verified => true}
@ -538,7 +554,7 @@ process_sasl_success(#{mod := Mod,
-spec process_sasl_failure(sasl_failure(), state()) -> state().
process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) ->
try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State)
catch _:undef -> process_stream_close({error, {auth, Reason}}, State)
catch _:undef -> process_stream_end({auth, Reason}, State)
end.
-spec process_packet(xmpp_element(), state()) -> state().
@ -581,7 +597,7 @@ send_header(#{remote_server := RemoteServer,
version = {1,0}}),
case send_text(State, fxml:element_to_header(Header)) of
ok -> State;
{error, Why} -> process_stream_close({error, {socket, Why}}, State)
{error, Why} -> process_stream_end({socket, Why}, State)
end.
-spec send_element(state(), xmpp_element()) -> state().
@ -596,11 +612,11 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
false ->
case send_text(State1, Data) of
_ when is_record(Pkt, stream_error) ->
process_stream_end({error, {stream, Pkt}}, State1);
process_stream_end({stream, Pkt}, State1);
ok ->
State1;
{error, Why} ->
process_stream_close({error, {socket, Why}}, State1)
process_stream_end({socket, Why}, State1)
end
end.
@ -626,7 +642,7 @@ send_text(#{sockmod := SockMod, socket := Socket,
stream_state := StateName}, Data) when StateName /= disconnected ->
SockMod:send(Socket, Data);
send_text(_, _) ->
{error, einval}.
{error, closed}.
-spec send_trailer(state()) -> state().
send_trailer(State) ->