Add xmpp_stream_out behaviour and rewrite s2s/SM code

This commit is contained in:
Evgeniy Khramtsov 2016-12-28 09:47:11 +03:00
parent 7f653cfe76
commit 309bdfbe28
28 changed files with 3690 additions and 2487 deletions

View File

@ -41,8 +41,6 @@
-define(COPYRIGHT, "Copyright (c) 2002-2016 ProcessOne").
-define(S2STIMEOUT, timer:minutes(10)).
%%-define(DBGFSM, true).
-record(scram,

View File

@ -36,13 +36,6 @@
-include("ejabberd.hrl").
-include("logger.hrl").
%%
-export_type([
mechanism/0,
mechanisms/0,
sasl_mechanism/0
]).
-record(sasl_mechanism,
{mechanism = <<"">> :: mechanism() | '$1',
module :: atom(),
@ -51,10 +44,15 @@
-type(mechanism() :: binary()).
-type(mechanisms() :: [mechanism(),...]).
-type(password_type() :: plain | digest | scram).
-type(props() :: [{username, binary()} |
{authzid, binary()} |
{mechanism, binary()} |
{auth_module, atom()}]).
-type sasl_property() :: {username, binary()} |
{authzid, binary()} |
{mechanism, binary()} |
{auth_module, atom()}.
-type sasl_return() :: {ok, [sasl_property()]} |
{ok, [sasl_property()], binary()} |
{continue, binary(), any()} |
{error, atom()} |
{error, atom(), binary()}.
-type(sasl_mechanism() :: #sasl_mechanism{}).
@ -71,14 +69,11 @@
mech_state
}).
-type sasl_state() :: #sasl_state{}.
-export_type([sasl_state/0]).
-export_type([mechanism/0, mechanisms/0, sasl_mechanism/0,
sasl_state/0, sasl_return/0, sasl_property/0]).
-callback mech_new(binary(), fun(), fun(), fun()) -> any().
-callback mech_step(any(), binary()) -> {ok, props()} |
{ok, props(), binary()} |
{continue, binary(), any()} |
{error, atom()} |
{error, atom(), binary()}.
-callback mech_step(any(), binary()) -> sasl_return().
start() ->
ets:new(sasl_mechanism,

View File

@ -169,7 +169,7 @@ broadcast_c2s_shutdown() ->
Children = ejabberd_sm:get_all_pids(),
lists:foreach(
fun(C2SPid) when node(C2SPid) == node() ->
C2SPid ! system_shutdown;
ejabberd_c2s:send(C2SPid, xmpp:serr_system_shutdown());
(_) ->
ok
end, Children).

View File

@ -42,7 +42,7 @@
get_password_s/2, get_password_with_authmodule/2,
is_user_exists/2, is_user_exists_in_other_modules/3,
remove_user/2, remove_user/3, plain_password_required/1,
store_type/1, entropy/1]).
store_type/1, entropy/1, backend_type/1]).
-export([auth_modules/1, opt_type/1]).
@ -412,6 +412,13 @@ entropy(B) ->
length(S) * math:log(lists:sum(Set)) / math:log(2)
end.
-spec backend_type(atom()) -> atom().
backend_type(Mod) ->
case atom_to_list(Mod) of
"ejabberd_auth_" ++ T -> list_to_atom(T);
_ -> Mod
end.
%%%----------------------------------------------------------------------
%%% Internal functions
%%%----------------------------------------------------------------------

View File

@ -22,26 +22,32 @@
-module(ejabberd_c2s).
-behaviour(xmpp_stream_in).
-behaviour(ejabberd_config).
-behaviour(ejabberd_socket).
-protocol({rfc, 6121}).
%% ejabberd_socket callbacks
-export([start/2, socket_type/0]).
-export([start/2, start_link/2, socket_type/0]).
%% ejabberd_config callbacks
-export([opt_type/1, transform_listen_option/2]).
%% xmpp_stream_in callbacks
-export([init/1, handle_call/3, handle_cast/2,
handle_info/2, terminate/2, code_change/3]).
-export([tls_options/1, tls_required/1, compress_methods/1,
sasl_mechanisms/1, init_sasl/1, bind/2, handshake/2,
-export([tls_options/1, tls_required/1, tls_verify/1,
compress_methods/1, bind/2, get_password_fun/1,
check_password_fun/1, check_password_digest_fun/1,
unauthenticated_stream_features/1, authenticated_stream_features/1,
handle_stream_start/1, handle_stream_end/1, handle_stream_close/1,
handle_stream_start/2, handle_stream_end/2, handle_stream_close/2,
handle_unauthenticated_packet/2, handle_authenticated_packet/2,
handle_auth_success/4, handle_auth_failure/4, handle_send/5,
handle_unbinded_packet/2, handle_cdata/2]).
handle_auth_success/4, handle_auth_failure/4, handle_send/3,
handle_recv/3, handle_cdata/2, handle_unbinded_packet/2]).
%% Hooks
-export([handle_unexpected_info/2, handle_unexpected_cast/2,
reject_unauthenticated_packet/2, process_closed/2]).
%% API
-export([get_presence/1, get_subscription/2, get_subscribed/1,
send/2, close/1]).
open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1,
copy_state/2, add_hooks/0]).
-include("ejabberd.hrl").
-include("xmpp.hrl").
@ -49,30 +55,30 @@
-define(SETS, gb_sets).
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
-define(FSMOPTS, [{debug, [trace]}]).
-else.
-define(FSMOPTS, []).
-endif.
-type state() :: map().
-type next_state() :: {noreply, state()} | {stop, term(), state()}.
-export_type([state/0, next_state/0]).
-export_type([state/0]).
%%%===================================================================
%%% ejabberd_socket API
%%%===================================================================
start(SockData, Opts) ->
xmpp_stream_in:start(?MODULE, [SockData, Opts],
fsm_limit_opts(Opts) ++ ?FSMOPTS).
ejabberd_config:fsm_limit_opts(Opts)).
start_link(SockData, Opts) ->
xmpp_stream_in:start_link(?MODULE, [SockData, Opts],
ejabberd_config:fsm_limit_opts(Opts)).
socket_type() ->
xml_stream.
-spec call(pid(), term(), non_neg_integer() | infinity) -> term().
call(Ref, Msg, Timeout) ->
xmpp_stream_in:call(Ref, Msg, Timeout).
-spec get_presence(pid()) -> presence().
get_presence(Ref) ->
xmpp_stream_in:call(Ref, get_presence, 1000).
call(Ref, get_presence, 1000).
-spec get_subscription(jid() | ljid(), state()) -> both | from | to | none.
get_subscription(#jid{} = From, State) ->
@ -90,15 +96,85 @@ get_subscription(LFrom, #{pres_f := PresF, pres_t := PresT}) ->
-spec get_subscribed(pid()) -> [ljid()].
%% Return list of all available resources of contacts
get_subscribed(Ref) ->
xmpp_stream_in:call(Ref, get_subscribed, 1000).
call(Ref, get_subscribed, 1000).
-spec close(pid()) -> ok.
close(Ref) ->
xmpp_stream_in:cast(Ref, closed).
xmpp_stream_in:close(Ref).
-spec send(state(), xmpp_element()) -> next_state().
send(State, Pkt) ->
xmpp_stream_in:send(State, Pkt).
close(Ref, SendTrailer) ->
xmpp_stream_in:close(Ref, SendTrailer).
stop(Ref) ->
xmpp_stream_in:stop(Ref).
-spec send(pid(), xmpp_element()) -> ok;
(state(), xmpp_element()) -> state().
send(Pid, Pkt) when is_pid(Pid) ->
xmpp_stream_in:send(Pid, Pkt);
send(#{lserver := LServer} = State, Pkt) ->
case ejabberd_hooks:run_fold(c2s_filter_send, LServer, Pkt, [State]) of
drop -> State;
Pkt1 -> xmpp_stream_in:send(State, Pkt1)
end.
-spec establish(state()) -> state().
establish(State) ->
xmpp_stream_in:establish(State).
-spec add_hooks() -> ok.
add_hooks() ->
lists:foreach(
fun(Host) ->
ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100),
ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
reject_unauthenticated_packet, 100),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
handle_unexpected_info, 100),
ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE,
handle_unexpected_cast, 100)
end, ?MYHOSTS).
%% Copies content of one c2s state to another.
%% This is needed for session migration from one pid to another.
-spec copy_state(state(), state()) -> state().
copy_state(#{owner := Owner} = NewState,
#{jid := JID, resource := Resource, sid := {Time, _},
auth_module := AuthModule, lserver := LServer,
pres_t := PresT, pres_a := PresA,
pres_f := PresF} = OldState) ->
State1 = case OldState of
#{pres_last := Pres, pres_timestamp := PresTS} ->
NewState#{pres_last => Pres, pres_timestamp => PresTS};
_ ->
NewState
end,
Conn = get_conn_type(State1),
State2 = State1#{jid => JID, resource => Resource,
conn => Conn,
sid => {Time, Owner},
auth_module => AuthModule,
pres_t => PresT, pres_a => PresA,
pres_f => PresF},
ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]).
%%%===================================================================
%%% Hooks
%%%===================================================================
handle_unexpected_info(State, Info) ->
?WARNING_MSG("got unexpected info: ~p", [Info]),
State.
handle_unexpected_cast(State, Msg) ->
?WARNING_MSG("got unexpected cast: ~p", [Msg]),
State.
reject_unauthenticated_packet(State, Pkt) ->
Err = xmpp:err_not_authorized(),
xmpp_stream_in:send_error(State, Pkt, Err).
process_closed(State, _Reason) ->
stop(State).
%%%===================================================================
%%% xmpp_stream_in callbacks
@ -115,128 +191,158 @@ tls_options(#{lserver := LServer, tls_options := TLSOpts}) ->
tls_required(#{tls_required := TLSRequired}) ->
TLSRequired.
tls_verify(#{tls_verify := TLSVerify}) ->
TLSVerify.
compress_methods(#{zlib := true}) ->
[<<"zlib">>];
compress_methods(_) ->
[].
sasl_mechanisms(#{lserver := LServer}) ->
cyrsasl:listmech(LServer).
unauthenticated_stream_features(#{lserver := LServer}) ->
ejabberd_hooks:run_fold(c2s_pre_auth_features, LServer, [], [LServer]).
authenticated_stream_features(#{lserver := LServer}) ->
ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]).
init_sasl(#{lserver := LServer}) ->
cyrsasl:server_new(
<<"jabber">>, LServer, <<"">>, [],
fun(U) ->
ejabberd_auth:get_password_with_authmodule(U, LServer)
end,
fun(U, AuthzId, P) ->
ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P)
end,
fun(U, AuthzId, P, D, DG) ->
ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG)
end).
get_password_fun(#{lserver := LServer}) ->
fun(U) ->
ejabberd_auth:get_password_with_authmodule(U, LServer)
end.
check_password_fun(#{lserver := LServer}) ->
fun(U, AuthzId, P) ->
ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P)
end.
check_password_digest_fun(#{lserver := LServer}) ->
fun(U, AuthzId, P, D, DG) ->
ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG)
end.
bind(<<"">>, State) ->
bind(new_uniq_id(), State);
bind(R, #{user := U, server := S} = State) ->
bind(R, #{user := U, server := S, access := Access, lang := Lang,
lserver := LServer, socket := Socket, ip := IP} = State) ->
case resource_conflict_action(U, S, R) of
closenew ->
{error, xmpp:err_conflict(), State};
{accept_resource, Resource} ->
open_session(State, Resource)
JID = jid:make(U, S, Resource),
case acl:access_matches(Access,
#{usr => jid:split(JID), ip => IP},
LServer) of
allow ->
State1 = open_session(State#{resource => Resource}),
State2 = ejabberd_hooks:run_fold(
c2s_session_opened, LServer, State1, []),
?INFO_MSG("(~s) Opened session for ~s",
[ejabberd_socket:pp(Socket), jid:to_string(JID)]),
{ok, State2};
deny ->
ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]),
?INFO_MSG("(~s) Forbidden session for ~s",
[ejabberd_socket:pp(Socket), jid:to_string(JID)]),
Txt = <<"Denied by ACL">>,
{error, xmpp:err_not_allowed(Txt, Lang), State}
end
end.
handshake(_Data, State) ->
%% This is only for jabber component
{ok, State}.
-spec open_session(state()) -> {ok, state()} | state().
open_session(#{user := U, server := S, resource := R,
sid := SID, ip := IP, auth_module := AuthModule} = State) ->
JID = jid:make(U, S, R),
change_shaper(State),
Conn = get_conn_type(State),
State1 = State#{conn => Conn, resource => R, jid => JID},
Prio = try maps:get(pres_last, State) of
Pres -> get_priority_from_presence(Pres)
catch _:{badkey, _} ->
undefined
end,
Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}],
ejabberd_sm:open_session(SID, U, S, R, Prio, Info),
State1.
handle_stream_start(#{lserver := LServer, ip := IP, lang := Lang} = State) ->
handle_stream_start(StreamStart,
#{lserver := LServer, ip := IP, lang := Lang} = State) ->
case lists:member(LServer, ?MYHOSTS) of
false ->
xmpp_stream_in:send(State, xmpp:serr_host_unknown());
send(State, xmpp:serr_host_unknown());
true ->
case check_bl_c2s(IP, Lang) of
false ->
change_shaper(State),
{noreply, State};
ejabberd_hooks:run_fold(
c2s_stream_started, LServer, State, [StreamStart]);
{true, LogReason, ReasonT} ->
?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s",
[jlib:ip_to_list(IP), LogReason]),
Err = xmpp:serr_policy_violation(ReasonT, Lang),
xmpp_stream_in:send(State, Err)
send(State, Err)
end
end.
handle_stream_end(State) ->
{stop, normal, State}.
handle_stream_end(Reason, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_closed, LServer, State, [Reason]).
handle_stream_close(State) ->
{stop, normal, State}.
handle_stream_close(_Reason, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_closed, LServer, State, [normal]).
handle_auth_success(User, Mech, AuthModule,
#{socket := Socket, ip := IP, lserver := LServer} = State) ->
?INFO_MSG("(~w) Accepted ~s authentication for ~s@~s by ~p from ~s",
[Socket, Mech, User, LServer, AuthModule,
?INFO_MSG("(~s) Accepted c2s ~s authentication for ~s@~s by ~s backend from ~s",
[ejabberd_socket:pp(Socket), Mech, User, LServer,
ejabberd_auth:backend_type(AuthModule),
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
State1 = State#{auth_module => AuthModule},
ejabberd_hooks:run_fold(c2s_auth_result, LServer,
{noreply, State1}, [true, User]).
State1, [true, User]).
handle_auth_failure(User, Mech, Reason,
#{socket := Socket, ip := IP, lserver := LServer} = State) ->
?INFO_MSG("(~w) Failed ~s authentication ~sfrom ~s: ~s",
[Socket, Mech,
?INFO_MSG("(~s) Failed c2s ~s authentication ~sfrom ~s: ~s",
[ejabberd_socket:pp(Socket), Mech,
if User /= <<"">> -> ["for ", User, "@", LServer, " "];
true -> ""
end,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
ejabberd_hooks:run_fold(c2s_auth_result, LServer,
{noreply, State}, [false, User]).
State, [false, User]).
handle_unbinded_packet(Pkt, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer,
{noreply, State}, [Pkt]).
State, [Pkt]).
handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_unauthenticated_packet,
LServer, {noreply, State}, [Pkt]).
LServer, State, [Pkt]).
handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) ->
ejabberd_hooks:run_fold(c2s_authenticated_packet,
LServer, {noreply, State}, [Pkt]);
LServer, State, [Pkt]);
handle_authenticated_packet(Pkt, #{lserver := LServer} = State) ->
case ejabberd_hooks:run_fold(c2s_authenticated_packet,
LServer, {noreply, State}, [Pkt]) of
{noreply, State1} ->
Pkt1 = ejabberd_hooks:run_fold(user_send_packet, LServer, Pkt, [State1]),
Res = case Pkt1 of
#presence{to = #jid{lresource = <<"">>}} ->
process_self_presence(State1, Pkt1);
#presence{} ->
process_presence_out(State1, Pkt1);
_ ->
check_privacy_then_route(State1, Pkt1)
end,
ejabberd_hooks:run(c2s_loop_debug, [{xmlstreamelement, Pkt}]),
Res;
Err ->
ejabberd_hooks:run(c2s_loop_debug, [{xmlstreamelement, Pkt}]),
Err
State1 = ejabberd_hooks:run_fold(c2s_authenticated_packet,
LServer, State, [Pkt]),
Pkt1 = ejabberd_hooks:run_fold(user_send_packet, LServer, Pkt, [State1]),
case Pkt1 of
#presence{to = #jid{lresource = <<"">>}} ->
process_self_presence(State1, Pkt1);
#presence{} ->
process_presence_out(State1, Pkt1);
_ ->
check_privacy_then_route(State1, Pkt1)
end.
handle_cdata(Data, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_cdata, LServer,
{noreply, State}, [Data]).
State, [Data]).
handle_send(Reason, Pkt, El, Data, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_send, LServer,
{noreply, State}, [Reason, Pkt, El, Data]).
handle_recv(El, Pkt, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_recv, LServer, State, [El, Pkt]).
handle_send(Pkt, Result, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_send, LServer, State, [Pkt, Result]).
init([State, Opts]) ->
Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all),
@ -262,15 +368,13 @@ init([State, Opts]) ->
server => ?MYNAME,
access => Access,
shaper => Shaper},
ejabberd_hooks:run_fold(c2s_init, {ok, State1}, []).
ejabberd_hooks:run_fold(c2s_init, {ok, State1}, [Opts]).
handle_call(get_presence, _From, #{jid := JID} = State) ->
Pres = case maps:get(pres_last, State, undefined) of
undefined ->
Pres = try maps:get(pres_last, State)
catch _:{badkey, _} ->
BareJID = jid:remove_resource(JID),
#presence{from = JID, to = BareJID, type = unavailable};
P ->
P
#presence{from = JID, to = BareJID, type = unavailable}
end,
{reply, Pres, State};
handle_call(get_subscribed, _From, #{pres_f := PresF} = State) ->
@ -278,12 +382,10 @@ handle_call(get_subscribed, _From, #{pres_f := PresF} = State) ->
{reply, Subscribed, State};
handle_call(Request, From, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(
c2s_handle_call, LServer, {noreply, State}, [Request, From]).
c2s_handle_call, LServer, State, [Request, From]).
handle_cast(closed, State) ->
handle_stream_close(State);
handle_cast(Msg, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_cast, LServer, {noreply, State}, [Msg]).
ejabberd_hooks:run_fold(c2s_handle_cast, LServer, State, [Msg]).
handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) ->
Packet = xmpp:set_from_to(Packet0, From, To),
@ -299,15 +401,13 @@ handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) ->
Packet1 = ejabberd_hooks:run_fold(
user_receive_packet, LServer, Packet, [NewState]),
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
xmpp_stream_in:send(NewState, Packet1);
send(NewState, Packet1);
true ->
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
{noreply, NewState}
NewState
end;
handle_info(system_shutdown, State) ->
xmpp_stream_in:send(State, xmpp:serr_system_shutdown());
handle_info(Info, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_info, LServer, {noreply, State}, [Info]).
ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]).
terminate(_Reason, _State) ->
ok.
@ -323,33 +423,6 @@ code_change(_OldVsn, State, _Extra) ->
check_bl_c2s({IP, _Port}, Lang) ->
ejabberd_hooks:run_fold(check_bl_c2s, false, [IP, Lang]).
-spec open_session(state(), binary()) -> {ok, state()} | {error, stanza_error(), state()}.
open_session(#{user := U, server := S, lserver := LServer, sid := SID,
socket := Socket, ip := IP, auth_module := AuthMod,
access := Access, lang := Lang} = State, R) ->
JID = jid:make(U, S, R),
case acl:access_matches(Access,
#{usr => jid:split(JID), ip => IP},
LServer) of
allow ->
?INFO_MSG("(~w) Opened session for ~s",
[Socket, jid:to_string(JID)]),
change_shaper(State),
Conn = get_conn_type(State),
Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}],
ejabberd_sm:open_session(SID, U, LServer, R, Info),
State1 = State#{conn => Conn, resource => R, jid => JID},
State2 = ejabberd_hooks:run_fold(
c2s_session_opened, LServer, State1, []),
{ok, State2};
deny ->
ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]),
?INFO_MSG("(~w) Forbidden session for ~s",
[Socket, jid:to_string(JID)]),
Txt = <<"Denied by ACL">>,
{error, xmpp:err_not_allowed(Txt, Lang), State}
end.
-spec process_iq_in(state(), iq()) -> {boolean(), state()}.
process_iq_in(State, #iq{} = IQ) ->
case privacy_check_packet(State, IQ, in) of
@ -433,7 +506,7 @@ route_probe_reply(From, To, #{lserver := LServer, pres_f := PresF,
route_probe_reply(_, _, _) ->
ok.
-spec process_presence_out(state(), presence()) -> next_state().
-spec process_presence_out(state(), presence()) -> state().
process_presence_out(#{user := User, server := Server, lserver := LServer,
jid := JID, lang := Lang, pres_a := PresA} = State,
#presence{from = From, to = To, type = Type} = Pres) ->
@ -461,21 +534,21 @@ process_presence_out(#{user := User, server := Server, lserver := LServer,
[User, Server, To, Type]),
BareFrom = jid:remove_resource(From),
route(xmpp:set_from_to(Pres, BareFrom, To)),
{noreply, State}
State
end;
allow when Type == error; Type == probe ->
route(Pres),
{noreply, State};
State;
allow ->
route(Pres),
A = case Type of
available -> ?SETS:add_element(LTo, PresA);
unavailable -> ?SETS:del_element(LTo, PresA)
end,
{noreply, State#{pres_a => A}}
State#{pres_a => A}
end.
-spec process_self_presence(state(), presence()) -> {noreply, state()}.
-spec process_self_presence(state(), presence()) -> state().
process_self_presence(#{ip := IP, conn := Conn,
auth_module := AuthMod, sid := SID,
user := U, server := S, resource := R} = State,
@ -484,8 +557,7 @@ process_self_presence(#{ip := IP, conn := Conn,
Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}],
ejabberd_sm:unset_presence(SID, U, S, R, Status, Info),
State1 = broadcast_presence_unavailable(State, Pres),
State2 = maps:remove(pres_last, maps:remove(pres_timestamp, State1)),
{noreply, State2};
maps:remove(pres_last, maps:remove(pres_timestamp, State1));
process_self_presence(#{lserver := LServer} = State,
#presence{type = available} = Pres) ->
PreviousPres = maps:get(pres_last, State, undefined),
@ -494,10 +566,9 @@ process_self_presence(#{lserver := LServer} = State,
State2 = State1#{pres_last => Pres,
pres_timestamp => p1_time_compat:timestamp()},
FromUnavailable = PreviousPres == undefined,
State3 = broadcast_presence_available(State2, Pres, FromUnavailable),
{noreply, State3};
broadcast_presence_available(State2, Pres, FromUnavailable);
process_self_presence(State, _Pres) ->
{noreply, State}.
State.
-spec update_priority(state(), presence()) -> ok.
update_priority(#{ip := IP, conn := Conn, auth_module := AuthMod,
@ -529,7 +600,7 @@ broadcast_presence_available(#{pres_a := PresA, pres_f := PresF} = State,
route_multiple(State, JIDs, Pres),
State.
-spec check_privacy_then_route(state(), stanza()) -> next_state().
-spec check_privacy_then_route(state(), stanza()) -> state().
check_privacy_then_route(#{lang := Lang} = State, Pkt) ->
case privacy_check_packet(State, Pkt, out) of
deny ->
@ -539,7 +610,7 @@ check_privacy_then_route(#{lang := Lang} = State, Pkt) ->
xmpp_stream_in:send_error(State, Pkt, Err);
allow ->
route(Pkt),
{noreply, State}
State
end.
-spec privacy_check_packet(state(), stanza(), in | out) -> allow | deny.
@ -664,25 +735,10 @@ do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) ->
end
end.
-spec fsm_limit_opts([proplists:property()]) -> [proplists:property()].
fsm_limit_opts(Opts) ->
case lists:keysearch(max_fsm_queue, 1, Opts) of
{value, {_, N}} when is_integer(N) -> [{max_queue, N}];
_ ->
case ejabberd_config:get_option(
max_fsm_queue,
fun(I) when is_integer(I), I > 0 -> I end) of
undefined -> [];
N -> [{max_queue, N}]
end
end.
transform_listen_option(Opt, Opts) ->
[Opt|Opts].
opt_type(domain_certfile) -> fun iolist_to_binary/1;
opt_type(max_fsm_queue) ->
fun (I) when is_integer(I), I > 0 -> I end;
opt_type(resource_conflict) ->
fun (setresource) -> setresource;
(closeold) -> closeold;
@ -690,4 +746,4 @@ opt_type(resource_conflict) ->
(acceptnew) -> acceptnew
end;
opt_type(_) ->
[domain_certfile, max_fsm_queue, resource_conflict].
[domain_certfile, resource_conflict].

View File

@ -38,7 +38,8 @@
transform_options/1, collect_options/1, default_db/2,
convert_to_yaml/1, convert_to_yaml/2, v_db/2,
env_binary_to_list/2, opt_type/1, may_hide_data/1,
is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1]).
is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1,
fsm_limit_opts/1]).
-export([start/2]).
@ -1403,6 +1404,8 @@ opt_type(hosts) ->
end;
opt_type(language) ->
fun iolist_to_binary/1;
opt_type(max_fsm_queue) ->
fun (I) when is_integer(I), I > 0 -> I end;
opt_type(_) ->
[hide_sensitive_log_data, hosts, language].
@ -1421,3 +1424,17 @@ may_hide_data(Data) ->
true ->
"hidden_by_ejabberd"
end.
-spec fsm_limit_opts([proplists:property()]) -> [{max_queue, pos_integer()}].
fsm_limit_opts(Opts) ->
case lists:keyfind(max_fsm_queue, 1, Opts) of
{_, I} when is_integer(I), I>0 ->
[{max_queue, I}];
false ->
case get_option(
max_fsm_queue,
fun(I) when is_integer(I), I>0 -> I end) of
undefined -> [];
N -> [{max_queue, N}]
end
end.

View File

@ -376,8 +376,11 @@ run_fold1([{_Seq, Module, Function} | Ls], Hook, Val, Args) ->
end.
safe_apply(Module, Function, Args) ->
if is_function(Function) ->
catch apply(Function, Args);
true ->
catch apply(Module, Function, Args)
try if is_function(Function) ->
apply(Function, Args);
true ->
apply(Module, Function, Args)
end
catch E:R when E /= exit, R /= normal ->
{'EXIT', {E, {R, erlang:get_stacktrace()}}}
end.

View File

@ -330,9 +330,9 @@ accept(ListenSocket, Module, Opts, Interval) ->
{ok, Socket} ->
case {inet:sockname(Socket), inet:peername(Socket)} of
{{ok, {Addr, Port}}, {ok, {PAddr, PPort}}} ->
?INFO_MSG("(~w) Accepted connection ~s:~p -> ~s:~p",
[Socket, ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)), PPort,
inet_parse:ntoa(Addr), Port]);
?INFO_MSG("Accepted connection ~s:~p -> ~s:~p",
[ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)),
PPort, inet_parse:ntoa(Addr), Port]);
_ ->
ok
end,

View File

@ -43,7 +43,9 @@
unregister_route/1,
unregister_routes/1,
dirty_get_all_routes/0,
dirty_get_all_domains/0
dirty_get_all_domains/0,
is_my_route/1,
is_my_host/1
]).
-export([start_link/0]).
@ -110,12 +112,12 @@ register_route(Domain) ->
[?MODULE, ?MODULE]),
register_route(Domain, ?MYNAME).
-spec register_route(binary(), binary()) -> term().
-spec register_route(binary(), binary()) -> ok.
register_route(Domain, ServerHost) ->
register_route(Domain, ServerHost, undefined).
-spec register_route(binary(), binary(), local_hint()) -> term().
-spec register_route(binary(), binary(), local_hint()) -> ok.
register_route(Domain, ServerHost, LocalHint) ->
case {jid:nameprep(Domain), jid:nameprep(ServerHost)} of
@ -165,6 +167,11 @@ register_route(Domain, ServerHost, LocalHint) ->
end
end,
mnesia:transaction(F)
end,
if LocalHint == undefined ->
?INFO_MSG("Route registered: ~s", [LDomain]);
true ->
ok
end
end.
@ -175,7 +182,7 @@ register_routes(Domains) ->
end,
Domains).
-spec unregister_route(binary()) -> term().
-spec unregister_route(binary()) -> ok.
unregister_route(Domain) ->
case jid:nameprep(Domain) of
@ -210,7 +217,8 @@ unregister_route(Domain) ->
end
end,
mnesia:transaction(F)
end
end,
?INFO_MSG("Route unregistered: ~s", [LDomain])
end.
-spec unregister_routes([binary()]) -> ok.
@ -245,6 +253,29 @@ host_of_route(Domain) ->
end
end.
-spec is_my_route(binary()) -> boolean().
is_my_route(Domain) ->
case jid:nameprep(Domain) of
error ->
erlang:error({invalid_domain, Domain});
LDomain ->
mnesia:dirty_read(route, LDomain) /= []
end.
-spec is_my_host(binary()) -> boolean().
is_my_host(Domain) ->
case jid:nameprep(Domain) of
error ->
erlang:error({invalid_domain, Domain});
LDomain ->
case mnesia:dirty_read(route, LDomain) of
[#route{server_host = Host}|_] ->
Host == LDomain;
[] ->
false
end
end.
-spec process_iq(jid(), jid(), iq() | xmlel()) -> any().
process_iq(From, To, #iq{} = IQ) ->
if To#jid.luser == <<"">> ->

View File

@ -35,16 +35,16 @@
%% API
-export([start_link/0, route/3, have_connection/1,
make_key/2, get_connections_pids/1, try_register/1,
remove_connection/2, find_connection/2,
get_connections_pids/1, try_register/1,
remove_connection/2, start_connection/2, start_connection/3,
dirty_get_connections/0, allow_host/2,
incoming_s2s_number/0, outgoing_s2s_number/0,
stop_all_connections/0,
clean_temporarily_blocked_table/0,
list_temporarily_blocked_hosts/0,
external_host_overloaded/1, is_temporarly_blocked/1,
check_peer_certificate/3,
get_commands_spec/0]).
get_commands_spec/0, zlib_enabled/1, get_idle_timeout/1,
tls_required/1, tls_verify/1, tls_enabled/1, tls_options/2]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2,
@ -196,39 +196,94 @@ try_register(FromTo) ->
dirty_get_connections() ->
mnesia:dirty_all_keys(s2s).
check_peer_certificate(SockMod, Sock, Peer) ->
case SockMod:get_peer_certificate(Sock) of
{ok, Cert} ->
case SockMod:get_verify_result(Sock) of
0 ->
case ejabberd_idna:domain_utf8_to_ascii(Peer) of
false ->
{error, <<"Cannot decode remote server name">>};
AsciiPeer ->
case
lists:any(fun(D) -> match_domain(AsciiPeer, D) end,
get_cert_domains(Cert)) of
true ->
{ok, <<"Verification successful">>};
false ->
{error, <<"Certificate host name mismatch">>}
end
end;
VerifyRes ->
{error, fast_tls:get_cert_verify_string(VerifyRes, Cert)}
end;
{error, _Reason} ->
{error, <<"Cannot get peer certificate">>};
error ->
{error, <<"Cannot get peer certificate">>}
-spec tls_options(binary(), [proplists:property()]) -> [proplists:property()].
tls_options(LServer, DefaultOpts) ->
TLSOpts1 = case ejabberd_config:get_option(
{s2s_certfile, LServer},
fun iolist_to_binary/1,
ejabberd_config:get_option(
{domain_certfile, LServer},
fun iolist_to_binary/1)) of
undefined -> [];
CertFile -> lists:keystore(certfile, 1, DefaultOpts,
{certfile, CertFile})
end,
TLSOpts2 = case ejabberd_config:get_option(
{s2s_ciphers, LServer},
fun iolist_to_binary/1) of
undefined -> TLSOpts1;
Ciphers -> lists:keystore(ciphers, 1, TLSOpts1,
{ciphers, Ciphers})
end,
TLSOpts3 = case ejabberd_config:get_option(
{s2s_protocol_options, LServer},
fun (Options) -> str:join(Options, <<$|>>) end) of
undefined -> TLSOpts2;
ProtoOpts -> lists:keystore(protocol_options, 1, TLSOpts2,
{protocol_options, ProtoOpts})
end,
TLSOpts4 = case ejabberd_config:get_option(
{s2s_dhfile, LServer},
fun iolist_to_binary/1) of
undefined -> TLSOpts3;
DHFile -> lists:keystore(dhfile, 1, TLSOpts3,
{dhfile, DHFile})
end,
TLSOpts5 = case ejabberd_config:get_option(
{s2s_cafile, LServer},
fun iolist_to_binary/1) of
undefined -> TLSOpts4;
CAFile -> lists:keystore(cafile, 1, TLSOpts4,
{cafile, CAFile})
end,
case ejabberd_config:get_option(
{s2s_tls_compression, LServer},
fun(B) when is_boolean(B) -> B end) of
undefined -> TLSOpts5;
false -> [compression_none | TLSOpts5];
true -> lists:delete(compression_none, TLSOpts5)
end.
-spec make_key({binary(), binary()}, binary()) -> binary().
make_key({From, To}, StreamID) ->
Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end),
p1_sha:to_hexlist(
crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)),
[To, " ", From, " ", StreamID])).
-spec tls_required(binary()) -> boolean().
tls_required(LServer) ->
TLS = use_starttls(LServer),
TLS == required orelse TLS == required_trusted.
-spec tls_verify(binary()) -> boolean().
tls_verify(LServer) ->
TLS = use_starttls(LServer),
TLS == required_trusted.
-spec tls_enabled(binary()) -> boolean().
tls_enabled(LServer) ->
TLS = use_starttls(LServer),
TLS == true orelse TLS == optional.
-spec zlib_enabled(binary()) -> boolean().
zlib_enabled(LServer) ->
ejabberd_config:get_option(
{s2s_zlib, LServer},
fun(B) when is_boolean(B) -> B end,
false).
-spec use_starttls(binary()) -> boolean() | optional | required | required_trusted.
use_starttls(LServer) ->
ejabberd_config:get_option(
{s2s_use_starttls, LServer},
fun(true) -> true;
(false) -> false;
(optional) -> optional;
(required) -> required;
(required_trusted) -> required_trusted
end, false).
-spec get_idle_timeout(binary()) -> non_neg_integer() | infinity.
get_idle_timeout(LServer) ->
ejabberd_config:get_option(
{s2s_timeout, LServer},
fun(I) when is_integer(I), I >= 0 -> timer:seconds(I);
(infinity) -> infinity
end, timer:minutes(10)).
%%====================================================================
%% gen_server callbacks
@ -246,6 +301,8 @@ init([]) ->
ejabberd_mnesia:create(?MODULE, temporarily_blocked,
[{ram_copies, [node()]},
{attributes, record_info(fields, temporarily_blocked)}]),
ejabberd_s2s_in:add_hooks(),
ejabberd_s2s_out:add_hooks(),
{ok, #state{}}.
handle_call(_Request, _From, State) ->
@ -291,30 +348,36 @@ clean_table_from_bad_node(Node) ->
end,
mnesia:async_dirty(F).
-spec do_route(jid(), jid(), stanza()) -> ok | false.
-spec do_route(jid(), jid(), stanza()) -> ok.
do_route(From, To, Packet) ->
?DEBUG("s2s manager~n\tfrom ~p~n\tto ~p~n\tpacket "
"~P~n",
[From, To, Packet, 8]),
case find_connection(From, To) of
{atomic, Pid} when is_pid(Pid) ->
?DEBUG("sending to process ~p~n", [Pid]),
#jid{lserver = MyServer} = From,
ejabberd_hooks:run(s2s_send_packet, MyServer,
[From, To, Packet]),
send_element(Pid, xmpp:set_from_to(Packet, From, To)),
ok;
{aborted, _Reason} ->
Lang = xmpp:get_lang(Packet),
Txt = <<"No s2s connection found">>,
Err = xmpp:err_service_unavailable(Txt, Lang),
ejabberd_router:route_error(To, From, Packet, Err),
false
case start_connection(From, To) of
{ok, Pid} when is_pid(Pid) ->
?DEBUG("sending to process ~p~n", [Pid]),
#jid{lserver = MyServer} = From,
ejabberd_hooks:run(s2s_send_packet, MyServer, [From, To, Packet]),
ejabberd_s2s_out:route(Pid, xmpp:set_from_to(Packet, From, To));
{error, Reason} ->
Err = case Reason of
forbidden ->
Lang = xmpp:get_lang(Packet),
xmpp:err_forbidden(<<"Denied by ACL">>, Lang);
internal_server_error ->
xmpp:err_internal_server_error()
end,
ejabberd_router:route_error(To, From, Packet, Err)
end.
-spec find_connection(jid(), jid()) -> {aborted, any()} | {atomic, pid()}.
-spec start_connection(jid(), jid()) -> {ok, pid()} |
{error, forbidden | internal_server_error}.
start_connection(From, To) ->
start_connection(From, To, []).
find_connection(From, To) ->
-spec start_connection(jid(), jid(), [proplists:property()])
-> {ok, pid()} | {error, forbidden | internal_server_error}.
start_connection(From, To, Opts) ->
#jid{lserver = MyServer} = From,
#jid{lserver = Server} = To,
FromTo = {MyServer, Server},
@ -323,15 +386,13 @@ find_connection(From, To) ->
MaxS2SConnectionsNumberPerNode =
max_s2s_connections_number_per_node(FromTo),
?DEBUG("Finding connection for ~p~n", [FromTo]),
case catch mnesia:dirty_read(s2s, FromTo) of
{'EXIT', Reason} -> {aborted, Reason};
case mnesia:dirty_read(s2s, FromTo) of
[] ->
%% We try to establish all the connections if the host is not a
%% service and if the s2s host is not blacklisted or
%% is in whitelist:
case not is_service(From, To) andalso
allow_host(MyServer, Server)
of
LServer = ejabberd_router:host_of_route(MyServer),
case not is_service(From, To) andalso allow_host(LServer, Server) of
true ->
NeededConnections = needed_connections_number([],
MaxS2SConnectionsNumber,
@ -339,8 +400,8 @@ find_connection(From, To) ->
open_several_connections(NeededConnections, MyServer,
Server, From, FromTo,
MaxS2SConnectionsNumber,
MaxS2SConnectionsNumberPerNode);
false -> {aborted, error}
MaxS2SConnectionsNumberPerNode, Opts);
false -> {error, forbidden}
end;
L when is_list(L) ->
NeededConnections = needed_connections_number(L,
@ -351,10 +412,10 @@ find_connection(From, To) ->
open_several_connections(NeededConnections, MyServer,
Server, From, FromTo,
MaxS2SConnectionsNumber,
MaxS2SConnectionsNumberPerNode);
MaxS2SConnectionsNumberPerNode, Opts);
true ->
%% We choose a connexion from the pool of opened ones.
{atomic, choose_connection(From, L)}
{ok, choose_connection(From, L)}
end
end.
@ -377,20 +438,22 @@ choose_pid(From, Pids) ->
open_several_connections(N, MyServer, Server, From,
FromTo, MaxS2SConnectionsNumber,
MaxS2SConnectionsNumberPerNode) ->
ConnectionsResult = [new_connection(MyServer, Server,
From, FromTo, MaxS2SConnectionsNumber,
MaxS2SConnectionsNumberPerNode)
|| _N <- lists:seq(1, N)],
case [PID || {atomic, PID} <- ConnectionsResult] of
[] -> hd(ConnectionsResult);
PIDs -> {atomic, choose_pid(From, PIDs)}
MaxS2SConnectionsNumberPerNode, Opts) ->
case lists:flatmap(
fun(_) ->
new_connection(MyServer, Server,
From, FromTo, MaxS2SConnectionsNumber,
MaxS2SConnectionsNumberPerNode, Opts)
end, lists:seq(1, N)) of
[] ->
{error, internal_server_error};
PIDs ->
{ok, choose_pid(From, PIDs)}
end.
new_connection(MyServer, Server, From, FromTo,
MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode) ->
{ok, Pid} = ejabberd_s2s_out:start(
MyServer, Server, new),
MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) ->
{ok, Pid} = ejabberd_s2s_out:start(MyServer, Server, Opts),
F = fun() ->
L = mnesia:read({s2s, FromTo}),
NeededConnections = needed_connections_number(L,
@ -398,17 +461,21 @@ new_connection(MyServer, Server, From, FromTo,
MaxS2SConnectionsNumberPerNode),
if NeededConnections > 0 ->
mnesia:write(#s2s{fromto = FromTo, pid = Pid}),
?INFO_MSG("New s2s connection started ~p", [Pid]),
Pid;
true -> choose_connection(From, L)
end
end,
TRes = mnesia:transaction(F),
case TRes of
{atomic, Pid} -> ejabberd_s2s_out:start_connection(Pid);
_ -> ejabberd_s2s_out:stop_connection(Pid)
end,
TRes.
{atomic, Pid} ->
ejabberd_s2s_out:connect(Pid),
[Pid];
{aborted, Reason} ->
?ERROR_MSG("failed to register connection ~s -> ~s: ~p",
[MyServer, Server, Reason]),
ejabberd_s2s_out:stop(Pid),
[]
end.
-spec max_s2s_connections_number({binary(), binary()}) -> integer().
max_s2s_connections_number({From, To}) ->
@ -459,9 +526,6 @@ parent_domains(Domain) ->
end,
[], lists:reverse(str:tokens(Domain, <<".">>))).
send_element(Pid, El) ->
Pid ! {send_element, El}.
%%%----------------------------------------------------------------------
%%% ejabberd commands
@ -536,24 +600,13 @@ update_tables() ->
%% Check if host is in blacklist or white list
allow_host(MyServer, S2SHost) ->
allow_host2(MyServer, S2SHost) andalso
allow_host1(MyServer, S2SHost) andalso
not is_temporarly_blocked(S2SHost).
allow_host2(MyServer, S2SHost) ->
Hosts = (?MYHOSTS),
case lists:dropwhile(fun (ParentDomain) ->
not lists:member(ParentDomain, Hosts)
end,
parent_domains(MyServer))
of
[MyHost | _] -> allow_host1(MyHost, S2SHost);
[] -> allow_host1(MyServer, S2SHost)
end.
allow_host1(MyHost, S2SHost) ->
Rule = ejabberd_config:get_option(
s2s_access,
fun(A) -> A end,
{s2s_access, MyHost},
fun acl:access_rules_validator/1,
all),
JID = jid:make(S2SHost),
case acl:match_rule(MyHost, Rule, JID) of
@ -624,133 +677,34 @@ get_s2s_state(S2sPid) ->
end,
[{s2s_pid, S2sPid} | Infos].
get_cert_domains(Cert) ->
TBSCert = Cert#'Certificate'.tbsCertificate,
Subject = case TBSCert#'TBSCertificate'.subject of
{rdnSequence, Subj} -> lists:flatten(Subj);
_ -> []
end,
Extensions = case TBSCert#'TBSCertificate'.extensions of
Exts when is_list(Exts) -> Exts;
_ -> []
end,
lists:flatmap(fun (#'AttributeTypeAndValue'{type =
?'id-at-commonName',
value = Val}) ->
case 'OTP-PUB-KEY':decode('X520CommonName', Val) of
{ok, {_, D1}} ->
D = if is_binary(D1) -> D1;
is_list(D1) -> list_to_binary(D1);
true -> error
end,
if D /= error ->
case jid:from_string(D) of
#jid{luser = <<"">>, lserver = LD,
lresource = <<"">>} ->
[LD];
_ -> []
end;
true -> []
end;
_ -> []
end;
(_) -> []
end,
Subject)
++
lists:flatmap(fun (#'Extension'{extnID =
?'id-ce-subjectAltName',
extnValue = Val}) ->
BVal = if is_list(Val) -> list_to_binary(Val);
true -> Val
end,
case 'OTP-PUB-KEY':decode('SubjectAltName', BVal)
of
{ok, SANs} ->
lists:flatmap(fun ({otherName,
#'AnotherName'{'type-id' =
?'id-on-xmppAddr',
value =
XmppAddr}}) ->
case
'XmppAddr':decode('XmppAddr',
XmppAddr)
of
{ok, D}
when
is_binary(D) ->
case
jid:from_string((D))
of
#jid{luser =
<<"">>,
lserver =
LD,
lresource =
<<"">>} ->
case
ejabberd_idna:domain_utf8_to_ascii(LD)
of
false ->
[];
PCLD ->
[PCLD]
end;
_ -> []
end;
_ -> []
end;
({dNSName, D})
when is_list(D) ->
case
jid:from_string(list_to_binary(D))
of
#jid{luser = <<"">>,
lserver = LD,
lresource =
<<"">>} ->
[LD];
_ -> []
end;
(_) -> []
end,
SANs);
_ -> []
end;
(_) -> []
end,
Extensions).
match_domain(Domain, Domain) -> true;
match_domain(Domain, Pattern) ->
DLabels = str:tokens(Domain, <<".">>),
PLabels = str:tokens(Pattern, <<".">>),
match_labels(DLabels, PLabels).
match_labels([], []) -> true;
match_labels([], [_ | _]) -> false;
match_labels([_ | _], []) -> false;
match_labels([DL | DLabels], [PL | PLabels]) ->
case lists:all(fun (C) ->
$a =< C andalso C =< $z orelse
$0 =< C andalso C =< $9 orelse
C == $- orelse C == $*
end,
binary_to_list(PL))
of
true ->
Regexp = ejabberd_regexp:sh_to_awk(PL),
case ejabberd_regexp:run(DL, Regexp) of
match -> match_labels(DLabels, PLabels);
nomatch -> false
end;
false -> false
end.
opt_type(route_subdomains) ->
fun (s2s) -> s2s;
(local) -> local
end;
opt_type(s2s_access) ->
fun acl:access_rules_validator/1;
opt_type(_) -> [route_subdomains, s2s_access].
opt_type(domain_certfile) -> fun iolist_to_binary/1;
opt_type(s2s_certfile) -> fun iolist_to_binary/1;
opt_type(s2s_ciphers) -> fun iolist_to_binary/1;
opt_type(s2s_dhfile) -> fun iolist_to_binary/1;
opt_type(s2s_protocol_options) ->
fun (Options) -> str:join(Options, <<"|">>) end;
opt_type(s2s_tls_compression) ->
fun (true) -> true;
(false) -> false
end;
opt_type(s2s_use_starttls) ->
fun (true) -> true;
(false) -> false;
(optional) -> optional;
(required) -> required;
(required_trusted) -> required_trusted
end;
opt_type(s2s_timeout) ->
fun(I) when is_integer(I), I>=0 -> I;
(infinity) -> infinity
end;
opt_type(_) ->
[route_subdomains, s2s_access, s2s_certfile,
s2s_ciphers, s2s_dhfile, s2s_protocol_options,
s2s_tls_compression, s2s_use_starttls, s2s_timeout].

View File

@ -1,8 +1,5 @@
%%%----------------------------------------------------------------------
%%% File : ejabberd_s2s_in.erl
%%% Author : Alexey Shchepin <alexey@process-one.net>
%%% Purpose : Serve incoming s2s connection
%%% Created : 6 Dec 2002 by Alexey Shchepin <alexey@process-one.net>
%%%-------------------------------------------------------------------
%%% Created : 12 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
%%%
%%%
%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
@ -21,645 +18,280 @@
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%----------------------------------------------------------------------
%%%-------------------------------------------------------------------
-module(ejabberd_s2s_in).
-behaviour(xmpp_stream_in).
-behaviour(ejabberd_config).
-behaviour(ejabberd_socket).
-author('alexey@process-one.net').
-behaviour(p1_fsm).
%% External exports
%% ejabberd_socket callbacks
-export([start/2, start_link/2, socket_type/0]).
-export([init/1, wait_for_stream/2,
wait_for_feature_request/2, stream_established/2,
handle_event/3, handle_sync_event/4, code_change/4,
handle_info/3, print_state/1, terminate/3, opt_type/1]).
%% ejabberd_config callbacks
-export([opt_type/1]).
%% xmpp_stream_in callbacks
-export([init/1, handle_call/3, handle_cast/2,
handle_info/2, terminate/2, code_change/3]).
-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
compress_methods/1,
unauthenticated_stream_features/1, authenticated_stream_features/1,
handle_stream_start/2, handle_stream_end/2, handle_stream_close/2,
handle_stream_established/1, handle_auth_success/4,
handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2,
handle_unauthenticated_packet/2, handle_authenticated_packet/2]).
%% Hooks
-export([handle_unexpected_info/2, handle_unexpected_cast/2,
reject_unauthenticated_packet/2, process_closed/2]).
%% API
-export([stop/1, close/1, send/2, update_state/2, establish/1, add_hooks/0]).
-include("ejabberd.hrl").
-include("xmpp.hrl").
-include("logger.hrl").
-include("xmpp.hrl").
-define(DICT, dict).
-record(state,
{socket :: ejabberd_socket:socket_state(),
sockmod = ejabberd_socket :: ejabberd_socket | ejabberd_frontend_socket,
streamid = <<"">> :: binary(),
shaper = none :: shaper:shaper(),
tls = false :: boolean(),
tls_enabled = false :: boolean(),
tls_required = false :: boolean(),
tls_certverify = false :: boolean(),
tls_options = [] :: list(),
server = <<"">> :: binary(),
authenticated = false :: boolean(),
auth_domain = <<"">> :: binary(),
connections = (?DICT):new() :: ?TDICT,
timer = make_ref() :: reference()}).
-type state_name() :: wait_for_stream | wait_for_feature_request | stream_established.
-type state() :: #state{}.
-type fsm_next() :: {next_state, state_name(), state()}.
-type fsm_stop() :: {stop, normal, state()}.
-type fsm_transition() :: fsm_stop() | fsm_next().
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
-define(FSMOPTS, [{debug, [trace]}]).
-else.
-define(FSMOPTS, []).
-endif.
-type state() :: map().
-export_type([state/0]).
%%%===================================================================
%%% API
%%%===================================================================
start(SockData, Opts) ->
supervisor:start_child(ejabberd_s2s_in_sup,
[SockData, Opts]).
xmpp_stream_in:start(?MODULE, [SockData, Opts],
ejabberd_config:fsm_limit_opts(Opts)).
start_link(SockData, Opts) ->
p1_fsm:start_link(ejabberd_s2s_in, [SockData, Opts],
?FSMOPTS ++ fsm_limit_opts(Opts)).
xmpp_stream_in:start_link(?MODULE, [SockData, Opts],
ejabberd_config:fsm_limit_opts(Opts)).
socket_type() -> xml_stream.
close(Ref) ->
xmpp_stream_in:close(Ref).
%%%----------------------------------------------------------------------
%%% Callback functions from gen_fsm
%%%----------------------------------------------------------------------
stop(Ref) ->
xmpp_stream_in:stop(Ref).
init([{SockMod, Socket}, Opts]) ->
?DEBUG("started: ~p", [{SockMod, Socket}]),
Shaper = case lists:keysearch(shaper, 1, Opts) of
{value, {_, S}} -> S;
_ -> none
end,
{StartTLS, TLSRequired, TLSCertverify} =
case ejabberd_config:get_option(
s2s_use_starttls,
fun(false) -> false;
(true) -> true;
(optional) -> optional;
(required) -> required;
(required_trusted) -> required_trusted
end,
false) of
UseTls
when (UseTls == undefined) or
(UseTls == false) ->
{false, false, false};
UseTls
when (UseTls == true) or
(UseTls ==
optional) ->
{true, false, false};
required -> {true, true, false};
required_trusted ->
{true, true, true}
end,
TLSOpts1 = case ejabberd_config:get_option(
s2s_certfile,
fun iolist_to_binary/1) of
undefined -> [];
CertFile -> [{certfile, CertFile}]
end,
TLSOpts2 = case ejabberd_config:get_option(
s2s_ciphers, fun iolist_to_binary/1) of
undefined -> TLSOpts1;
Ciphers -> [{ciphers, Ciphers} | TLSOpts1]
end,
TLSOpts3 = case ejabberd_config:get_option(
s2s_protocol_options,
fun (Options) ->
[_|O] = lists:foldl(
fun(X, Acc) -> X ++ Acc end, [],
[["|" | binary_to_list(Opt)] || Opt <- Options, is_binary(Opt)]
),
iolist_to_binary(O)
end) of
undefined -> TLSOpts2;
ProtocolOpts -> [{protocol_options, ProtocolOpts} | TLSOpts2]
end,
TLSOpts4 = case ejabberd_config:get_option(
s2s_dhfile, fun iolist_to_binary/1) of
undefined -> TLSOpts3;
DHFile -> [{dhfile, DHFile} | TLSOpts3]
end,
TLSOpts = case proplists:get_bool(tls_compression, Opts) of
false -> [compression_none | TLSOpts4];
true -> TLSOpts4
end,
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
{ok, wait_for_stream,
#state{socket = Socket, sockmod = SockMod,
streamid = new_id(), shaper = Shaper, tls = StartTLS,
tls_enabled = false, tls_required = TLSRequired,
tls_certverify = TLSCertverify, tls_options = TLSOpts,
timer = Timer}}.
socket_type() ->
xml_stream.
%%----------------------------------------------------------------------
%% Func: StateName/2
%% Returns: {next_state, NextStateName, NextStateData} |
%% {next_state, NextStateName, NextStateData, Timeout} |
%% {stop, Reason, NewStateData}
%%----------------------------------------------------------------------
wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
#stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM}
when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM ->
send_header(StateData, {1,0}),
send_element(StateData, xmpp:serr_invalid_namespace()),
{stop, normal, StateData};
#stream_start{to = #jid{lserver = Server},
from = From, version = {1,0}}
when StateData#state.tls and not StateData#state.authenticated ->
send_header(StateData, {1,0}),
Auth = if StateData#state.tls_enabled ->
case From of
#jid{} ->
{Result, Message} =
ejabberd_s2s:check_peer_certificate(
StateData#state.sockmod,
StateData#state.socket,
From#jid.lserver),
{Result, From#jid.lserver, Message};
undefined ->
{error, <<"(unknown)">>,
<<"Got no valid 'from' attribute">>}
end;
true ->
{no_verify, <<"(unknown)">>, <<"TLS not (yet) enabled">>}
end,
StartTLS = if StateData#state.tls_enabled -> [];
not StateData#state.tls_enabled and
not StateData#state.tls_required ->
[#starttls{required = false}];
not StateData#state.tls_enabled and
StateData#state.tls_required ->
[#starttls{required = true}]
end,
case Auth of
{error, RemoteServer, CertError}
when StateData#state.tls_certverify ->
?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)",
[StateData#state.server, RemoteServer, CertError]),
send_element(StateData,
xmpp:serr_policy_violation(CertError, ?MYLANG)),
{stop, normal, StateData};
{VerifyResult, RemoteServer, Msg} ->
{SASL, NewStateData} =
case VerifyResult of
ok ->
{[#sasl_mechanisms{list = [<<"EXTERNAL">>]}],
StateData#state{auth_domain = RemoteServer}};
error ->
?DEBUG("Won't accept certificate of ~s: ~s",
[RemoteServer, Msg]),
{[], StateData};
no_verify ->
{[], StateData}
end,
send_element(NewStateData,
#stream_features{
sub_els = SASL ++ StartTLS ++
ejabberd_hooks:run_fold(
s2s_stream_features, Server, [],
[Server])}),
{next_state, wait_for_feature_request,
NewStateData#state{server = Server}}
end;
#stream_start{to = #jid{lserver = Server},
version = {1,0}} when StateData#state.authenticated ->
send_header(StateData, {1,0}),
send_element(StateData,
#stream_features{
sub_els = ejabberd_hooks:run_fold(
s2s_stream_features, Server, [],
[Server])}),
{next_state, stream_established, StateData};
#stream_start{db_xmlns = ?NS_SERVER_DIALBACK}
when (StateData#state.tls_required and StateData#state.tls_enabled)
or (not StateData#state.tls_required) ->
send_header(StateData, undefined),
{next_state, stream_established, StateData};
#stream_start{} ->
send_header(StateData, {1,0}),
send_element(StateData, xmpp:serr_undefined_condition()),
{stop, normal, StateData};
_ ->
send_header(StateData, {1,0}),
send_element(StateData, xmpp:serr_invalid_xml()),
{stop, normal, StateData}
catch _:{xmpp_codec, Why} ->
Txt = xmpp:format_error(Why),
send_header(StateData, {1,0}),
send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)),
{stop, normal, StateData}
end;
wait_for_stream({xmlstreamerror, _}, StateData) ->
send_header(StateData, {1,0}),
send_element(StateData, xmpp:serr_not_well_formed()),
{stop, normal, StateData};
wait_for_stream(timeout, StateData) ->
send_header(StateData, {1,0}),
send_element(StateData, xmpp:serr_connection_timeout()),
{stop, normal, StateData};
wait_for_stream(closed, StateData) ->
{stop, normal, StateData}.
-spec send(pid(), xmpp_element()) -> ok;
(state(), xmpp_element()) -> state().
send(Stream, Pkt) ->
xmpp_stream_in:send(Stream, Pkt).
wait_for_feature_request({xmlstreamelement, El}, StateData) ->
decode_element(El, wait_for_feature_request, StateData);
wait_for_feature_request(#starttls{},
#state{tls = true, tls_enabled = false} = StateData) ->
case (StateData#state.sockmod):get_sockmod(StateData#state.socket) of
gen_tcp ->
?DEBUG("starttls", []),
Socket = StateData#state.socket,
TLSOpts1 = case
ejabberd_config:get_option(
{domain_certfile, StateData#state.server},
fun iolist_to_binary/1) of
undefined -> StateData#state.tls_options;
CertFile ->
lists:keystore(certfile, 1,
StateData#state.tls_options,
{certfile, CertFile})
end,
TLSOpts2 = case ejabberd_config:get_option(
{s2s_cafile, StateData#state.server},
fun iolist_to_binary/1) of
undefined -> TLSOpts1;
CAFile ->
lists:keystore(cafile, 1, TLSOpts1,
{cafile, CAFile})
end,
TLSOpts = case ejabberd_config:get_option(
{s2s_tls_compression, StateData#state.server},
fun(true) -> true;
(false) -> false
end, false) of
true -> lists:delete(compression_none, TLSOpts2);
false -> [compression_none | TLSOpts2]
end,
TLSSocket = (StateData#state.sockmod):starttls(
Socket, TLSOpts,
fxml:element_to_binary(
xmpp:encode(#starttls_proceed{}))),
{next_state, wait_for_stream,
StateData#state{socket = TLSSocket, streamid = new_id(),
tls_enabled = true, tls_options = TLSOpts}};
_ ->
send_element(StateData, #starttls_failure{}),
{stop, normal, StateData}
end;
wait_for_feature_request(#sasl_auth{mechanism = Mech},
#state{tls_enabled = true} = StateData) ->
case Mech of
<<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> ->
AuthDomain = StateData#state.auth_domain,
AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>, AuthDomain),
if AllowRemoteHost ->
(StateData#state.sockmod):reset_stream(StateData#state.socket),
send_element(StateData, #sasl_success{}),
?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)",
[AuthDomain, StateData#state.tls_enabled]),
change_shaper(StateData, <<"">>, jid:make(AuthDomain)),
{next_state, wait_for_stream,
StateData#state{streamid = new_id(),
authenticated = true}};
true ->
Txt = xmpp:mk_text(<<"Denied by ACL">>, ?MYLANG),
send_element(StateData,
#sasl_failure{reason = 'not-authorized',
text = Txt}),
{stop, normal, StateData}
end;
_ ->
send_element(StateData, #sasl_failure{reason = 'invalid-mechanism'}),
{stop, normal, StateData}
end;
wait_for_feature_request({xmlstreamend, _Name}, StateData) ->
{stop, normal, StateData};
wait_for_feature_request({xmlstreamerror, _}, StateData) ->
send_element(StateData, xmpp:serr_not_well_formed()),
{stop, normal, StateData};
wait_for_feature_request(closed, StateData) ->
{stop, normal, StateData};
wait_for_feature_request(_Pkt, #state{tls_required = TLSRequired,
tls_enabled = TLSEnabled} = StateData)
when TLSRequired and not TLSEnabled ->
Txt = <<"Use of STARTTLS required">>,
send_element(StateData, xmpp:serr_policy_violation(Txt, ?MYLANG)),
{stop, normal, StateData};
wait_for_feature_request(El, StateData) ->
stream_established({xmlstreamelement, El}, StateData).
-spec establish(state()) -> state().
establish(State) ->
xmpp_stream_in:establish(State).
stream_established({xmlstreamelement, El}, StateData) ->
cancel_timer(StateData#state.timer),
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
decode_element(El, stream_established, StateData#state{timer = Timer});
stream_established(#db_result{to = To, from = From, key = Key},
StateData) ->
?DEBUG("GET KEY: ~p", [{To, From, Key}]),
case {ejabberd_s2s:allow_host(To, From),
lists:member(To, ejabberd_router:dirty_get_all_domains())} of
{true, true} ->
ejabberd_s2s_out:terminate_if_waiting_delay(To, From),
ejabberd_s2s_out:start(To, From,
{verify, self(), Key,
StateData#state.streamid}),
Conns = (?DICT):store({From, To},
wait_for_verification,
StateData#state.connections),
change_shaper(StateData, To, jid:make(From)),
{next_state, stream_established,
StateData#state{connections = Conns}};
{_, false} ->
send_element(StateData, xmpp:serr_host_unknown()),
{stop, normal, StateData};
{false, _} ->
send_element(StateData, xmpp:serr_invalid_from()),
{stop, normal, StateData}
end;
stream_established(#db_verify{to = To, from = From, id = Id, key = Key},
StateData) ->
?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]),
Type = case ejabberd_s2s:make_key({To, From}, Id) of
Key -> valid;
_ -> invalid
-spec update_state(pid(), fun((state()) -> state()) |
{module(), atom(), list()}) -> ok.
update_state(Ref, Callback) ->
xmpp_stream_in:cast(Ref, {update_state, Callback}).
-spec add_hooks() -> ok.
add_hooks() ->
lists:foreach(
fun(Host) ->
ejabberd_hooks:add(s2s_in_closed, Host, ?MODULE,
process_closed, 100),
ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE,
reject_unauthenticated_packet, 100),
ejabberd_hooks:add(s2s_in_handle_info, Host, ?MODULE,
handle_unexpected_info, 100),
ejabberd_hooks:add(s2s_in_handle_cast, Host, ?MODULE,
handle_unexpected_cast, 100)
end, ?MYHOSTS).
%%%===================================================================
%%% Hooks
%%%===================================================================
handle_unexpected_info(State, Info) ->
?WARNING_MSG("got unexpected info: ~p", [Info]),
State.
handle_unexpected_cast(State, Msg) ->
?WARNING_MSG("got unexpected cast: ~p", [Msg]),
State.
reject_unauthenticated_packet(State, Pkt) ->
Err = xmpp:err_not_authorized(),
xmpp_stream_in:send_error(State, Pkt, Err).
process_closed(State, _Reason) ->
stop(State).
%%%===================================================================
%%% xmpp_stream_in callbacks
%%%===================================================================
tls_options(#{tls_compression := Compression, server_host := LServer}) ->
Opts = case Compression of
false -> [compression_none];
true -> []
end,
send_element(StateData,
#db_verify{from = To, to = From, id = Id, type = Type}),
{next_state, stream_established, StateData};
stream_established(Pkt, StateData) when ?is_stanza(Pkt) ->
ejabberd_s2s:tls_options(LServer, Opts).
tls_required(#{server_host := LServer}) ->
ejabberd_s2s:tls_required(LServer).
tls_verify(#{server_host := LServer}) ->
ejabberd_s2s:tls_verify(LServer).
tls_enabled(#{server_host := LServer}) ->
ejabberd_s2s:tls_enabled(LServer).
compress_methods(#{server_host := LServer}) ->
case ejabberd_s2s:zlib_enabled(LServer) of
true -> [<<"zlib">>];
false -> []
end.
unauthenticated_stream_features(#{server_host := LServer}) ->
ejabberd_hooks:run_fold(s2s_in_pre_auth_features, LServer, [], [LServer]).
authenticated_stream_features(#{server_host := LServer}) ->
ejabberd_hooks:run_fold(s2s_in_post_auth_features, LServer, [], [LServer]).
handle_stream_start(_StreamStart, #{lserver := LServer} = State) ->
case check_to(jid:make(LServer), State) of
false ->
send(State, xmpp:serr_host_unknown());
true ->
ServerHost = ejabberd_router:host_of_route(LServer),
State#{server_host => ServerHost}
end.
handle_stream_end(Reason, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [Reason]).
handle_stream_close(_Reason, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [normal]).
handle_stream_established(State) ->
set_idle_timeout(State#{established => true}).
handle_auth_success(RServer, Mech, _AuthModule,
#{socket := Socket, ip := IP,
auth_domains := AuthDomains,
server_host := ServerHost,
lserver := LServer} = State) ->
?INFO_MSG("(~s) Accepted inbound s2s ~s authentication ~s -> ~s (~s)",
[ejabberd_socket:pp(Socket), Mech, RServer, LServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
State1 = case ejabberd_s2s:allow_host(ServerHost, RServer) of
true ->
AuthDomains1 = sets:add_element(RServer, AuthDomains),
State#{auth_domains => AuthDomains1};
false ->
State
end,
ejabberd_hooks:run_fold(s2s_in_auth_result, ServerHost, State1, [true, RServer]).
handle_auth_failure(RServer, Mech, Reason,
#{socket := Socket, ip := IP,
server_host := ServerHost,
lserver := LServer} = State) ->
?INFO_MSG("(~s) Failed inbound s2s ~s authentication ~s -> ~s (~s): ~s",
[ejabberd_socket:pp(Socket), Mech, RServer, LServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
ejabberd_hooks:run_fold(s2s_in_auth_result,
ServerHost, State, [false, RServer]).
handle_unauthenticated_packet(Pkt, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_unauthenticated_packet,
LServer, State, [Pkt]).
handle_authenticated_packet(Pkt, #{server_host := LServer} = State) when not ?is_stanza(Pkt) ->
ejabberd_hooks:run_fold(s2s_in_authenticated_packet, LServer, State, [Pkt]);
handle_authenticated_packet(Pkt, State) ->
From = xmpp:get_from(Pkt),
To = xmpp:get_to(Pkt),
if To /= undefined, From /= undefined ->
LFrom = From#jid.lserver,
LTo = To#jid.lserver,
if StateData#state.authenticated ->
case LFrom == StateData#state.auth_domain andalso
lists:member(LTo, ejabberd_router:dirty_get_all_domains()) of
true ->
ejabberd_hooks:run(s2s_receive_packet, LTo,
[From, To, Pkt]),
ejabberd_router:route(From, To, Pkt);
false ->
send_error(StateData, Pkt, xmpp:err_not_authorized())
end;
true ->
case (?DICT):find({LFrom, LTo}, StateData#state.connections) of
{ok, established} ->
ejabberd_hooks:run(s2s_receive_packet, LTo,
[From, To, Pkt]),
ejabberd_router:route(From, To, Pkt);
_ ->
send_error(StateData, Pkt, xmpp:err_not_authorized())
end
end;
true ->
send_error(StateData, Pkt, xmpp:err_jid_malformed())
end,
ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]),
{next_state, stream_established, StateData};
stream_established({valid, From, To}, StateData) ->
send_element(StateData,
#db_result{from = To, to = From, type = valid}),
?INFO_MSG("Accepted s2s dialback authentication for ~s (TLS=~p)",
[From, StateData#state.tls_enabled]),
NSD = StateData#state{connections =
(?DICT):store({From, To}, established,
StateData#state.connections)},
{next_state, stream_established, NSD};
stream_established({invalid, From, To}, StateData) ->
send_element(StateData,
#db_result{from = To, to = From, type = invalid}),
NSD = StateData#state{connections =
(?DICT):erase({From, To},
StateData#state.connections)},
{next_state, stream_established, NSD};
stream_established({xmlstreamend, _Name}, StateData) ->
{stop, normal, StateData};
stream_established({xmlstreamerror, _}, StateData) ->
send_element(StateData, xmpp:serr_not_well_formed()),
{stop, normal, StateData};
stream_established(timeout, StateData) ->
send_element(StateData, xmpp:serr_connection_timeout()),
{stop, normal, StateData};
stream_established(closed, StateData) ->
{stop, normal, StateData};
stream_established(Pkt, StateData) ->
ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]),
{next_state, stream_established, StateData}.
case check_from_to(From, To, State) of
ok ->
LServer = ejabberd_router:host_of_route(To#jid.lserver),
State1 = ejabberd_hooks:run_fold(s2s_in_authenticated_packet,
LServer, State, [Pkt]),
Pkt1 = ejabberd_hooks:run_fold(s2s_receive_packet, LServer,
Pkt, [State1]),
ejabberd_router:route(From, To, Pkt1),
State1;
{error, Err} ->
send(State, Err)
end.
%%----------------------------------------------------------------------
%% Func: StateName/3
%% Returns: {next_state, NextStateName, NextStateData} |
%% {next_state, NextStateName, NextStateData, Timeout} |
%% {reply, Reply, NextStateName, NextStateData} |
%% {reply, Reply, NextStateName, NextStateData, Timeout} |
%% {stop, Reason, NewStateData} |
%% {stop, Reason, Reply, NewStateData}
%%----------------------------------------------------------------------
%state_name(Event, From, StateData) ->
% Reply = ok,
% {reply, Reply, state_name, StateData}.
handle_cdata(Data, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_handle_cdata, LServer, State, [Data]).
handle_event(_Event, StateName, StateData) ->
{next_state, StateName, StateData}.
handle_recv(El, Pkt, #{server_host := LServer} = State) ->
State1 = set_idle_timeout(State),
ejabberd_hooks:run_fold(s2s_in_handle_recv, LServer, State1, [El, Pkt]).
handle_sync_event(get_state_infos, _From, StateName,
StateData) ->
SockMod = StateData#state.sockmod,
{Addr, Port} = try
SockMod:peername(StateData#state.socket)
of
{ok, {A, P}} -> {A, P};
{error, _} -> {unknown, unknown}
catch
_:_ -> {unknown, unknown}
end,
Domains = get_external_hosts(StateData),
Infos = [{direction, in}, {statename, StateName},
{addr, Addr}, {port, Port},
{streamid, StateData#state.streamid},
{tls, StateData#state.tls},
{tls_enabled, StateData#state.tls_enabled},
{tls_options, StateData#state.tls_options},
{authenticated, StateData#state.authenticated},
{shaper, StateData#state.shaper}, {sockmod, SockMod},
{domains, Domains}],
Reply = {state_infos, Infos},
{reply, Reply, StateName, StateData};
%%----------------------------------------------------------------------
%% Func: handle_sync_event/4
%% Returns: {next_state, NextStateName, NextStateData} |
%% {next_state, NextStateName, NextStateData, Timeout} |
%% {reply, Reply, NextStateName, NextStateData} |
%% {reply, Reply, NextStateName, NextStateData, Timeout} |
%% {stop, Reason, NewStateData} |
%% {stop, Reason, Reply, NewStateData}
%%----------------------------------------------------------------------
handle_sync_event(_Event, _From, StateName,
StateData) ->
Reply = ok, {reply, Reply, StateName, StateData}.
handle_send(Pkt, Result, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_handle_send, LServer,
State, [Pkt, Result]).
code_change(_OldVsn, StateName, StateData, _Extra) ->
{ok, StateName, StateData}.
init([State, Opts]) ->
Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none),
TLSCompression = proplists:get_bool(tls_compression, Opts),
State1 = State#{tls_compression => TLSCompression,
auth_domains => sets:new(),
xmlns => ?NS_SERVER,
lang => ?MYLANG,
server => ?MYNAME,
lserver => ?MYNAME,
server_host => ?MYNAME,
established => false,
shaper => Shaper},
ejabberd_hooks:run_fold(s2s_in_init, {ok, State1}, [Opts]).
handle_info({send_text, Text}, StateName, StateData) ->
send_text(StateData, Text),
{next_state, StateName, StateData};
handle_info({timeout, Timer, _}, StateName,
#state{timer = Timer} = StateData) ->
if StateName == wait_for_stream ->
send_header(StateData, undefined);
true ->
ok
end,
send_element(StateData, xmpp:serr_connection_timeout()),
{stop, normal, StateData};
handle_info(_, StateName, StateData) ->
{next_state, StateName, StateData}.
handle_call(Request, From, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_handle_call, LServer, State, [Request, From]).
terminate(Reason, _StateName, StateData) ->
?DEBUG("terminated: ~p", [Reason]),
case Reason of
{process_limit, _} ->
[ejabberd_s2s:external_host_overloaded(Host)
|| Host <- get_external_hosts(StateData)];
_ -> ok
end,
catch send_trailer(StateData),
(StateData#state.sockmod):close(StateData#state.socket),
handle_cast({update_state, Fun}, State) ->
case Fun of
{M, F, A} -> erlang:apply(M, F, [State|A]);
_ when is_function(Fun) -> Fun(State)
end;
handle_cast(Msg, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_handle_cast, LServer, State, [Msg]).
handle_info(Info, #{server_host := LServer} = State) ->
ejabberd_hooks:run_fold(s2s_in_handle_info, LServer, State, [Info]).
terminate(_Reason, _State) ->
ok.
get_external_hosts(StateData) ->
case StateData#state.authenticated of
true -> [StateData#state.auth_domain];
false ->
Connections = StateData#state.connections,
[D
|| {{D, _}, established} <- dict:to_list(Connections)]
end.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
print_state(State) -> State.
%%%----------------------------------------------------------------------
%%%===================================================================
%%% Internal functions
%%%----------------------------------------------------------------------
-spec send_text(state(), iodata()) -> ok.
send_text(StateData, Text) ->
(StateData#state.sockmod):send(StateData#state.socket,
Text).
-spec send_element(state(), xmpp_element()) -> ok.
send_element(StateData, El) ->
El1 = xmpp:encode(El, ?NS_SERVER),
send_text(StateData, fxml:element_to_binary(El1)).
-spec send_error(state(), xmlel() | stanza(), stanza_error()) -> ok.
send_error(StateData, Stanza, Error) ->
Type = xmpp:get_type(Stanza),
if Type == error; Type == result;
Type == <<"error">>; Type == <<"result">> ->
ok;
true ->
send_element(StateData, xmpp:make_error(Stanza, Error))
%%%===================================================================
-spec check_from_to(jid(), jid(), state()) -> ok | {error, stream_error()}.
check_from_to(From, To, State) ->
case check_from(From, State) of
true ->
case check_to(To, State) of
true ->
ok;
false ->
{error, xmpp:serr_improper_addressing()}
end;
false ->
{error, xmpp:serr_invalid_from()}
end.
-spec send_trailer(state()) -> ok.
send_trailer(StateData) ->
send_text(StateData, <<"</stream:stream>">>).
-spec check_from(jid(), state()) -> boolean().
check_from(#jid{lserver = S1}, #{auth_domains := AuthDomains}) ->
sets:is_element(S1, AuthDomains).
-spec send_header(state(), undefined | {integer(), integer()}) -> ok.
send_header(StateData, Version) ->
Header = xmpp:encode(
#stream_start{xmlns = ?NS_SERVER,
stream_xmlns = ?NS_STREAM,
db_xmlns = ?NS_SERVER_DIALBACK,
id = StateData#state.streamid,
version = Version}),
send_text(StateData, fxml:element_to_header(Header)).
-spec check_to(jid(), state()) -> boolean().
check_to(#jid{lserver = LServer}, _State) ->
ejabberd_router:is_my_route(LServer).
-spec change_shaper(state(), binary(), jid()) -> ok.
change_shaper(StateData, Host, JID) ->
Shaper = acl:match_rule(Host, StateData#state.shaper,
JID),
(StateData#state.sockmod):change_shaper(StateData#state.socket,
Shaper).
-spec set_idle_timeout(state()) -> state().
set_idle_timeout(#{server_host := LServer,
established := true} = State) ->
Timeout = ejabberd_s2s:get_idle_timeout(LServer),
xmpp_stream_in:set_timeout(State, Timeout);
set_idle_timeout(State) ->
State.
-spec new_id() -> binary().
new_id() -> randoms:get_string().
-spec cancel_timer(reference()) -> ok.
cancel_timer(Timer) ->
erlang:cancel_timer(Timer),
receive {timeout, Timer, _} -> ok after 0 -> ok end.
fsm_limit_opts(Opts) ->
case lists:keysearch(max_fsm_queue, 1, Opts) of
{value, {_, N}} when is_integer(N) -> [{max_queue, N}];
_ ->
case ejabberd_config:get_option(
max_fsm_queue,
fun(I) when is_integer(I), I > 0 -> I end) of
undefined -> [];
N -> [{max_queue, N}]
end
end.
-spec decode_element(xmlel() | xmpp_element(), state_name(), state()) -> fsm_transition().
decode_element(#xmlel{} = El, StateName, StateData) ->
Opts = if StateName == stream_established ->
[ignore_els];
true ->
[]
end,
try xmpp:decode(El, ?NS_SERVER, Opts) of
Pkt -> ?MODULE:StateName(Pkt, StateData)
catch error:{xmpp_codec, Why} ->
case xmpp:is_stanza(El) of
true ->
Lang = xmpp:get_lang(El),
Txt = xmpp:format_error(Why),
send_error(StateData, El, xmpp:err_bad_request(Txt, Lang));
false ->
ok
end,
{next_state, StateName, StateData}
end;
decode_element(Pkt, StateName, StateData) ->
?MODULE:StateName(Pkt, StateData).
opt_type(domain_certfile) -> fun iolist_to_binary/1;
opt_type(max_fsm_queue) ->
fun (I) when is_integer(I), I > 0 -> I end;
opt_type(s2s_certfile) -> fun iolist_to_binary/1;
opt_type(s2s_cafile) -> fun iolist_to_binary/1;
opt_type(s2s_ciphers) -> fun iolist_to_binary/1;
opt_type(s2s_dhfile) -> fun iolist_to_binary/1;
opt_type(s2s_protocol_options) ->
fun (Options) ->
[_ | O] = lists:foldl(fun (X, Acc) -> X ++ Acc end, [],
[["|" | binary_to_list(Opt)]
|| Opt <- Options, is_binary(Opt)]),
iolist_to_binary(O)
end;
opt_type(s2s_tls_compression) ->
fun (true) -> true;
(false) -> false
end;
opt_type(s2s_use_starttls) ->
fun (false) -> false;
(true) -> true;
(optional) -> optional;
(required) -> required;
(required_trusted) -> required_trusted
end;
opt_type(_) ->
[domain_certfile, max_fsm_queue, s2s_certfile, s2s_cafile,
s2s_ciphers, s2s_dhfile, s2s_protocol_options,
s2s_tls_compression, s2s_use_starttls].
[].

File diff suppressed because it is too large Load Diff

View File

@ -22,17 +22,18 @@
-module(ejabberd_service).
-behaviour(xmpp_stream_in).
-behaviour(ejabberd_config).
-behaviour(ejabberd_socket).
-protocol({xep, 114, '1.6'}).
%% ejabberd_socket callbacks
-export([start/2, socket_type/0]).
-export([start/2, start_link/2, socket_type/0]).
%% ejabberd_config callbacks
-export([opt_type/1, transform_listen_option/2]).
%% xmpp_stream_in callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-export([handshake/2, handle_stream_start/1, handle_authenticated_packet/2]).
-export([init/1, handle_info/2, terminate/2, code_change/3]).
-export([handle_stream_start/2, handle_auth_success/4, handle_auth_failure/4,
handle_authenticated_packet/2, get_password_fun/1]).
%% API
-export([send/2]).
@ -40,36 +41,32 @@
-include("xmpp.hrl").
-include("logger.hrl").
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
-define(FSMOPTS, [{debug, [trace]}]).
-else.
-define(FSMOPTS, []).
-endif.
-type state() :: map().
-type next_state() :: {noreply, state()} | {stop, term(), state()}.
-export_type([state/0, next_state/0]).
-export_type([state/0]).
%%%===================================================================
%%% API
%%%===================================================================
start(SockData, Opts) ->
xmpp_stream_in:start(?MODULE, [SockData, Opts],
fsm_limit_opts(Opts) ++ ?FSMOPTS).
ejabberd_config:fsm_limit_opts(Opts)).
start_link(SockData, Opts) ->
xmpp_stream_in:start_link(?MODULE, [SockData, Opts],
ejabberd_config:fsm_limit_opts(Opts)).
socket_type() ->
xml_stream.
-spec send(state(), xmpp_element()) -> next_state().
send(State, Pkt) ->
xmpp_stream_in:send(State, Pkt).
-spec send(pid(), xmpp_element()) -> ok;
(state(), xmpp_element()) -> state().
send(Stream, Pkt) ->
xmpp_stream_in:send(Stream, Pkt).
%%%===================================================================
%%% xmpp_stream_in callbacks
%%%===================================================================
init([#{socket := Socket} = State, Opts]) ->
?INFO_MSG("(~w) External service connected", [Socket]),
init([State, Opts]) ->
Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all),
Shaper = gen_mod:get_opt(shaper_rule, Opts, fun acl:shaper_rules_validator/1, none),
HostOpts = case lists:keyfind(hosts, 1, Opts) of
@ -96,66 +93,85 @@ init([#{socket := Socket} = State, Opts]) ->
server => ?MYNAME,
host_opts => HostOpts,
check_from => CheckFrom},
ejabberd_hooks:run_fold(component_init, {ok, State1}, []).
ejabberd_hooks:run_fold(component_init, {ok, State1}, [Opts]).
handle_stream_start(#{remote_server := RemoteServer,
handle_stream_start(_StreamStart,
#{remote_server := RemoteServer,
lang := Lang,
host_opts := HostOpts} = State) ->
NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of
true ->
HostOpts;
false ->
case dict:find(global, HostOpts) of
{ok, GlobalPass} ->
dict:from_list([{RemoteServer, GlobalPass}]);
error ->
HostOpts
end
end,
{noreply, State#{host_opts => NewHostOpts}}.
handshake(Digest, #{remote_server := RemoteServer,
stream_id := StreamID,
host_opts := HostOpts} = State) ->
case dict:find(RemoteServer, HostOpts) of
{ok, Password} ->
case p1_sha:sha(<<StreamID/binary, Password/binary>>) of
Digest ->
lists:foreach(
fun (H) ->
ejabberd_router:register_route(H, ?MYNAME),
?INFO_MSG("Route registered for service ~p~n", [H]),
ejabberd_hooks:run(component_connected, [H])
end, dict:fetch_keys(HostOpts)),
{ok, State};
_ ->
?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]),
{error, xmpp:serr_not_authorized(), State}
end;
_ ->
?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]),
{error, xmpp:serr_not_authorized(), State}
case lists:member(RemoteServer, ?MYHOSTS) of
true ->
Txt = <<"Unable to register route on existing local domain">>,
xmpp_stream_in:send(State, xmpp:serr_conflict(Txt, Lang));
false ->
NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of
true ->
HostOpts;
false ->
case dict:find(global, HostOpts) of
{ok, GlobalPass} ->
dict:from_list([{RemoteServer, GlobalPass}]);
error ->
HostOpts
end
end,
State#{host_opts => NewHostOpts}
end.
get_password_fun(#{remote_server := RemoteServer,
socket := Socket,
ip := IP,
host_opts := HostOpts}) ->
fun(_) ->
case dict:find(RemoteServer, HostOpts) of
{ok, Password} ->
{Password, undefined};
error ->
?ERROR_MSG("(~s) Domain ~s is unconfigured for "
"external component from ~s",
[ejabberd_socket:pp(Socket), RemoteServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
{false, undefined}
end
end.
handle_auth_success(_, Mech, _,
#{remote_server := RemoteServer, host_opts := HostOpts,
socket := Socket, ip := IP} = State) ->
?INFO_MSG("(~s) Accepted external component ~s authentication "
"for ~s from ~s",
[ejabberd_socket:pp(Socket), Mech, RemoteServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
lists:foreach(
fun (H) ->
ejabberd_router:register_route(H, ?MYNAME),
ejabberd_hooks:run(component_connected, [H])
end, dict:fetch_keys(HostOpts)),
State.
handle_auth_failure(_, Mech, Reason,
#{remote_server := RemoteServer,
socket := Socket, ip := IP} = State) ->
?ERROR_MSG("(~s) Failed external component ~s authentication "
"for ~s from ~s: ~s",
[ejabberd_socket:pp(Socket), Mech, RemoteServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP)),
Reason]),
State.
handle_authenticated_packet(Pkt, #{lang := Lang} = State) ->
From = xmpp:get_from(Pkt),
case check_from(From, State) of
true ->
To = xmpp:get_to(Pkt),
ejabberd_router:route(From, To, Pkt),
{noreply, State};
State;
false ->
Txt = <<"Improper domain part of 'from' attribute">>,
Err = xmpp:serr_invalid_from(Txt, Lang),
xmpp_stream_in:send(State, Err)
end.
handle_call(_Request, _From, State) ->
Reply = ok,
{reply, Reply, State}.
handle_cast(_Msg, State) ->
{noreply, State}.
handle_info({route, From, To, Packet}, #{access := Access} = State) ->
case acl:match_rule(global, Access, From) of
allow ->
@ -165,16 +181,15 @@ handle_info({route, From, To, Packet}, #{access := Access} = State) ->
Lang = xmpp:get_lang(Packet),
Err = xmpp:err_not_allowed(<<"Denied by ACL">>, Lang),
ejabberd_router:route_error(To, From, Packet, Err),
{noreply, State}
State
end;
handle_info(Info, State) ->
?ERROR_MSG("Unexpected info: ~p", [Info]),
{noreply, State}.
State.
terminate(Reason, #{stream_state := StreamState, host_opts := HostOpts}) ->
?INFO_MSG("External service disconnected: ~p", [Reason]),
case StreamState of
session_established ->
established ->
lists:foreach(
fun(H) ->
ejabberd_router:unregister_route(H),
@ -220,19 +235,4 @@ transform_listen_option({host, Host, Os}, Opts) ->
transform_listen_option(Opt, Opts) ->
[Opt|Opts].
fsm_limit_opts(Opts) ->
case lists:keysearch(max_fsm_queue, 1, Opts) of
{value, {_, N}} when is_integer(N) ->
[{max_queue, N}];
_ ->
case ejabberd_config:get_option(
max_fsm_queue,
fun(I) when is_integer(I), I > 0 -> I end) of
undefined -> [];
N -> [{max_queue, N}]
end
end.
opt_type(max_fsm_queue) ->
fun (I) when is_integer(I), I > 0 -> I end;
opt_type(_) -> [max_fsm_queue].
opt_type(_) -> [].

View File

@ -359,20 +359,20 @@ unregister_iq_handler(Host, XMLNS) ->
ejabberd_sm ! {unregister_iq_handler, Host, XMLNS}.
%% Why the hell do we have so many similar kicks?
c2s_handle_info({noreply, #{lang := Lang} = State}, replaced) ->
c2s_handle_info(#{lang := Lang} = State, replaced) ->
State1 = State#{replaced => true},
Err = xmpp:serr_conflict(<<"Replaced by new connection">>, Lang),
ejabberd_c2s:send(State1, Err);
c2s_handle_info({noreply, #{lang := Lang} = State}, kick) ->
{stop, ejabberd_c2s:send(State1, Err)};
c2s_handle_info(#{lang := Lang} = State, kick) ->
Err = xmpp:serr_policy_violation(<<"has been kicked">>, Lang),
c2s_handle_info({noreply, State}, {kick, kicked_by_admin, Err});
c2s_handle_info({noreply, State}, {kick, _Reason, Err}) ->
ejabberd_c2s:send(State, Err);
c2s_handle_info({noreply, #{lang := Lang} = State}, {exit, Reason}) ->
c2s_handle_info(State, {kick, kicked_by_admin, Err});
c2s_handle_info(State, {kick, _Reason, Err}) ->
{stop, ejabberd_c2s:send(State, Err)};
c2s_handle_info(#{lang := Lang} = State, {exit, Reason}) ->
Err = xmpp:serr_conflict(Reason, Lang),
ejabberd_c2s:send(State, Err);
c2s_handle_info(Acc, _) ->
Acc.
{stop, ejabberd_c2s:send(State, Err)};
c2s_handle_info(State, _) ->
State.
%%====================================================================
%% gen_server callbacks

View File

@ -46,6 +46,7 @@
get_peer_certificate/1,
get_verify_result/1,
close/1,
pp/1,
sockname/1, peername/1]).
-include("ejabberd.hrl").
@ -71,6 +72,11 @@
-export_type([socket/0, socket_state/0, sockmod/0]).
-callback start({module(), socket_state()},
[proplists:property()]) -> {ok, pid()} | {error, term()} | ignore.
-callback start_link({module(), socket_state()},
[proplists:property()]) -> {ok, pid()} | {error, term()} | ignore.
-callback socket_type() -> xml_stream | independent | raw.
%%====================================================================
%% API
@ -109,7 +115,7 @@ start(Module, SockMod, Socket, Opts) ->
{error, _Reason} -> SockMod:close(Socket)
end,
ReceiverMod:become_controller(Receiver, Pid);
{error, _Reason} ->
_ ->
SockMod:close(Socket),
case ReceiverMod of
ejabberd_receiver -> ReceiverMod:close(Receiver);
@ -190,6 +196,7 @@ reset_stream(SocketData)
-spec send(socket_state(), iodata()) -> ok.
send(SocketData, Data) ->
?DEBUG("Send XML on stream = ~p", [Data]),
case catch (SocketData#socket_state.sockmod):send(
SocketData#socket_state.socket, Data) of
ok -> ok;
@ -238,8 +245,8 @@ get_transport(#socket_state{sockmod = SockMod,
fast_tls -> tls;
ezlib ->
case ezlib:get_sockmod(Socket) of
tcp -> tcp_zlib;
tls -> tls_zlib
gen_tcp -> tcp_zlib;
fast_tls -> tls_zlib
end;
ejabberd_bosh -> http_bind;
ejabberd_http_bind -> http_bind;
@ -268,3 +275,7 @@ peername(#socket_state{sockmod = SockMod,
gen_tcp -> inet:peername(Socket);
_ -> SockMod:peername(Socket)
end.
pp(#socket_state{receiver = Receiver} = State) ->
Transport = get_transport(State),
io_lib:format("~s|~w", [Transport, Receiver]).

View File

@ -38,8 +38,8 @@
-export([tolower/1, term_to_base64/1, base64_to_term/1,
decode_base64/1, encode_base64/1, ip_to_list/1,
atom_to_binary/1, binary_to_atom/1, tuple_to_binary/1,
l2i/1, i2l/1, i2l/2, queue_drop_while/2,
expr_to_term/1, term_to_expr/1]).
l2i/1, i2l/1, i2l/2, expr_to_term/1, term_to_expr/1,
queue_drop_while/2, queue_foldl/3, queue_foldr/3, queue_foreach/2]).
%% The following functions are used by gen_iq_handler.erl for providing backward
%% compatibility and must not be used in other parts of the code
@ -974,3 +974,33 @@ queue_drop_while(F, Q) ->
empty ->
Q
end.
-spec queue_foldl(fun((term(), T) -> T), T, ?TQUEUE) -> T.
queue_foldl(F, Acc, Q) ->
case queue:out(Q) of
{{value, Item}, Q1} ->
Acc1 = F(Item, Acc),
queue_foldl(F, Acc1, Q1);
{empty, _} ->
Acc
end.
-spec queue_foldr(fun((term(), T) -> T), T, ?TQUEUE) -> T.
queue_foldr(F, Acc, Q) ->
case queue:out_r(Q) of
{{value, Item}, Q1} ->
Acc1 = F(Item, Acc),
queue_foldr(F, Acc1, Q1);
{empty, _} ->
Acc
end.
-spec queue_foreach(fun((_) -> _), ?TQUEUE) -> ok.
queue_foreach(F, Q) ->
case queue:out(Q) of
{{value, Item}, Q1} ->
F(Item),
queue_foreach(F, Q1);
{empty, _} ->
ok
end.

View File

@ -54,8 +54,6 @@ start(Host, Opts) ->
process_iq_set, 40),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
c2s_handle_info, 40),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
c2s_handle_info, 40),
mod_disco:register_feature(Host, ?NS_BLOCKING),
gen_iq_handler:add_iq_handler(ejabberd_sm, Host,
?NS_BLOCKING, ?MODULE, process_iq, IQDisc).
@ -65,6 +63,8 @@ stop(Host) ->
process_iq_get, 40),
ejabberd_hooks:delete(privacy_iq_set, Host, ?MODULE,
process_iq_set, 40),
ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
c2s_handle_info, 40),
mod_disco:unregister_feature(Host, ?NS_BLOCKING),
gen_iq_handler:remove_iq_handler(ejabberd_sm, Host,
?NS_BLOCKING).
@ -253,8 +253,8 @@ process_blocklist_get(LUser, LServer, Lang) ->
{result, #block_list{items = Items}}
end.
-spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state().
c2s_handle_info({noreply, #{user := U, server := S, resource := R} = State},
-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
c2s_handle_info(#{user := U, server := S, resource := R} = State,
{blocking, Action}) ->
SubEl = case Action of
{block, JIDs} ->
@ -272,7 +272,9 @@ c2s_handle_info({noreply, #{user := U, server := S, resource := R} = State},
%% No need to replace active privacy list here,
%% blocking pushes are always accompanied by
%% Privacy List pushes
ejabberd_c2s:send(State, PushIQ).
{stop, ejabberd_c2s:send(State, PushIQ)};
c2s_handle_info(State, _) ->
State.
-spec db_mod(binary()) -> module().
db_mod(LServer) ->

View File

@ -52,16 +52,16 @@ depends(_Host, _Opts) ->
mod_opt_type(_) ->
[].
c2s_unauthenticated_packet({noreply, State}, #iq{type = T, sub_els = [_]} = IQ)
c2s_unauthenticated_packet(State, #iq{type = T, sub_els = [_]} = IQ)
when T == get; T == set ->
case xmpp:get_subtag(IQ, #legacy_auth{}) of
#legacy_auth{} = Auth ->
{stop, authenticate(State, xmpp:set_els(IQ, [Auth]))};
false ->
{noreply, State}
State
end;
c2s_unauthenticated_packet(Acc, _) ->
Acc.
c2s_unauthenticated_packet(State, _) ->
State.
c2s_stream_features(Acc, LServer) ->
case gen_mod:is_loaded(LServer, ?MODULE) of
@ -112,14 +112,10 @@ authenticate(#{stream_id := StreamID, server := Server,
case ejabberd_auth:check_password_with_authmodule(
U, U, JID#jid.lserver, P, D, DGen) of
{true, AuthModule} ->
case ejabberd_c2s:handle_auth_success(
U, <<"legacy">>, AuthModule, State) of
{noreply, State1} ->
State2 = State1#{user := U},
open_session(State2, IQ, R);
Err ->
Err
end;
State1 = ejabberd_c2s:handle_auth_success(
U, <<"legacy">>, AuthModule, State),
State2 = State1#{user := U},
open_session(State2, IQ, R);
_ ->
Err = xmpp:make_error(IQ, xmpp:err_not_authorized()),
process_auth_failure(State, U, Err, 'not-authorized')
@ -137,23 +133,13 @@ open_session(State, IQ, R) ->
case ejabberd_c2s:bind(R, State) of
{ok, State1} ->
Res = xmpp:make_iq_result(IQ),
case ejabberd_c2s:send(State1, Res) of
{noreply, State2} ->
{noreply, State2#{stream_authenticated := true,
stream_state := session_established}};
Err ->
Err
end;
State2 = ejabberd_c2s:send(State1, Res),
ejabberd_c2s:establish(State2);
{error, Err, State1} ->
Res = xmpp:make_error(IQ, Err),
ejabberd_c2s:send(State1, Res)
end.
process_auth_failure(State, User, StanzaErr, Reason) ->
case ejabberd_c2s:send(State, StanzaErr) of
{noreply, State1} ->
ejabberd_c2s:handle_auth_failure(
User, <<"legacy">>, Reason, State1);
Err ->
Err
end.
State1 = ejabberd_c2s:send(State, StanzaErr),
ejabberd_c2s:handle_auth_failure(User, <<"legacy">>, Reason, State1).

View File

@ -309,11 +309,11 @@ get_info(_Acc, #jid{luser = U, lserver = S} = JID,
get_info(Acc, _From, _To, _Node, _Lang) ->
Acc.
-spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state().
c2s_handle_info({noreply, State}, {resend_offline, Flag}) ->
{noreply, State#{resend_offline => Flag}};
c2s_handle_info(Acc, _) ->
Acc.
-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
c2s_handle_info(State, {resend_offline, Flag}) ->
{stop, State#{resend_offline => Flag}};
c2s_handle_info(State, _) ->
State.
-spec handle_offline_query(iq()) -> iq().
handle_offline_query(#iq{from = #jid{luser = U1, lserver = S1},

View File

@ -535,8 +535,8 @@ remove_user(User, Server) ->
Mod = gen_mod:db_mod(LServer, ?MODULE),
Mod:remove_user(LUser, LServer).
c2s_handle_info({noreply, #{privacy_list := Old,
user := U, server := S, resource := R} = State},
c2s_handle_info(#{privacy_list := Old,
user := U, server := S, resource := R} = State,
{privacy_list, New, Name}) ->
List = if Old#userlist.name == New#userlist.name -> New;
true -> Old
@ -548,9 +548,9 @@ c2s_handle_info({noreply, #{privacy_list := Old,
sub_els = [#privacy_query{
lists = [#privacy_list{name = Name}]}]},
State1 = State#{privacy_list => List},
ejabberd_c2s:send(State1, PushIQ);
c2s_handle_info(Acc, _) ->
Acc.
{stop, ejabberd_c2s:send(State1, PushIQ)};
c2s_handle_info(State, _) ->
State.
-spec updated_list(userlist(), userlist(), userlist()) -> userlist().
updated_list(_, #userlist{name = OldName} = Old,

View File

@ -3026,8 +3026,8 @@ broadcast_stanza({LUser, LServer, LResource}, Publisher, Node, Nidx, Type, NodeO
broadcast_stanza(Host, _Publisher, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM) ->
broadcast_stanza(Host, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM).
-spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state().
c2s_handle_info({noreply, #{server := Server} = C2SState},
-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
c2s_handle_info(#{server := Server} = C2SState,
{pep_message, Feature, From, Packet}) ->
LServer = jid:nameprep(Server),
lists:foreach(
@ -3042,8 +3042,8 @@ c2s_handle_info({noreply, #{server := Server} = C2SState},
ok
end
end, mod_caps:list_features(C2SState)),
{noreply, C2SState};
c2s_handle_info({noreply, #{server := Server} = C2SState},
{stop, C2SState};
c2s_handle_info(#{server := Server} = C2SState,
{send_filtered, {pep_message, Feature}, From, To, Packet}) ->
LServer = jid:nameprep(Server),
case mod_caps:get_user_caps(To, C2SState) of
@ -3059,9 +3059,9 @@ c2s_handle_info({noreply, #{server := Server} = C2SState},
error ->
ok
end,
{noreply, C2SState};
c2s_handle_info(Acc, _) ->
Acc.
{stop, C2SState};
c2s_handle_info(C2SState, _) ->
C2SState.
subscribed_nodes_by_jid(NotifyType, SubsByDepth) ->
NodesToDeliver = fun (Depth, Node, Subs, Acc) ->

View File

@ -86,7 +86,7 @@ stream_feature_register(Acc, Host) ->
Acc
end.
c2s_unauthenticated_packet({noreply, #{ip := IP, server := Server} = State},
c2s_unauthenticated_packet(#{ip := IP, server := Server} = State,
#iq{type = T, sub_els = [_]} = IQ)
when T == set; T == get ->
case xmpp:get_subtag(IQ, #register{}) of
@ -97,10 +97,10 @@ c2s_unauthenticated_packet({noreply, #{ip := IP, server := Server} = State},
ResIQ1 = xmpp:set_from_to(ResIQ, jid:make(Server), undefined),
{stop, ejabberd_c2s:send(State, ResIQ1)};
false ->
{noreply, State}
State
end;
c2s_unauthenticated_packet(Acc, _) ->
Acc.
c2s_unauthenticated_packet(State, _) ->
State.
process_iq(#iq{from = From} = IQ) ->
process_iq(IQ, jid:tolower(From)).

View File

@ -464,10 +464,10 @@ push_item_version(Server, User, From, Item,
end,
ejabberd_sm:get_user_resources(User, Server)).
c2s_handle_info({noreply, State}, {item, JID, Sub}) ->
{noreply, roster_change(State, JID, Sub)};
c2s_handle_info(Acc, _) ->
Acc.
c2s_handle_info(State, {item, JID, Sub}) ->
{stop, roster_change(State, JID, Sub)};
c2s_handle_info(State, _) ->
State.
-spec roster_change(ejabberd_c2s:state(), jid(), subscription()) -> ejabberd_c2s:state().
roster_change(#{user := U, server := S, resource := R} = State,

273
src/mod_s2s_dialback.erl Normal file
View File

@ -0,0 +1,273 @@
%%%-------------------------------------------------------------------
%%% Created : 16 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
%%%
%%%
%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
%%%
%%% This program is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU General Public License as
%%% published by the Free Software Foundation; either version 2 of the
%%% License, or (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
%%% General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License along
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%-------------------------------------------------------------------
-module(mod_s2s_dialback).
-behaviour(gen_mod).
-protocol({xep, 220, '1.1.1'}).
-protocol({xep, 185, '1.0'}).
%% gen_mod API
-export([start/2, stop/1, depends/2, mod_opt_type/1]).
%% Hooks
-export([s2s_out_auth_result/2, s2s_in_packet/2, s2s_out_packet/2,
s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]).
-include("ejabberd.hrl").
-include("xmpp.hrl").
-include("logger.hrl").
%%%===================================================================
%%% API
%%%===================================================================
start(Host, _Opts) ->
case ejabberd_s2s:tls_verify(Host) of
true ->
?ERROR_MSG("disabling ~s for host ~s because option "
"'s2s_use_starttls' is set to 'required_trusted'",
[?MODULE, Host]);
false ->
ejabberd_hooks:add(s2s_out_init, Host, ?MODULE, s2s_out_init, 50),
ejabberd_hooks:add(s2s_out_closed, Host, ?MODULE, s2s_out_closed, 50),
ejabberd_hooks:add(s2s_in_pre_auth_features, Host, ?MODULE,
s2s_in_features, 50),
ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE,
s2s_in_features, 50),
ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE,
s2s_in_packet, 50),
ejabberd_hooks:add(s2s_in_authenticated_packet, Host, ?MODULE,
s2s_in_packet, 50),
ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE,
s2s_out_packet, 50),
ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE,
s2s_out_auth_result, 50)
end.
stop(Host) ->
ejabberd_hooks:delete(s2s_out_init, Host, ?MODULE, s2s_out_init, 50),
ejabberd_hooks:delete(s2s_out_closed, Host, ?MODULE, s2s_out_closed, 50),
ejabberd_hooks:delete(s2s_in_pre_auth_features, Host, ?MODULE,
s2s_in_features, 50),
ejabberd_hooks:delete(s2s_in_post_auth_features, Host, ?MODULE,
s2s_in_features, 50),
ejabberd_hooks:delete(s2s_in_unauthenticated_packet, Host, ?MODULE,
s2s_in_packet, 50),
ejabberd_hooks:delete(s2s_in_authenticated_packet, Host, ?MODULE,
s2s_in_packet, 50),
ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE,
s2s_out_packet, 50),
ejabberd_hooks:delete(s2s_out_auth_result, Host, ?MODULE,
s2s_out_auth_result, 50).
depends(_Host, _Opts) ->
[].
mod_opt_type(_) ->
[].
s2s_in_features(Acc, _) ->
[#db_feature{errors = true}|Acc].
s2s_out_init({ok, State}, Opts) ->
case proplists:get_value(db_verify, Opts) of
{StreamID, Key, Pid} ->
%% This is an outbound s2s connection created at step 1.
%% The purpose of this connection is to verify dialback key ONLY.
%% The connection is not registered in s2s table and thus is not
%% seen by anyone.
%% The connection will be closed immediately after receiving the
%% verification response (at step 3)
{ok, State#{db_verify => {StreamID, Key, Pid}}};
undefined ->
{ok, State#{db_enabled => true}}
end;
s2s_out_init(Acc, _Opts) ->
Acc.
s2s_out_closed(#{server := LServer,
remote_server := RServer,
db_verify := {StreamID, _Key, _Pid}} = State, _Reason) ->
%% Outbound s2s verificating connection (created at step 1) is
%% closed suddenly without receiving the response.
%% Building a response on our own
Response = #db_verify{from = RServer, to = LServer,
id = StreamID, type = error,
sub_els = [mk_error(internal_server_error)]},
s2s_out_packet(State, Response);
s2s_out_closed(State, _Reason) ->
State.
s2s_out_auth_result(#{server := LServer,
remote_server := RServer,
db_verify := {StreamID, Key, _Pid}} = State,
_) ->
%% The temporary outbound s2s connect (intended for verification)
%% has passed authentication state (either successfully or not, no matter)
%% and at this point we can send verification request as described
%% in section 2.1.2, step 2
Request = #db_verify{from = LServer, to = RServer,
key = Key, id = StreamID},
{stop, ejabberd_s2s_out:send(State, Request)};
s2s_out_auth_result(#{db_enabled := true,
socket := Socket, ip := IP,
server := LServer,
remote_server := RServer,
stream_remote_id := StreamID} = State, false) ->
%% SASL authentication has failed, retrying with dialback
%% Sending dialback request, section 2.1.1, step 1
?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)",
[ejabberd_socket:pp(Socket), LServer, RServer,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
Key = make_key(LServer, RServer, StreamID),
State1 = maps:remove(stop_reason, State#{on_route => queue}),
State2 = ejabberd_s2s_out:send(State1, #db_result{from = LServer,
to = RServer,
key = Key}),
{stop, State2};
s2s_out_auth_result(State, _) ->
State.
s2s_in_packet(#{stream_id := StreamID} = State,
#db_result{from = From, to = To, key = Key, type = undefined}) ->
%% Received dialback request, section 2.2.1, step 1
try
ok = check_from_to(From, To),
%% We're creating a temporary outbound s2s connection to
%% send verification request and to receive verification response
{ok, Pid} = ejabberd_s2s_out:start(
To, From, [{db_verify, {StreamID, Key, self()}}]),
ejabberd_s2s_out:connect(Pid),
State
catch _:{badmatch, {error, Reason}} ->
send_db_result(State,
#db_verify{from = From, to = To, type = error,
sub_els = [mk_error(Reason)]})
end;
s2s_in_packet(State, #db_verify{to = To, from = From, key = Key,
id = StreamID, type = undefined}) ->
%% Received verification request, section 2.2.2, step 2
Type = case make_key(To, From, StreamID) of
Key -> valid;
_ -> invalid
end,
Response = #db_verify{from = To, to = From, id = StreamID, type = Type},
ejabberd_s2s_in:send(State, Response);
s2s_in_packet(State, Pkt) when is_record(Pkt, db_result);
is_record(Pkt, db_verify) ->
?WARNING_MSG("Got stray dialback packet:~n~s", [xmpp:pp(Pkt)]),
State;
s2s_in_packet(State, _) ->
State.
s2s_out_packet(#{server := LServer,
remote_server := RServer,
db_verify := {StreamID, _Key, Pid}} = State,
#db_verify{from = RServer, to = LServer,
id = StreamID, type = Type} = Response)
when Type /= undefined ->
%% Received verification response, section 2.1.2, step 3
%% This is a response for the request sent at step 2
ejabberd_s2s_in:update_state(
Pid, fun(S) -> send_db_result(S, Response) end),
%% At this point the connection is no longer needed and we can terminate it
ejabberd_s2s_out:stop(State);
s2s_out_packet(#{server := LServer, remote_server := RServer} = State,
#db_result{to = LServer, from = RServer,
type = Type} = Result) when Type /= undefined ->
%% Received dialback response, section 2.1.1, step 4
%% This is a response to the request sent at step 1
State1 = maps:remove(db_enabled, State),
case Type of
valid ->
State2 = ejabberd_s2s_out:handle_auth_success(<<"dialback">>, State1),
ejabberd_s2s_out:establish(State2);
_ ->
Reason = format_error(Result),
ejabberd_s2s_out:handle_auth_failure(<<"dialback">>, Reason, State1)
end;
s2s_out_packet(State, Pkt) when is_record(Pkt, db_result);
is_record(Pkt, db_verify) ->
?WARNING_MSG("Got stray dialback packet:~n~s", [xmpp:pp(Pkt)]),
State;
s2s_out_packet(State, _) ->
State.
%%%===================================================================
%%% Internal functions
%%%===================================================================
-spec make_key(binary(), binary(), binary()) -> binary().
make_key(From, To, StreamID) ->
Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end),
p1_sha:to_hexlist(
crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)),
[To, " ", From, " ", StreamID])).
-spec send_db_result(ejabberd_s2s_in:state(), db_verify()) -> ejabberd_s2s_in:state().
send_db_result(State, #db_verify{from = From, to = To,
type = Type, sub_els = Els}) ->
%% Sending dialback response, section 2.2.1, step 4
%% This is a response to the request received at step 1
Response = #db_result{from = To, to = From, type = Type, sub_els = Els},
State1 = ejabberd_s2s_in:send(State, Response),
case Type of
valid ->
State2 = ejabberd_s2s_in:handle_auth_success(
From, <<"dialback">>, undefined, State1),
ejabberd_s2s_in:establish(State2);
_ ->
Reason = format_error(Response),
ejabberd_s2s_in:handle_auth_failure(
From, <<"dialback">>, Reason, State1)
end.
-spec check_from_to(binary(), binary()) -> ok | {error, forbidden | host_unknown}.
check_from_to(From, To) ->
case ejabberd_router:is_my_route(To) of
false -> {error, host_unknown};
true ->
LServer = ejabberd_router:host_of_route(To),
case ejabberd_s2s:allow_host(LServer, From) of
true -> ok;
false -> {error, forbidden}
end
end.
-spec mk_error(term()) -> stanza_error().
mk_error(forbidden) ->
xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG);
mk_error(host_unknown) ->
xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG);
mk_error(_) ->
xmpp:err_internal_server_error().
-spec format_error(db_result()) -> binary().
format_error(#db_result{type = invalid}) ->
<<"invalid dialback key">>;
format_error(#db_result{type = error, sub_els = Els}) ->
%% TODO: improve xmpp.erl
case xmpp:get_error(#message{sub_els = Els}) of
#stanza_error{reason = Reason} ->
erlang:atom_to_binary(Reason, latin1);
undefined ->
<<"unrecognized error">>
end;
format_error(_) ->
<<"unexpected dialback result">>.

660
src/mod_sm.erl Normal file
View File

@ -0,0 +1,660 @@
%%%-------------------------------------------------------------------
%%% Author : Holger Weiss <holger@zedat.fu-berlin.de>
%%% Created : 25 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
%%%
%%%
%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
%%%
%%% This program is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU General Public License as
%%% published by the Free Software Foundation; either version 2 of the
%%% License, or (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
%%% General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License along
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%-------------------------------------------------------------------
-module(mod_sm).
-behaviour(gen_mod).
-author('holger@zedat.fu-berlin.de').
-protocol({xep, 198, '1.5.2'}).
%% gen_mod API
-export([start/2, stop/1, depends/2, mod_opt_type/1]).
%% hooks
-export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2,
c2s_authenticated_packet/2, c2s_unauthenticated_packet/2,
c2s_unbinded_packet/2, c2s_closed/2,
c2s_handle_send/3, c2s_filter_send/2, c2s_handle_info/2]).
-include("xmpp.hrl").
-include("logger.hrl").
-define(is_sm_packet(Pkt),
is_record(Pkt, sm_enable) or
is_record(Pkt, sm_resume) or
is_record(Pkt, sm_a) or
is_record(Pkt, sm_r)).
-type state() :: ejabberd_c2s:state().
-type lqueue() :: {non_neg_integer(), queue:queue()}.
%%%===================================================================
%%% API
%%%===================================================================
start(Host, _Opts) ->
ejabberd_hooks:add(c2s_init, ?MODULE, c2s_stream_init, 50),
ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE,
c2s_stream_started, 50),
ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE,
c2s_stream_features, 50),
ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
c2s_unauthenticated_packet, 50),
ejabberd_hooks:add(c2s_unbinded_packet, Host, ?MODULE,
c2s_unbinded_packet, 50),
ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE,
c2s_authenticated_packet, 50),
ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE,
c2s_handle_send, 50),
ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE,
c2s_filter_send, 50),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
c2s_handle_info, 50),
ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50).
stop(Host) ->
%% TODO: do something with global 'c2s_init' hook
ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE,
c2s_stream_started, 50),
ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE,
c2s_stream_features, 50),
ejabberd_hooks:delete(c2s_unauthenticated_packet, Host, ?MODULE,
c2s_unauthenticated_packet, 50),
ejabberd_hooks:delete(c2s_unbinded_packet, Host, ?MODULE,
c2s_unbinded_packet, 50),
ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE,
c2s_authenticated_packet, 50),
ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE,
c2s_handle_send, 50),
ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE,
c2s_filter_send, 50),
ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
c2s_handle_info, 50),
ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50).
depends(_Host, _Opts) ->
[].
c2s_stream_init({ok, State}, Opts) ->
MgmtOpts = lists:filter(
fun({stream_management, _}) -> true;
({max_ack_queue, _}) -> true;
({resume_timeout, _}) -> true;
({max_resume_timeout, _}) -> true;
({ack_timeout, _}) -> true;
({resend_on_timeout, _}) -> true;
(_) -> false
end, Opts),
{ok, State#{mgmt_options => MgmtOpts}};
c2s_stream_init(Acc, _Opts) ->
Acc.
c2s_stream_started(#{lserver := LServer, mgmt_options := Opts} = State,
_StreamStart) ->
State1 = maps:remove(mgmt_options, State),
ResumeTimeout = get_resume_timeout(LServer, Opts),
MaxResumeTimeout = get_max_resume_timeout(LServer, Opts, ResumeTimeout),
State1#{mgmt_state => inactive,
mgmt_max_queue => get_max_ack_queue(LServer, Opts),
mgmt_timeout => ResumeTimeout,
mgmt_max_timeout => MaxResumeTimeout,
mgmt_ack_timeout => get_ack_timeout(LServer, Opts),
mgmt_resend => get_resend_on_timeout(LServer, Opts)};
c2s_stream_started(State, _StreamStart) ->
State.
c2s_stream_features(Acc, Host) ->
case gen_mod:is_loaded(Host, ?MODULE) of
true ->
[#feature_sm{xmlns = ?NS_STREAM_MGMT_2},
#feature_sm{xmlns = ?NS_STREAM_MGMT_3}|Acc];
false ->
Acc
end.
c2s_unauthenticated_packet(State, Pkt) when ?is_sm_packet(Pkt) ->
%% XEP-0198 says: "For client-to-server connections, the client MUST NOT
%% attempt to enable stream management until after it has completed Resource
%% Binding unless it is resuming a previous session". However, it also
%% says: "Stream management errors SHOULD be considered recoverable", so we
%% won't bail out.
Err = #sm_failed{reason = 'unexpected-request', xmlns = ?NS_STREAM_MGMT_3},
{stop, send(State, Err)};
c2s_unauthenticated_packet(State, _Pkt) ->
State.
c2s_unbinded_packet(State, #sm_resume{} = Pkt) ->
case handle_resume(State, Pkt) of
{ok, ResumedState} ->
{stop, ResumedState};
error ->
{stop, State}
end;
c2s_unbinded_packet(State, Pkt) when ?is_sm_packet(Pkt) ->
c2s_unauthenticated_packet(State, Pkt);
c2s_unbinded_packet(State, _Pkt) ->
State.
c2s_authenticated_packet(#{mgmt_state := MgmtState} = State, Pkt)
when ?is_sm_packet(Pkt) ->
if MgmtState == pending; MgmtState == active ->
{stop, perform_stream_mgmt(Pkt, State)};
true ->
{stop, negotiate_stream_mgmt(Pkt, State)}
end;
c2s_authenticated_packet(State, Pkt) ->
update_num_stanzas_in(State, Pkt).
c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
when MgmtState == pending; MgmtState == active ->
State1 = mgmt_queue_add(State, Pkt),
case Result of
ok when ?is_stanza(Pkt) ->
send_ack(State1);
ok ->
State1;
{error, _} ->
transition_to_pending(State1)
end;
c2s_handle_send(State, _Pkt, _Result) ->
State.
c2s_filter_send(Pkt, _State) ->
Pkt.
c2s_handle_info(#{mgmt_ack_timer := T, jid := JID} = State,
{timeout, T, ack_timeout}) ->
?DEBUG("Timeout waiting for stream management acknowledgement of ~s",
[jid:to_string(JID)]),
State1 = ejabberd_c2s:close(State, _SendTrailer = false),
c2s_closed(State1, ack_timeout);
c2s_handle_info(State, _) ->
State.
c2s_closed(#{mgmt_state := active} = State, Reason) when Reason /= normal ->
{stop, transition_to_pending(State)};
c2s_closed(State, _) ->
State.
%%%===================================================================
%%% Internal functions
%%%===================================================================
-spec negotiate_stream_mgmt(xmpp_element(), state()) -> state().
negotiate_stream_mgmt(Pkt, State) ->
Xmlns = xmpp:get_ns(Pkt),
case Pkt of
#sm_enable{} ->
handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt);
_ ->
Res = if is_record(Pkt, sm_a);
is_record(Pkt, sm_r);
is_record(Pkt, sm_resume) ->
#sm_failed{reason = 'unexpected-request',
xmlns = Xmlns};
true ->
#sm_failed{reason = 'bad-request',
xmlns = Xmlns}
end,
send(State, Res)
end.
-spec perform_stream_mgmt(xmpp_element(), state()) -> state().
perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) ->
case xmpp:get_ns(Pkt) of
Xmlns ->
case Pkt of
#sm_r{} ->
handle_r(State);
#sm_a{} ->
handle_a(State, Pkt);
_ ->
Res = if is_record(Pkt, sm_enable);
is_record(Pkt, sm_resume) ->
#sm_failed{reason = 'unexpected-request',
xmlns = Xmlns};
true ->
#sm_failed{reason = 'bad-request',
xmlns = Xmlns}
end,
send(State, Res)
end;
_ ->
send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns})
end.
-spec handle_enable(state(), sm_enable()) -> state().
handle_enable(#{mgmt_timeout := DefaultTimeout,
mgmt_max_timeout := MaxTimeout,
xmlns := Xmlns, jid := JID} = State,
#sm_enable{resume = Resume, max = Max}) ->
Timeout = if Resume == false ->
0;
Max /= undefined, Max > 0, Max =< MaxTimeout ->
Max;
true ->
DefaultTimeout
end,
Res = if Timeout > 0 ->
?INFO_MSG("Stream management with resumption enabled for ~s",
[jid:to_string(JID)]),
#sm_enabled{xmlns = Xmlns,
id = make_resume_id(State),
resume = true,
max = Timeout};
true ->
?INFO_MSG("Stream management without resumption enabled for ~s",
[jid:to_string(JID)]),
#sm_enabled{xmlns = Xmlns}
end,
State1 = State#{mgmt_state => active,
mgmt_queue => queue_new(),
mgmt_timeout => Timeout * 1000},
send(State1, Res).
-spec handle_r(state()) -> state().
handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) ->
Res = #sm_a{xmlns = Xmlns, h = H},
send(State, Res).
-spec handle_a(state(), sm_a()) -> state().
handle_a(State, #sm_a{h = H}) ->
State1 = check_h_attribute(State, H),
resend_ack(State1).
-spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}.
handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State,
#sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) ->
R = case inherit_session_state(State, PrevID) of
{ok, InheritedState} ->
{ok, InheritedState, H};
{error, Err, InH} ->
{error, #sm_failed{reason = 'item-not-found',
h = InH, xmlns = Xmlns}, Err};
{error, Err} ->
{error, #sm_failed{reason = 'item-not-found',
xmlns = Xmlns}, Err}
end,
case R of
{ok, ResumedState, NumHandled} ->
State1 = check_h_attribute(ResumedState, NumHandled),
#{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1,
AttrId = make_resume_id(State1),
State2 = send(State1, #sm_resumed{xmlns = AttrXmlns,
h = AttrH,
previd = AttrId}),
State3 = resend_unacked_stanzas(State2),
State4 = send(State3, #sm_r{xmlns = AttrXmlns}),
%% TODO: move this to mod_client_state
%% csi_flush_queue(State4),
State5 = ejabberd_hooks:run_fold(c2s_session_resumed, LServer, State4, []),
?INFO_MSG("(~s) Resumed session for ~s",
[ejabberd_socket:pp(Socket), jid:to_string(JID)]),
{ok, State5};
{error, El, Msg} ->
?INFO_MSG("Cannot resume session for ~s: ~s", [jid:to_string(JID), Msg]),
{error, send(State, El)}
end.
-spec transition_to_pending(state()) -> state().
transition_to_pending(#{mgmt_state := active} = State) ->
%% TODO
State;
transition_to_pending(State) ->
State.
-spec check_h_attribute(state(), non_neg_integer()) -> state().
check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, jid := JID} = State, H)
when H > NumStanzasOut ->
?DEBUG("~s acknowledged ~B stanzas, but only ~B were sent",
[jid:to_string(JID), H, NumStanzasOut]),
mgmt_queue_drop(State#{mgmt_stanzas_out => H}, NumStanzasOut);
check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, jid := JID} = State, H) ->
?DEBUG("~s acknowledged ~B of ~B stanzas",
[jid:to_string(JID), H, NumStanzasOut]),
mgmt_queue_drop(State, H).
-spec update_num_stanzas_in(state(), xmpp_element()) -> state().
update_num_stanzas_in(#{mgmt_state := MgmtState,
mgmt_stanzas_in := NumStanzasIn} = State, El)
when MgmtState == active; MgmtState == pending ->
NewNum = case {xmpp:is_stanza(El), NumStanzasIn} of
{true, 4294967295} ->
0;
{true, Num} ->
Num + 1;
{false, Num} ->
Num
end,
State#{mgmt_stanzas_in => NewNum};
update_num_stanzas_in(State, _El) ->
State.
send_ack(#{mgmt_ack_timer := _} = State) ->
State;
send_ack(#{mgmt_xmlns := Xmlns,
mgmt_stanzas_out := NumStanzasOut,
mgmt_ack_timeout := AckTimeout} = State) ->
State1 = send(State, #sm_r{xmlns = Xmlns}),
TRef = erlang:start_timer(AckTimeout, self(), ack_timeout),
State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}.
resend_ack(#{mgmt_ack_timer := _,
mgmt_queue := Queue,
mgmt_stanzas_out := NumStanzasOut,
mgmt_stanzas_req := NumStanzasReq} = State) ->
State1 = cancel_ack_timer(State),
case NumStanzasReq < NumStanzasOut andalso not queue_is_empty(Queue) of
true -> send_ack(State1);
false -> State1
end;
resend_ack(State) ->
State.
-spec mgmt_queue_add(state(), xmpp_element()) -> state().
mgmt_queue_add(#{mgmt_stanzas_out := NumStanzasOut,
mgmt_queue := Queue} = State, Stanza) when ?is_stanza(Stanza) ->
NewNum = case NumStanzasOut of
4294967295 -> 0;
Num -> Num + 1
end,
Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Stanza}, Queue),
State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum},
check_queue_length(State1);
mgmt_queue_add(State, _Nonza) ->
State.
-spec mgmt_queue_drop(state(), non_neg_integer()) -> state().
mgmt_queue_drop(#{mgmt_queue := Queue} = State, NumHandled) ->
NewQueue = queue_dropwhile(
fun({N, _T, _E}) -> N =< NumHandled end, Queue),
State#{mgmt_queue => NewQueue}.
-spec check_queue_length(state()) -> state().
check_queue_length(#{mgmt_max_queue := Limit} = State)
when Limit == infinity; Limit == exceeded ->
State;
check_queue_length(#{mgmt_queue := Queue, mgmt_max_queue := Limit} = State) ->
case queue_len(Queue) > Limit of
true ->
State#{mgmt_max_queue => exceeded};
false ->
State
end.
-spec resend_unacked_stanzas(state()) -> state().
resend_unacked_stanzas(#{mgmt_state := MgmtState,
mgmt_queue := {QueueLen, _} = Queue,
jid := JID} = State)
when (MgmtState == active orelse
MgmtState == pending orelse
MgmtState == timeout) andalso QueueLen > 0 ->
?DEBUG("Resending ~B unacknowledged stanza(s) to ~s",
[QueueLen, jid:to_string(JID)]),
queue_foldl(
fun({_, Time, Pkt}, AccState) ->
NewPkt = add_resent_delay_info(AccState, Pkt, Time),
send(AccState, NewPkt)
end, State, Queue);
resend_unacked_stanzas(State) ->
State.
-spec route_unacked_stanzas(state()) -> ok.
route_unacked_stanzas(#{mgmt_state := MgmtState,
mgmt_resend := MgmtResend,
lang := Lang, user := User,
jid := JID, lserver := LServer,
mgmt_queue := {QueueLen, _} = Queue,
resource := Resource} = State)
when (MgmtState == active orelse
MgmtState == pending orelse
MgmtState == timeout) andalso QueueLen > 0 ->
ResendOnTimeout = case MgmtResend of
Resend when is_boolean(Resend) ->
Resend;
if_offline ->
case ejabberd_sm:get_user_resources(User, Resource) of
[Resource] ->
%% Same resource opened new session
true;
[] -> true;
_ -> false
end
end,
?DEBUG("Re-routing ~B unacknowledged stanza(s) to ~s",
[QueueLen, jid:to_string(JID)]),
queue_foreach(
fun({_, _Time, #presence{from = From}}) ->
?DEBUG("Dropping presence stanza from ~s", [jid:to_string(From)]);
({_, _Time, #iq{} = El}) ->
Txt = <<"User session terminated">>,
route_error(El, xmpp:err_service_unavailable(Txt, Lang));
({_, _Time, #message{from = From, meta = #{carbon_copy := true}}}) ->
%% XEP-0280 says: "When a receiving server attempts to deliver a
%% forked message, and that message bounces with an error for
%% any reason, the receiving server MUST NOT forward that error
%% back to the original sender." Resending such a stanza could
%% easily lead to unexpected results as well.
?DEBUG("Dropping forwarded message stanza from ~s",
[jid:to_string(From)]);
({_, Time, El}) ->
case ejabberd_hooks:run_fold(message_is_archived,
LServer, false,
[State, El]) of
true ->
?DEBUG("Dropping archived message stanza from ~s",
[jid:to_string(xmpp:get_from(El))]);
false when ResendOnTimeout ->
NewEl = add_resent_delay_info(State, El, Time),
route(NewEl);
false ->
Txt = <<"User session terminated">>,
route_error(El, xmpp:err_service_unavailable(Txt, Lang))
end
end, Queue);
route_unacked_stanzas(_State) ->
ok.
-spec inherit_session_state(state(), binary()) -> {ok, state()} |
{error, binary()} |
{error, binary(), non_neg_integer()}.
inherit_session_state(#{user := U, server := S} = State, ResumeID) ->
case jlib:base64_to_term(ResumeID) of
{term, {R, Time}} ->
case ejabberd_sm:get_session_pid(U, S, R) of
none ->
case ejabberd_sm:get_offline_info(Time, U, S, R) of
none ->
{error, <<"Previous session PID not found">>};
Info ->
case proplists:get_value(num_stanzas_in, Info) of
undefined ->
{error, <<"Previous session timed out">>};
H ->
{error, <<"Previous session timed out">>, H}
end
end;
OldPID ->
OldSID = {Time, OldPID},
try resume_session(OldSID, State) of
{resume, OldState} ->
State1 = ejabberd_c2s:copy_state(State, OldState),
State2 = ejabberd_c2s:open_session(State1),
{ok, State2};
{error, Msg} ->
{error, Msg}
catch exit:{noproc, _} ->
{error, <<"Previous session PID is dead">>};
exit:{timeout, _} ->
{error, <<"Session state copying timed out">>}
end
end;
_ ->
{error, <<"Invalid 'previd' value">>}
end.
-spec resume_session({integer(), pid()}, state()) -> {resume, state()} |
{error, binary()}.
resume_session({Time, Pid}, _State) ->
ejabberd_c2s:call(Pid, {resume_session, Time}, timer:seconds(15)).
-spec make_resume_id(state()) -> binary().
make_resume_id(#{sid := {Time, _}, resource := Resource}) ->
jlib:term_to_base64({Resource, Time}).
-spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza().
add_resent_delay_info(_State, #iq{} = El, _Time) ->
El;
add_resent_delay_info(#{lserver := LServer}, El, Time) ->
xmpp_util:add_delay_info(El, jid:make(LServer), Time, <<"Resent">>).
-spec route(stanza()) -> ok.
route(Pkt) ->
From = xmpp:get_from(Pkt),
To = xmpp:get_to(Pkt),
ejabberd_router:route(From, To, Pkt).
-spec route_error(stanza(), stanza_error()) -> ok.
route_error(Pkt, Err) ->
From = xmpp:get_from(Pkt),
To = xmpp:get_to(Pkt),
ejabberd_router:route_error(To, From, Pkt, Err).
-spec send(state(), xmpp_element()) -> state().
send(#{mod := Mod} = State, Pkt) ->
Mod:send(State, Pkt).
-spec queue_new() -> lqueue().
queue_new() ->
{0, queue:new()}.
-spec queue_in(term(), lqueue()) -> lqueue().
queue_in(Elem, {N, Q}) ->
{N+1, queue:in(Elem, Q)}.
-spec queue_len(lqueue()) -> non_neg_integer().
queue_len({N, _}) ->
N.
-spec queue_foldl(fun((term(), T) -> T), T, lqueue()) -> T.
queue_foldl(F, Acc, {_N, Q}) ->
jlib:queue_foldl(F, Acc, Q).
-spec queue_foreach(fun((_) -> _), lqueue()) -> ok.
queue_foreach(F, {_N, Q}) ->
jlib:queue_foreach(F, Q).
-spec queue_dropwhile(fun((term()) -> boolean()), lqueue()) -> lqueue().
queue_dropwhile(F, {N, Q}) ->
case queue:peek(Q) of
{value, Item} ->
case F(Item) of
true ->
queue_dropwhile(F, {N-1, queue:drop(Q)});
false ->
{N, Q}
end;
empty ->
{N, Q}
end.
-spec queue_is_empty(lqueue()) -> boolean().
queue_is_empty({N, _Q}) ->
N == 0.
-spec cancel_ack_timer(state()) -> state().
cancel_ack_timer(#{mgmt_ack_timer := TRef} = State) ->
case erlang:cancel_timer(TRef) of
false ->
receive {timeout, TRef, _} -> ok
after 0 -> ok
end;
_ ->
ok
end,
maps:remove(mgmt_ack_timer, State);
cancel_ack_timer(State) ->
State.
%%%===================================================================
%%% Configuration processing
%%%===================================================================
get_max_ack_queue(Host, Opts) ->
VFun = mod_opt_type(max_ack_queue),
case gen_mod:get_module_opt(Host, ?MODULE, max_ack_queue, VFun) of
undefined -> gen_mod:get_opt(max_ack_queue, Opts, VFun, 1000);
Limit -> Limit
end.
get_resume_timeout(Host, Opts) ->
VFun = mod_opt_type(resume_timeout),
case gen_mod:get_module_opt(Host, ?MODULE, resume_timeout, VFun) of
undefined -> gen_mod:get_opt(resume_timeout, Opts, VFun, 300);
Timeout -> Timeout
end.
get_max_resume_timeout(Host, Opts, ResumeTimeout) ->
VFun = mod_opt_type(max_resume_timeout),
case gen_mod:get_module_opt(Host, ?MODULE, max_resume_timeout, VFun) of
undefined ->
case gen_mod:get_opt(max_resume_timeout, Opts, VFun) of
undefined -> ResumeTimeout;
Max when Max >= ResumeTimeout -> Max;
_ -> ResumeTimeout
end;
Max when Max >= ResumeTimeout -> Max;
_ -> ResumeTimeout
end.
get_ack_timeout(Host, Opts) ->
VFun = mod_opt_type(ack_timeout),
T = case gen_mod:get_module_opt(Host, ?MODULE, ack_timeout, VFun) of
undefined -> gen_mod:get_opt(ack_timeout, Opts, VFun, 60);
AckTimeout -> AckTimeout
end,
case T of
infinity -> infinity;
_ -> timer:seconds(T)
end.
get_resend_on_timeout(Host, Opts) ->
VFun = mod_opt_type(resend_on_timeout),
case gen_mod:get_module_opt(Host, ?MODULE, resend_on_timeout, VFun) of
undefined -> gen_mod:get_opt(resend_on_timeout, Opts, VFun, false);
Resend -> Resend
end.
mod_opt_type(max_ack_queue) ->
fun(I) when is_integer(I), I > 0 -> I;
(infinity) -> infinity
end;
mod_opt_type(resume_timeout) ->
fun(I) when is_integer(I), I >= 0 -> I end;
mod_opt_type(max_resume_timeout) ->
fun(I) when is_integer(I), I >= 0 -> I end;
mod_opt_type(ack_timeout) ->
fun(I) when is_integer(I), I > 0 -> I;
(infinity) -> infinity
end;
mod_opt_type(resend_on_timeout) ->
fun(B) when is_boolean(B) -> B;
(if_offline) -> if_offline
end;
mod_opt_type(_) ->
[max_ack_queue, resume_timeout, max_resume_timeout, ack_timeout,
resend_on_timeout].

File diff suppressed because it is too large Load Diff

856
src/xmpp_stream_out.erl Normal file
View File

@ -0,0 +1,856 @@
%%%-------------------------------------------------------------------
%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
%%% @copyright (C) 2016, Evgeny Khramtsov
%%% @doc
%%%
%%% @end
%%% Created : 14 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
%%%-------------------------------------------------------------------
-module(xmpp_stream_out).
-behaviour(gen_server).
-protocol({rfc, 6120}).
%% API
-export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1,
stop/1, send/2, close/1, close/2, establish/1, format_error/1,
set_timeout/2, get_transport/1, change_shaper/2]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
-define(FSMOPTS, [{debug, [trace]}]).
-else.
-define(FSMOPTS, []).
-endif.
-define(TCP_SEND_TIMEOUT, 15000).
-include("xmpp.hrl").
-include("logger.hrl").
-include_lib("kernel/include/inet.hrl").
-type state() :: map().
-type host_port() :: {inet:hostname(), inet:port_number()}.
-type ip_port() :: {inet:ip_address(), inet:port_number()}.
-type network_error() :: {error, inet:posix() | inet_res:res_error()}.
-type stop_reason() :: {idna, bad_string} |
{dns, inet:posix() | inet_res:res_error()} |
{stream, reset | stream_error()} |
{tls, term()} |
{pkix, binary()} |
{auth, atom() | binary() | string()} |
{socket, inet:posix() | closed | timeout}.
-callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
%%%===================================================================
%%% API
%%%===================================================================
start(Mod, Args, Opts) ->
gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
start_link(Mod, Args, Opts) ->
gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
call(Ref, Msg, Timeout) ->
gen_server:call(Ref, Msg, Timeout).
cast(Ref, Msg) ->
gen_server:cast(Ref, Msg).
reply(Ref, Reply) ->
gen_server:reply(Ref, Reply).
-spec connect(pid()) -> ok.
connect(Ref) ->
cast(Ref, connect).
-spec stop(pid()) -> ok;
(state()) -> no_return().
stop(Pid) when is_pid(Pid) ->
cast(Pid, stop);
stop(#{owner := Owner} = State) when Owner == self() ->
terminate(normal, State),
exit(normal);
stop(_) ->
erlang:error(badarg).
-spec send(pid(), xmpp_element()) -> ok;
(state(), xmpp_element()) -> state().
send(Pid, Pkt) when is_pid(Pid) ->
cast(Pid, {send, Pkt});
send(#{owner := Owner} = State, Pkt) when Owner == self() ->
send_element(State, Pkt);
send(_, _) ->
erlang:error(badarg).
-spec close(pid()) -> ok;
(state()) -> state().
close(Ref) ->
close(Ref, true).
-spec close(pid(), boolean()) -> ok;
(state(), boolean()) -> state().
close(Pid, SendTrailer) when is_pid(Pid) ->
cast(Pid, {close, SendTrailer});
close(#{owner := Owner} = State, SendTrailer) when Owner == self() ->
if SendTrailer -> send_trailer(State);
true -> close_socket(State)
end;
close(_, _) ->
erlang:error(badarg).
-spec establish(state()) -> state().
establish(State) ->
process_stream_established(State).
-spec set_timeout(state(), non_neg_integer() | infinity) -> state().
set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
case Timeout of
infinity -> State#{stream_timeout => infinity};
_ ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State#{stream_timeout => {Timeout, Time}}
end;
set_timeout(_, _) ->
erlang:error(badarg).
get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner})
when Owner == self() ->
SockMod:get_transport(Socket);
get_transport(_) ->
erlang:error(badarg).
-spec change_shaper(state(), shaper:shaper()) -> ok.
change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper)
when Owner == self() ->
SockMod:change_shaper(Socket, Shaper);
change_shaper(_, _) ->
erlang:error(badarg).
-spec format_error(stop_reason()) -> binary().
format_error({idna, _}) ->
<<"Not an IDN hostname">>;
format_error({dns, Reason}) ->
format("DNS lookup failed: ~s", [format_inet_error(Reason)]);
format_error({socket, Reason}) ->
format("Connection failed: ~s", [format_inet_error(Reason)]);
format_error({pkix, Reason}) ->
format("Peer certificate rejected: ~s", [Reason]);
format_error({stream, reset}) ->
<<"Stream reset by peer">>;
format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
format_error({tls, Reason}) ->
format("TLS failed: ~w", [Reason]);
format_error({auth, Reason}) ->
format("Authentication failed: ~s", [Reason]);
format_error(Err) ->
format("Unrecognized error: ~w", [Err]).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([Mod, SockMod, From, To, Opts]) ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State = #{owner => self(),
mod => Mod,
sockmod => SockMod,
server => From,
user => <<"">>,
resource => <<"">>,
lang => <<"">>,
remote_server => To,
xmlns => ?NS_SERVER,
stream_direction => out,
stream_timeout => {timer:seconds(30), Time},
stream_id => new_id(),
stream_encrypted => false,
stream_verified => false,
stream_authenticated => false,
stream_restarted => false,
stream_state => connecting},
case try Mod:init([State, Opts])
catch _:undef -> {ok, State}
end of
{ok, State1} ->
{_, State2, Timeout} = noreply(State1),
{ok, State2, Timeout};
Err ->
Err
end.
handle_call(Call, From, #{mod := Mod} = State) ->
noreply(try Mod:handle_call(Call, From, State)
catch _:undef -> State
end).
handle_cast(connect, #{remote_server := RemoteServer,
sockmod := SockMod,
stream_state := connecting} = State) ->
case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of
false ->
noreply(process_stream_close({error, {idna, bad_string}}, State));
ASCIIName ->
case resolve(binary_to_list(ASCIIName), State) of
{ok, AddrPorts} ->
case connect(AddrPorts, State) of
{ok, Socket, AddrPort} ->
SocketMonitor = SockMod:monitor(Socket),
State1 = State#{ip => AddrPort,
socket => Socket,
socket_monitor => SocketMonitor},
State2 = State1#{stream_state => wait_for_stream},
noreply(send_header(State2));
{error, Why} ->
Err = {error, {socket, Why}},
noreply(process_stream_close(Err, State))
end;
{error, Why} ->
noreply(process_stream_close({error, {dns, Why}}, State))
end
end;
handle_cast(connect, State) ->
%% Ignoring connection attempts in other states
noreply(State);
handle_cast({send, Pkt}, State) ->
noreply(send_element(State, Pkt));
handle_cast(stop, State) ->
{stop, normal, State};
handle_cast(Cast, #{mod := Mod} = State) ->
noreply(try Mod:handle_cast(Cast, State)
catch _:undef -> State
end).
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
#{stream_state := wait_for_stream,
xmlns := XMLNS, lang := MyLang} = State) ->
El = #xmlel{name = Name, attrs = Attrs},
try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt ->
noreply(process_stream(Pkt, State));
_ ->
noreply(send_element(State, xmpp:serr_invalid_xml()))
catch _:{xmpp_codec, Why} ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang),
noreply(send_element(State, Err))
end;
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false ->
Err = case Reason of
<<"XML stanza is too big">> ->
xmpp:serr_policy_violation(Reason, Lang);
_ ->
xmpp:serr_not_well_formed()
end,
noreply(send_element(State1, Err))
end;
handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> noreply(process_element(Pkt, State1))
end
catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, undefined, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
end
end;
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) ->
noreply(try Mod:handle_cdata(Data, State)
catch _:undef -> State
end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({error, {stream, reset}}, State));
handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_close({error, {socket, closed}}, State));
handle_info(timeout, #{mod := Mod} = State) ->
Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State)
catch _:undef when not Disconnected ->
send_element(State, xmpp:serr_connection_timeout());
_:undef ->
stop(State)
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
#{socket_monitor := MRef} = State) ->
noreply(process_stream_close({error, {socket, closed}}, State));
handle_info(Info, #{mod := Mod} = State) ->
noreply(try Mod:handle_info(Info, State)
catch _:undef -> State
end).
terminate(Reason, #{mod := Mod} = State) ->
case get(already_terminated) of
true ->
State;
_ ->
put(already_terminated, true),
try Mod:terminate(Reason, State)
catch _:undef -> ok
end,
send_trailer(State)
end.
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
Mod:code_change(OldVsn, State, Extra).
%%%===================================================================
%%% Internal functions
%%%===================================================================
-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
noreply(#{stream_timeout := infinity} = State) ->
{noreply, State, infinity};
noreply(#{stream_timeout := {MSecs, OldTime}} = State) ->
NewTime = p1_time_compat:monotonic_time(milli_seconds),
Timeout = max(0, MSecs - NewTime + OldTime),
{noreply, State, Timeout}.
-spec new_id() -> binary().
new_id() ->
randoms:get_string().
-spec is_disconnected(state()) -> boolean().
is_disconnected(#{stream_state := StreamState}) ->
StreamState == disconnected.
-spec process_stream_close(stop_reason(), state()) -> state().
process_stream_close(_, #{stream_state := disconnected} = State) ->
State;
process_stream_close(Reason, #{mod := Mod} = State) ->
State1 = send_trailer(State),
try Mod:handle_stream_close(Reason, State1)
catch _:undef -> stop(State1)
end.
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
process_stream_end(Reason, #{mod := Mod} = State) ->
State1 = send_trailer(State),
try Mod:handle_stream_end(Reason, State1)
catch _:undef -> stop(State1)
end.
-spec process_stream(stream_start(), state()) -> state().
process_stream(#stream_start{xmlns = XML_NS,
stream_xmlns = STREAM_NS},
#{xmlns := NS} = State)
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
send_element(State, xmpp:serr_invalid_namespace());
process_stream(#stream_start{lang = Lang, id = ID,
version = Version} = StreamStart,
#{mod := Mod} = State) ->
State1 = State#{stream_remote_id => ID, lang => Lang},
State2 = try Mod:handle_stream_start(StreamStart, State1)
catch _:undef -> State1
end,
case is_disconnected(State2) of
true -> State2;
false ->
case Version of
{1,0} -> State2#{stream_state => wait_for_features};
_ -> process_stream_downgrade(StreamStart, State)
end
end.
-spec process_element(xmpp_element(), state()) -> state().
process_element(Pkt, #{stream_state := StateName} = State) ->
case Pkt of
#stream_features{} when StateName == wait_for_features ->
process_features(Pkt, State);
#starttls_proceed{} when StateName == wait_for_starttls_response ->
process_starttls(State);
#sasl_success{} when StateName == wait_for_sasl_response ->
process_sasl_success(State);
#sasl_failure{} when StateName == wait_for_sasl_response ->
process_sasl_failure(Pkt, State);
#stream_error{} ->
process_stream_end({error, {stream, Pkt}}, State);
_ when is_record(Pkt, stream_features);
is_record(Pkt, starttls_proceed);
is_record(Pkt, starttls);
is_record(Pkt, sasl_auth);
is_record(Pkt, sasl_success);
is_record(Pkt, sasl_failure);
is_record(Pkt, sasl_response);
is_record(Pkt, sasl_abort);
is_record(Pkt, compress);
is_record(Pkt, handshake) ->
%% Do not pass this crap upstream
State;
_ ->
process_packet(Pkt, State)
end.
-spec process_features(stream_features(), state()) -> state().
process_features(StreamFeatures,
#{stream_authenticated := true, mod := Mod} = State) ->
State1 = try Mod:handle_authenticated_features(StreamFeatures, State)
catch _:undef -> State
end,
process_stream_established(State1);
process_features(#stream_features{sub_els = Els} = StreamFeatures,
#{stream_encrypted := Encrypted,
mod := Mod, lang := Lang} = State) ->
State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
TLSRequired = is_starttls_required(State1),
%% TODO: improve xmpp.erl
Msg = #message{sub_els = Els},
case xmpp:get_subtag(Msg, #starttls{}) of
false when TLSRequired and not Encrypted ->
Txt = <<"Use of STARTTLS required">>,
send_element(State1, xmpp:err_policy_violation(Txt, Lang));
#starttls{} when not Encrypted ->
State2 = State1#{stream_state => wait_for_starttls_response},
send_element(State2, #starttls{});
_ ->
State2 = process_cert_verification(State1),
case is_disconnected(State2) of
true -> State2;
false ->
case xmpp:get_subtag(Msg, #sasl_mechanisms{}) of
#sasl_mechanisms{list = Mechs} ->
process_sasl_mechanisms(Mechs, State2);
false ->
process_sasl_failure(
#sasl_failure{reason = 'invalid-mechanism'},
State2)
end
end
end
end.
-spec process_stream_established(state()) -> state().
process_stream_established(#{stream_state := StateName} = State)
when StateName == disconnected; StateName == established ->
State;
process_stream_established(#{mod := Mod} = State) ->
State1 = State#{stream_authenticated := true,
stream_state => established,
stream_timeout => infinity},
try Mod:handle_stream_established(State1)
catch _:undef -> State1
end.
-spec process_sasl_mechanisms([binary()], state()) -> state().
process_sasl_mechanisms(Mechs, #{user := User, server := Server} = State) ->
%% TODO: support other mechanisms
Mech = <<"EXTERNAL">>,
case lists:member(<<"EXTERNAL">>, Mechs) of
true ->
State1 = State#{stream_state => wait_for_sasl_response},
Authzid = jid:to_string(jid:make(User, Server)),
send_element(State1, #sasl_auth{mechanism = Mech, text = Authzid});
false ->
process_sasl_failure(
#sasl_failure{reason = 'invalid-mechanism'}, State)
end.
-spec process_starttls(state()) -> state().
process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) ->
TLSOpts = try Mod:tls_options(State)
catch _:undef -> []
end,
case SockMod:starttls(Socket, [connect|TLSOpts]) of
{ok, TLSSocket} ->
State1 = State#{socket => TLSSocket,
stream_id => new_id(),
stream_restarted => true,
stream_state => wait_for_stream,
stream_encrypted => true},
send_header(State1);
{error, Why} ->
process_stream_close({error, {tls, Why}}, State)
end.
-spec process_stream_downgrade(stream_start(), state()) -> state().
process_stream_downgrade(StreamStart, #{mod := Mod} = State) ->
try Mod:downgrade_stream(StreamStart, State)
catch _:undef ->
send_element(State, xmpp:serr_unsupported_version())
end.
-spec process_cert_verification(state()) -> state().
process_cert_verification(#{stream_encrypted := true,
stream_verified := false,
mod := Mod} = State) ->
case try Mod:tls_verify(State)
catch _:undef -> true
end of
true ->
case xmpp_stream_pkix:authenticate(State) of
{ok, _} ->
State#{stream_verified => true};
{error, Why, _Peer} ->
process_stream_close({error, {pkix, Why}}, State)
end;
false ->
State#{stream_verified => true}
end;
process_cert_verification(State) ->
State.
-spec process_sasl_success(state()) -> state().
process_sasl_success(#{mod := Mod,
sockmod := SockMod,
socket := Socket} = State) ->
State1 = try Mod:handle_auth_success(<<"EXTERNAL">>, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
SockMod:reset_stream(Socket),
State2 = State1#{stream_id => new_id(),
stream_restarted => true,
stream_state => wait_for_stream,
stream_authenticated => true},
send_header(State2)
end.
-spec process_sasl_failure(sasl_failure(), state()) -> state().
process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) ->
try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State)
catch _:undef -> process_stream_close({error, {auth, Reason}}, State)
end.
-spec process_packet(xmpp_element(), state()) -> state().
process_packet(Pkt, #{mod := Mod} = State) ->
try Mod:handle_packet(Pkt, State)
catch _:undef -> State
end.
-spec is_starttls_required(state()) -> boolean().
is_starttls_required(#{mod := Mod} = State) ->
try Mod:tls_required(State)
catch _:undef -> false
end.
-spec send_header(state()) -> state().
send_header(#{remote_server := RemoteServer,
stream_encrypted := Encrypted,
lang := Lang,
xmlns := NS,
user := User,
resource := Resource,
server := Server} = State) ->
NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
true -> <<"">>
end,
From = if Encrypted ->
jid:make(User, Server, Resource);
NS == ?NS_SERVER ->
jid:make(Server);
true ->
undefined
end,
Header = xmpp:encode(
#stream_start{xmlns = NS,
lang = Lang,
stream_xmlns = ?NS_STREAM,
db_xmlns = NS_DB,
from = From,
to = jid:make(RemoteServer),
version = {1,0}}),
case send_text(State, fxml:element_to_header(Header)) of
ok -> State;
{error, Why} -> process_stream_close({error, {socket, Why}}, State)
end.
-spec send_element(state(), xmpp_element()) -> state().
send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
El = xmpp:encode(Pkt, NS),
Data = fxml:element_to_binary(El),
State1 = try Mod:handle_send(Pkt, El, Data, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
case send_text(State1, Data) of
_ when is_record(Pkt, stream_error) ->
process_stream_end({error, {stream, Pkt}}, State1);
ok ->
State1;
{error, Why} ->
process_stream_close({error, {socket, Why}}, State1)
end
end.
-spec send_error(state(), xmpp_element(), stanza_error()) -> state().
send_error(State, Pkt, Err) ->
case xmpp:is_stanza(Pkt) of
true ->
case xmpp:get_type(Pkt) of
result -> State;
error -> State;
<<"result">> -> State;
<<"error">> -> State;
_ ->
ErrPkt = xmpp:make_error(Pkt, Err),
send_element(State, ErrPkt)
end;
false ->
State
end.
-spec send_text(state(), binary()) -> ok | {error, inet:posix()}.
send_text(#{sockmod := SockMod, socket := Socket,
stream_state := StateName}, Data) when StateName /= disconnected ->
SockMod:send(Socket, Data);
send_text(_, _) ->
{error, einval}.
-spec send_trailer(state()) -> state().
send_trailer(State) ->
send_text(State, <<"</stream:stream>">>),
close_socket(State).
-spec close_socket(state()) -> state().
close_socket(State) ->
case State of
#{sockmod := SockMod, socket := Socket} ->
SockMod:close(Socket);
_ ->
ok
end,
State#{stream_timeout => infinity,
stream_state => disconnected}.
-spec select_lang(binary(), binary()) -> binary().
select_lang(Lang, <<"">>) -> Lang;
select_lang(_, Lang) -> Lang.
-spec format_inet_error(atom()) -> string().
format_inet_error(Reason) ->
case inet:format_error(Reason) of
"unknown POSIX error" -> atom_to_list(Reason);
Txt -> Txt
end.
-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
format_stream_error(Reason, Txt) ->
Slogan = case Reason of
#'see-other-host'{} -> "see-other-host";
_ -> atom_to_list(Reason)
end,
case Txt of
undefined -> Slogan;
#text{data = <<"">>} -> Slogan;
#text{data = Data} ->
binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
end.
-spec format(io:format(), list()) -> binary().
format(Fmt, Args) ->
iolist_to_binary(io_lib:format(Fmt, Args)).
%%%===================================================================
%%% Connection stuff
%%%===================================================================
-spec resolve(string(), state()) -> {ok, [host_port()]} | network_error().
resolve(Host, State) ->
case srv_lookup(Host, State) of
{error, _Reason} ->
DefaultPort = get_default_port(State),
a_lookup([{Host, DefaultPort}], State);
{ok, HostPorts} ->
a_lookup(HostPorts, State)
end.
-spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error().
srv_lookup(Host, State) ->
%% Only perform SRV lookups for FQDN names
case string:chr(Host, $.) of
0 ->
{error, nxdomain};
_ ->
case inet_parse:address(Host) of
{ok, _} ->
{error, nxdomain};
{error, _} ->
Timeout = get_dns_timeout(State),
Retries = get_dns_retries(State),
srv_lookup(Host, Timeout, Retries)
end
end.
-spec srv_lookup(string(), non_neg_integer(), integer()) ->
{ok, [host_port()]} | network_error().
srv_lookup(_Host, _Timeout, Retries) when Retries < 1 ->
{error, timeout};
srv_lookup(Host, Timeout, Retries) ->
SRVName = "_xmpp-server._tcp." ++ Host,
case inet_res:getbyname(SRVName, srv, Timeout) of
{ok, HostEntry} ->
host_entry_to_host_ports(HostEntry);
{error, _} ->
LegacySRVName = "_jabber._tcp." ++ Host,
case inet_res:getbyname(LegacySRVName, srv, Timeout) of
{error, timeout} ->
srv_lookup(Host, Timeout, Retries - 1);
{error, _} = Err ->
Err;
{ok, HostEntry} ->
host_entry_to_host_ports(HostEntry)
end
end.
-spec a_lookup([{inet:hostname(), inet:port_number()}], state()) ->
{ok, [ip_port()]} | network_error().
a_lookup(HostPorts, State) ->
HostPortFamilies = [{Host, Port, Family}
|| {Host, Port} <- HostPorts,
Family <- get_address_families(State)],
a_lookup(HostPortFamilies, State, {error, nxdomain}).
-spec a_lookup([{inet:hostname(), inet:port_number(), inet:address_family()}],
state(), network_error()) -> {ok, [ip_port()]} | network_error().
a_lookup([{Host, Port, Family}|HostPortFamilies], State, _) ->
Timeout = get_dns_timeout(State),
Retries = get_dns_retries(State),
case a_lookup(Host, Port, Family, Timeout, Retries) of
{error, _} = Err ->
a_lookup(HostPortFamilies, State, Err);
{ok, AddrPorts} ->
{ok, AddrPorts}
end;
a_lookup([], _State, Err) ->
Err.
-spec a_lookup(inet:hostname(), inet:port_number(), inet:address_family(),
non_neg_integer(), integer()) -> {ok, [ip_port()]} | network_error().
a_lookup(_Host, _Port, _Family, _Timeout, Retries) when Retries < 1 ->
{error, timeout};
a_lookup(Host, Port, Family, Timeout, Retries) ->
case inet:gethostbyname(Host, Family, Timeout) of
{error, timeout} ->
a_lookup(Host, Port, Family, Timeout, Retries - 1);
{error, _} = Err ->
Err;
{ok, HostEntry} ->
host_entry_to_addr_ports(HostEntry, Port)
end.
-spec host_entry_to_host_ports(inet:hostent()) -> {ok, [host_port()]} |
{error, nxdomain}.
host_entry_to_host_ports(#hostent{h_addr_list = AddrList}) ->
PrioHostPorts = lists:flatmap(
fun({Priority, Weight, Port, Host}) ->
N = case Weight of
0 -> 0;
_ -> (Weight + 1) * randoms:uniform()
end,
[{Priority * 65536 - N, Host, Port}];
(_) ->
[]
end, AddrList),
HostPorts = [{Host, Port}
|| {_Priority, Host, Port} <- lists:usort(PrioHostPorts)],
case HostPorts of
[] -> {error, nxdomain};
_ -> {ok, HostPorts}
end.
-spec host_entry_to_addr_ports(inet:hostent(), inet:port_number()) ->
{ok, [ip_port()]} | {error, nxdomain}.
host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port) ->
AddrPorts = lists:flatmap(
fun(Addr) ->
try get_addr_type(Addr) of
_ -> [{Addr, Port}]
catch _:_ ->
[]
end
end, AddrList),
case AddrPorts of
[] -> {error, nxdomain};
_ -> {ok, AddrPorts}
end.
-spec connect([ip_port()], state()) -> {ok, term(), ip_port()} | network_error().
connect(AddrPorts, #{sockmod := SockMod} = State) ->
Timeout = get_connect_timeout(State),
connect(AddrPorts, SockMod, Timeout, {error, nxdomain}).
-spec connect([ip_port()], module(), non_neg_integer(), network_error()) ->
{ok, term(), ip_port()} | network_error().
connect([{Addr, Port}|AddrPorts], SockMod, Timeout, _) ->
Type = get_addr_type(Addr),
case SockMod:connect(Addr, Port,
[binary, {packet, 0},
{send_timeout, ?TCP_SEND_TIMEOUT},
{send_timeout_close, true},
{active, false}, Type],
Timeout) of
{ok, Socket} ->
{ok, Socket, {Addr, Port}};
Err ->
connect(AddrPorts, SockMod, Timeout, Err)
end;
connect([], _SockMod, _Timeout, Err) ->
Err.
-spec get_addr_type(inet:ip_address()) -> inet:address_family().
get_addr_type({_, _, _, _}) -> inet;
get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
-spec get_dns_timeout(state()) -> non_neg_integer().
get_dns_timeout(#{mod := Mod} = State) ->
timer:seconds(
try Mod:dns_timeout(State)
catch _:undef -> 10
end).
-spec get_dns_retries(state()) -> non_neg_integer().
get_dns_retries(#{mod := Mod} = State) ->
try Mod:dns_retries(State)
catch _:undef -> 2
end.
-spec get_default_port(state()) -> inet:port_number().
get_default_port(#{mod := Mod, xmlns := NS} = State) ->
try Mod:default_port(State)
catch _:undef when NS == ?NS_SERVER -> 5269;
_:undef when NS == ?NS_CLIENT -> 5222
end.
-spec get_address_families(state()) -> [inet:address_family()].
get_address_families(#{mod := Mod} = State) ->
try Mod:address_families(State)
catch _:undef -> [inet, inet6]
end.
-spec get_connect_timeout(state()) -> non_neg_integer().
get_connect_timeout(#{mod := Mod} = State) ->
timer:seconds(
try Mod:connect_timeout(State)
catch _:undef -> 10
end).

159
src/xmpp_stream_pkix.erl Normal file
View File

@ -0,0 +1,159 @@
%%%-------------------------------------------------------------------
%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
%%% @copyright (C) 2016, Evgeny Khramtsov
%%% @doc
%%%
%%% @end
%%% Created : 13 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
%%%-------------------------------------------------------------------
-module(xmpp_stream_pkix).
%% API
-export([authenticate/1, authenticate/2]).
-include("xmpp.hrl").
-include_lib("public_key/include/public_key.hrl").
-include("XmppAddr.hrl").
%%%===================================================================
%%% API
%%%===================================================================
-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state())
-> {ok, binary()} | {error, binary(), binary()}.
authenticate(State) ->
authenticate(State, <<"">>).
-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state(), binary())
-> {ok, binary()} | {error, binary(), binary()}.
authenticate(#{xmlns := ?NS_SERVER, remote_server := Peer,
sockmod := SockMod, socket := Socket}, _Authzid) ->
case SockMod:get_peer_certificate(Socket) of
{ok, Cert} ->
case SockMod:get_verify_result(Socket) of
0 ->
case ejabberd_idna:domain_utf8_to_ascii(Peer) of
false ->
{error, <<"Cannot decode remote server name">>, Peer};
AsciiPeer ->
case lists:any(
fun(D) -> match_domain(AsciiPeer, D) end,
get_cert_domains(Cert)) of
true ->
{ok, Peer};
false ->
{error, <<"Certificate host name mismatch">>, Peer}
end
end;
VerifyRes ->
{error, fast_tls:get_cert_verify_string(VerifyRes, Cert), Peer}
end;
{error, _Reason} ->
{error, <<"Cannot get peer certificate">>, Peer};
error ->
{error, <<"Cannot get peer certificate">>, Peer}
end;
authenticate(_State, _Authzid) ->
%% TODO: client PKIX authentication
{error, <<"Client certificate verification not implemented">>, <<"">>}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
get_cert_domains(Cert) ->
TBSCert = Cert#'Certificate'.tbsCertificate,
Subject = case TBSCert#'TBSCertificate'.subject of
{rdnSequence, Subj} -> lists:flatten(Subj);
_ -> []
end,
Extensions = case TBSCert#'TBSCertificate'.extensions of
Exts when is_list(Exts) -> Exts;
_ -> []
end,
lists:flatmap(
fun(#'AttributeTypeAndValue'{type = ?'id-at-commonName',value = Val}) ->
case 'OTP-PUB-KEY':decode('X520CommonName', Val) of
{ok, {_, D1}} ->
D = if is_binary(D1) -> D1;
is_list(D1) -> list_to_binary(D1);
true -> error
end,
if D /= error ->
case jid:from_string(D) of
#jid{luser = <<"">>, lserver = LD,
lresource = <<"">>} ->
[LD];
_ -> []
end;
true -> []
end;
_ -> []
end;
(_) -> []
end, Subject) ++
lists:flatmap(
fun(#'Extension'{extnID = ?'id-ce-subjectAltName',
extnValue = Val}) ->
BVal = if is_list(Val) -> list_to_binary(Val);
true -> Val
end,
case 'OTP-PUB-KEY':decode('SubjectAltName', BVal) of
{ok, SANs} ->
lists:flatmap(
fun({otherName, #'AnotherName'{'type-id' = ?'id-on-xmppAddr',
value = XmppAddr}}) ->
case 'XmppAddr':decode('XmppAddr', XmppAddr) of
{ok, D} when is_binary(D) ->
case jid:from_string(D) of
#jid{luser = <<"">>,
lserver = LD,
lresource = <<"">>} ->
case ejabberd_idna:domain_utf8_to_ascii(LD) of
false ->
[];
PCLD ->
[PCLD]
end;
_ -> []
end;
_ -> []
end;
({dNSName, D}) when is_list(D) ->
case jid:from_string(list_to_binary(D)) of
#jid{luser = <<"">>,
lserver = LD,
lresource = <<"">>} ->
[LD];
_ -> []
end;
(_) -> []
end, SANs);
_ -> []
end;
(_) -> []
end, Extensions).
match_domain(Domain, Domain) -> true;
match_domain(Domain, Pattern) ->
DLabels = str:tokens(Domain, <<".">>),
PLabels = str:tokens(Pattern, <<".">>),
match_labels(DLabels, PLabels).
match_labels([], []) -> true;
match_labels([], [_ | _]) -> false;
match_labels([_ | _], []) -> false;
match_labels([DL | DLabels], [PL | PLabels]) ->
case lists:all(fun (C) ->
$a =< C andalso C =< $z orelse
$0 =< C andalso C =< $9 orelse
C == $- orelse C == $*
end,
binary_to_list(PL))
of
true ->
Regexp = ejabberd_regexp:sh_to_awk(PL),
case ejabberd_regexp:run(DL, Regexp) of
match -> match_labels(DLabels, PLabels);
nomatch -> false
end;
false -> false
end.