mirror of
https://github.com/processone/ejabberd.git
synced 2024-12-22 17:28:25 +01:00
1133 lines
37 KiB
Erlang
1133 lines
37 KiB
Erlang
%%%-------------------------------------------------------------------
|
|
%%% Created : 14 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
|
|
%%%
|
|
%%%
|
|
%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
|
|
%%%
|
|
%%% This program is free software; you can redistribute it and/or
|
|
%%% modify it under the terms of the GNU General Public License as
|
|
%%% published by the Free Software Foundation; either version 2 of the
|
|
%%% License, or (at your option) any later version.
|
|
%%%
|
|
%%% This program is distributed in the hope that it will be useful,
|
|
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
%%% General Public License for more details.
|
|
%%%
|
|
%%% You should have received a copy of the GNU General Public License along
|
|
%%% with this program; if not, write to the Free Software Foundation, Inc.,
|
|
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
|
%%%
|
|
%%%-------------------------------------------------------------------
|
|
-module(xmpp_stream_out).
|
|
-define(GEN_SERVER, p1_server).
|
|
-behaviour(?GEN_SERVER).
|
|
|
|
-protocol({rfc, 6120}).
|
|
-protocol({xep, 114, '1.6'}).
|
|
-protocol({xep, 368, '1.0.0'}).
|
|
|
|
%% API
|
|
-export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1,
|
|
stop/1, send/2, close/1, close/2, establish/1, format_error/1,
|
|
set_timeout/2, get_transport/1, change_shaper/2]).
|
|
%% gen_server callbacks
|
|
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
|
|
terminate/2, code_change/3]).
|
|
|
|
%%-define(DBGFSM, true).
|
|
-ifdef(DBGFSM).
|
|
-define(FSMOPTS, [{debug, [trace]}]).
|
|
-else.
|
|
-define(FSMOPTS, []).
|
|
-endif.
|
|
|
|
-define(TCP_SEND_TIMEOUT, 15000).
|
|
|
|
-include("xmpp.hrl").
|
|
-include_lib("kernel/include/inet.hrl").
|
|
|
|
-type state() :: map().
|
|
-type noreply() :: {noreply, state(), timeout()}.
|
|
-type host_port() :: {inet:hostname(), inet:port_number(), boolean()}.
|
|
-type ip_port() :: {inet:ip_address(), inet:port_number(), boolean()}.
|
|
-type h_addr_list() :: {{integer(), integer(), inet:port_number(), string()}, boolean()}.
|
|
-type network_error() :: {error, inet:posix() | inet_res:res_error()}.
|
|
-type tls_error_reason() :: inet:posix() | atom() | binary().
|
|
-type socket_error_reason() :: inet:posix() | atom().
|
|
-type stop_reason() :: {idna, bad_string} |
|
|
{dns, inet:posix() | inet_res:res_error()} |
|
|
{stream, reset | {in | out, stream_error()}} |
|
|
{tls, tls_error_reason()} |
|
|
{pkix, binary()} |
|
|
{auth, atom() | binary() | string()} |
|
|
{socket, socket_error_reason()} |
|
|
internal_failure.
|
|
-export_type([state/0, stop_reason/0]).
|
|
-callback init(list()) -> {ok, state()} | {error, term()} | ignore.
|
|
-callback handle_cast(term(), state()) -> state().
|
|
-callback handle_call(term(), term(), state()) -> state().
|
|
-callback handle_info(term(), state()) -> state().
|
|
-callback terminate(term(), state()) -> any().
|
|
-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
|
|
-callback handle_stream_start(stream_start(), state()) -> state().
|
|
-callback handle_stream_established(state()) -> state().
|
|
-callback handle_stream_downgraded(stream_start(), state()) -> state().
|
|
-callback handle_stream_end(stop_reason(), state()) -> state().
|
|
-callback handle_cdata(binary(), state()) -> state().
|
|
-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
|
|
-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
|
|
-callback handle_timeout(state()) -> state().
|
|
-callback handle_authenticated_features(stream_features(), state()) -> state().
|
|
-callback handle_unauthenticated_features(stream_features(), state()) -> state().
|
|
-callback handle_auth_success(cyrsasl:mechanism(), state()) -> state().
|
|
-callback handle_auth_failure(cyrsasl:mechanism(), binary(), state()) -> state().
|
|
-callback handle_packet(xmpp_element(), state()) -> state().
|
|
-callback tls_options(state()) -> [proplists:property()].
|
|
-callback tls_required(state()) -> boolean().
|
|
-callback tls_verify(state()) -> boolean().
|
|
-callback tls_enabled(state()) -> boolean().
|
|
-callback dns_timeout(state()) -> timeout().
|
|
-callback dns_retries(state()) -> non_neg_integer().
|
|
-callback default_port(state()) -> inet:port_number().
|
|
-callback address_families(state()) -> [inet:address_family()].
|
|
-callback connect_timeout(state()) -> timeout().
|
|
|
|
-optional_callbacks([init/1,
|
|
handle_cast/2,
|
|
handle_call/3,
|
|
handle_info/2,
|
|
terminate/2,
|
|
code_change/3,
|
|
handle_stream_start/2,
|
|
handle_stream_established/1,
|
|
handle_stream_downgraded/2,
|
|
handle_stream_end/2,
|
|
handle_cdata/2,
|
|
handle_send/3,
|
|
handle_recv/3,
|
|
handle_timeout/1,
|
|
handle_authenticated_features/2,
|
|
handle_unauthenticated_features/2,
|
|
handle_auth_success/2,
|
|
handle_auth_failure/3,
|
|
handle_packet/2,
|
|
tls_options/1,
|
|
tls_required/1,
|
|
tls_verify/1,
|
|
tls_enabled/1,
|
|
dns_timeout/1,
|
|
dns_retries/1,
|
|
default_port/1,
|
|
address_families/1,
|
|
connect_timeout/1]).
|
|
|
|
%%%===================================================================
|
|
%%% API
|
|
%%%===================================================================
|
|
start(Mod, Args, Opts) ->
|
|
?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
|
|
|
|
start_link(Mod, Args, Opts) ->
|
|
?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
|
|
|
|
call(Ref, Msg, Timeout) ->
|
|
?GEN_SERVER:call(Ref, Msg, Timeout).
|
|
|
|
cast(Ref, Msg) ->
|
|
?GEN_SERVER:cast(Ref, Msg).
|
|
|
|
reply(Ref, Reply) ->
|
|
?GEN_SERVER:reply(Ref, Reply).
|
|
|
|
-spec connect(pid()) -> ok.
|
|
connect(Ref) ->
|
|
cast(Ref, connect).
|
|
|
|
-spec stop(pid()) -> ok;
|
|
(state()) -> no_return().
|
|
stop(Pid) when is_pid(Pid) ->
|
|
cast(Pid, stop);
|
|
stop(#{owner := Owner} = State) when Owner == self() ->
|
|
terminate(normal, State),
|
|
exit(normal);
|
|
stop(_) ->
|
|
erlang:error(badarg).
|
|
|
|
-spec send(pid(), xmpp_element()) -> ok;
|
|
(state(), xmpp_element()) -> state().
|
|
send(Pid, Pkt) when is_pid(Pid) ->
|
|
cast(Pid, {send, Pkt});
|
|
send(#{owner := Owner} = State, Pkt) when Owner == self() ->
|
|
send_pkt(State, Pkt);
|
|
send(_, _) ->
|
|
erlang:error(badarg).
|
|
|
|
-spec close(pid()) -> ok;
|
|
(state()) -> state().
|
|
close(Pid) when is_pid(Pid) ->
|
|
close(Pid, closed);
|
|
close(#{owner := Owner} = State) when Owner == self() ->
|
|
close_socket(State);
|
|
close(_) ->
|
|
erlang:error(badarg).
|
|
|
|
-spec close(pid(), atom()) -> ok.
|
|
close(Pid, Reason) ->
|
|
cast(Pid, {close, Reason}).
|
|
|
|
-spec establish(state()) -> state().
|
|
establish(State) ->
|
|
process_stream_established(State).
|
|
|
|
-spec set_timeout(state(), timeout()) -> state().
|
|
set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
|
|
case Timeout of
|
|
infinity -> State#{stream_timeout => infinity};
|
|
_ ->
|
|
Time = p1_time_compat:monotonic_time(milli_seconds),
|
|
State#{stream_timeout => {Timeout, Time}}
|
|
end;
|
|
set_timeout(_, _) ->
|
|
erlang:error(badarg).
|
|
|
|
get_transport(#{socket := Socket, owner := Owner})
|
|
when Owner == self() ->
|
|
xmpp_socket:get_transport(Socket);
|
|
get_transport(_) ->
|
|
erlang:error(badarg).
|
|
|
|
-spec change_shaper(state(), shaper:shaper()) -> state().
|
|
change_shaper(#{socket := Socket, owner := Owner} = State, Shaper)
|
|
when Owner == self() ->
|
|
Socket1 = xmpp_socket:change_shaper(Socket, Shaper),
|
|
State#{socket => Socket1};
|
|
change_shaper(_, _) ->
|
|
erlang:error(badarg).
|
|
|
|
-spec format_error(stop_reason()) -> binary().
|
|
format_error({idna, _}) ->
|
|
<<"Remote domain is not an IDN hostname">>;
|
|
format_error({dns, Reason}) ->
|
|
format("DNS lookup failed: ~s", [format_inet_error(Reason)]);
|
|
format_error({socket, Reason}) ->
|
|
format("Connection failed: ~s", [format_inet_error(Reason)]);
|
|
format_error({pkix, Reason}) ->
|
|
{_, ErrTxt} = xmpp_stream_pkix:format_error(Reason),
|
|
format("Peer certificate rejected: ~s", [ErrTxt]);
|
|
format_error({stream, reset}) ->
|
|
<<"Stream reset by peer">>;
|
|
format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
|
|
format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]);
|
|
format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
|
|
format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
|
|
format_error({tls, Reason}) ->
|
|
format("TLS failed: ~s", [format_tls_error(Reason)]);
|
|
format_error({auth, Reason}) ->
|
|
format("Authentication failed: ~s", [Reason]);
|
|
format_error(internal_failure) ->
|
|
<<"Internal server error">>;
|
|
format_error(Err) ->
|
|
format("Unrecognized error: ~w", [Err]).
|
|
|
|
%%%===================================================================
|
|
%%% gen_server callbacks
|
|
%%%===================================================================
|
|
-spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore.
|
|
init([Mod, _SockMod, From, To, Opts]) ->
|
|
Time = p1_time_compat:monotonic_time(milli_seconds),
|
|
State = #{owner => self(),
|
|
mod => Mod,
|
|
server => From,
|
|
user => <<"">>,
|
|
resource => <<"">>,
|
|
lang => <<"">>,
|
|
remote_server => To,
|
|
xmlns => ?NS_SERVER,
|
|
codec_options => [ignore_els],
|
|
stream_direction => out,
|
|
stream_timeout => {timer:seconds(30), Time},
|
|
stream_id => new_id(),
|
|
stream_encrypted => false,
|
|
stream_verified => false,
|
|
stream_authenticated => false,
|
|
stream_restarted => false,
|
|
stream_state => connecting},
|
|
case try Mod:init([State, Opts])
|
|
catch _:undef -> {ok, State}
|
|
end of
|
|
{ok, State1} ->
|
|
{_, State2, Timeout} = noreply(State1),
|
|
{ok, State2, Timeout};
|
|
{error, Reason} ->
|
|
{stop, Reason};
|
|
ignore ->
|
|
ignore
|
|
end.
|
|
|
|
-spec handle_call(term(), term(), state()) -> noreply().
|
|
handle_call(Call, From, State) ->
|
|
noreply(try callback(handle_call, Call, From, State)
|
|
catch _:{?MODULE, undef} -> State
|
|
end).
|
|
|
|
-spec handle_cast(term(), state()) -> noreply().
|
|
handle_cast(connect, #{remote_server := RemoteServer,
|
|
stream_state := connecting} = State) ->
|
|
noreply(
|
|
case idna_to_ascii(RemoteServer) of
|
|
false ->
|
|
process_stream_end({idna, bad_string}, State);
|
|
ASCIIName ->
|
|
case resolve(binary_to_list(ASCIIName), State) of
|
|
{ok, AddrPorts} ->
|
|
case connect(AddrPorts, State) of
|
|
{ok, Socket, {Addr, Port, Encrypted}} ->
|
|
SocketMonitor = xmpp_socket:monitor(Socket),
|
|
State1 = State#{ip => {Addr, Port},
|
|
socket => Socket,
|
|
stream_encrypted => Encrypted,
|
|
socket_monitor => SocketMonitor},
|
|
State2 = State1#{stream_state => wait_for_stream},
|
|
send_header(State2);
|
|
{error, {Class, Why}} ->
|
|
process_stream_end({Class, Why}, State)
|
|
end;
|
|
{error, Why} ->
|
|
process_stream_end({dns, Why}, State)
|
|
end
|
|
end);
|
|
handle_cast(connect, State) ->
|
|
%% Ignoring connection attempts in other states
|
|
noreply(State);
|
|
handle_cast({send, Pkt}, State) ->
|
|
noreply(send_pkt(State, Pkt));
|
|
handle_cast(stop, State) ->
|
|
{stop, normal, State};
|
|
handle_cast({close, Reason}, State) ->
|
|
State1 = close_socket(State),
|
|
noreply(
|
|
case is_disconnected(State) of
|
|
true -> State1;
|
|
false -> process_stream_end({socket, Reason}, State)
|
|
end);
|
|
handle_cast(Cast, State) ->
|
|
noreply(try callback(handle_cast, Cast, State)
|
|
catch _:{?MODULE, undef} -> State
|
|
end).
|
|
|
|
-spec handle_info(term(), state()) -> noreply().
|
|
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
|
|
#{stream_state := wait_for_stream,
|
|
xmlns := XMLNS, lang := MyLang} = State) ->
|
|
El = #xmlel{name = Name, attrs = Attrs},
|
|
noreply(
|
|
try xmpp:decode(El, XMLNS, []) of
|
|
#stream_start{} = Pkt ->
|
|
process_stream(Pkt, State);
|
|
_ ->
|
|
send_pkt(State, xmpp:serr_invalid_xml())
|
|
catch _:{xmpp_codec, Why} ->
|
|
Txt = xmpp:io_format_error(Why),
|
|
Lang = select_lang(MyLang, xmpp:get_lang(El)),
|
|
Err = xmpp:serr_invalid_xml(Txt, Lang),
|
|
send_pkt(State, Err)
|
|
end);
|
|
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
|
|
State1 = send_header(State),
|
|
noreply(
|
|
case is_disconnected(State1) of
|
|
true -> State1;
|
|
false ->
|
|
Err = case Reason of
|
|
<<"XML stanza is too big">> ->
|
|
xmpp:serr_policy_violation(Reason, Lang);
|
|
{_, Txt} ->
|
|
xmpp:serr_not_well_formed(Txt, Lang)
|
|
end,
|
|
send_pkt(State1, Err)
|
|
end);
|
|
handle_info({'$gen_event', {xmlstreamelement, El}},
|
|
#{xmlns := NS, codec_options := Opts} = State) ->
|
|
noreply(
|
|
try xmpp:decode(El, NS, Opts) of
|
|
Pkt ->
|
|
State1 = try callback(handle_recv, El, Pkt, State)
|
|
catch _:{?MODULE, undef} -> State
|
|
end,
|
|
case is_disconnected(State1) of
|
|
true -> State1;
|
|
false -> process_element(Pkt, State1)
|
|
end
|
|
catch _:{xmpp_codec, Why} ->
|
|
State1 = try callback(handle_recv, El, {error, Why}, State)
|
|
catch _:{?MODULE, undef} -> State
|
|
end,
|
|
case is_disconnected(State1) of
|
|
true -> State1;
|
|
false -> process_invalid_xml(State1, El, Why)
|
|
end
|
|
end);
|
|
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
|
|
State) ->
|
|
noreply(try callback(handle_cdata, Data, State)
|
|
catch _:{?MODULE, undef} -> State
|
|
end);
|
|
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
|
|
noreply(process_stream_end({stream, reset}, State));
|
|
handle_info({'$gen_event', closed}, State) ->
|
|
noreply(process_stream_end({socket, closed}, State));
|
|
handle_info(timeout, #{lang := Lang} = State) ->
|
|
Disconnected = is_disconnected(State),
|
|
noreply(try callback(handle_timeout, State)
|
|
catch _:{?MODULE, undef} when not Disconnected ->
|
|
Txt = <<"Idle connection">>,
|
|
send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
|
|
_:{?MODULE, undef} ->
|
|
stop(State)
|
|
end);
|
|
handle_info({'DOWN', MRef, _Type, _Object, _Info},
|
|
#{socket_monitor := MRef} = State) ->
|
|
noreply(process_stream_end({socket, closed}, State));
|
|
handle_info({tcp, _, Data}, #{socket := Socket} = State) ->
|
|
noreply(
|
|
case xmpp_socket:recv(Socket, Data) of
|
|
{ok, NewSocket} ->
|
|
State#{socket => NewSocket};
|
|
{error, Reason} when is_atom(Reason) ->
|
|
process_stream_end({socket, Reason}, State);
|
|
{error, Reason} ->
|
|
%% TODO: make fast_tls return atoms
|
|
process_stream_end({tls, Reason}, State)
|
|
end);
|
|
handle_info({tcp_closed, _}, State) ->
|
|
handle_info({'$gen_event', closed}, State);
|
|
handle_info({tcp_error, _, Reason}, State) ->
|
|
noreply(process_stream_end({socket, Reason}, State));
|
|
handle_info(Info, State) ->
|
|
noreply(try callback(handle_info, Info, State)
|
|
catch _:{?MODULE, undef} -> State
|
|
end).
|
|
|
|
-spec terminate(term(), state()) -> any().
|
|
terminate(Reason, State) ->
|
|
case get(already_terminated) of
|
|
true ->
|
|
State;
|
|
_ ->
|
|
put(already_terminated, true),
|
|
try callback(terminate, Reason, State)
|
|
catch _:{?MODULE, undef} -> ok
|
|
end,
|
|
send_trailer(State)
|
|
end.
|
|
|
|
code_change(OldVsn, State, Extra) ->
|
|
callback(code_change, OldVsn, State, Extra).
|
|
|
|
%%%===================================================================
|
|
%%% Internal functions
|
|
%%%===================================================================
|
|
-spec noreply(state()) -> noreply().
|
|
noreply(#{stream_timeout := infinity} = State) ->
|
|
{noreply, State, infinity};
|
|
noreply(#{stream_timeout := {MSecs, OldTime}} = State) ->
|
|
NewTime = p1_time_compat:monotonic_time(milli_seconds),
|
|
Timeout = max(0, MSecs - NewTime + OldTime),
|
|
{noreply, State, Timeout}.
|
|
|
|
-spec new_id() -> binary().
|
|
new_id() ->
|
|
randoms:get_string().
|
|
|
|
-spec is_disconnected(state()) -> boolean().
|
|
is_disconnected(#{stream_state := StreamState}) ->
|
|
StreamState == disconnected.
|
|
|
|
-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
|
|
process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
|
|
case xmpp:is_stanza(El) of
|
|
true ->
|
|
Txt = xmpp:io_format_error(Reason),
|
|
Lang = select_lang(MyLang, xmpp:get_lang(El)),
|
|
send_error(State, El, xmpp:err_bad_request(Txt, Lang));
|
|
false ->
|
|
State
|
|
end.
|
|
|
|
-spec process_stream_end(stop_reason(), state()) -> state().
|
|
process_stream_end(_, #{stream_state := disconnected} = State) ->
|
|
State;
|
|
process_stream_end(Reason, State) ->
|
|
State1 = send_trailer(State),
|
|
try callback(handle_stream_end, Reason, State1)
|
|
catch _:{?MODULE, undef} -> stop(State1)
|
|
end.
|
|
|
|
-spec process_stream(stream_start(), state()) -> state().
|
|
process_stream(#stream_start{xmlns = XML_NS,
|
|
stream_xmlns = STREAM_NS},
|
|
#{xmlns := NS} = State)
|
|
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
|
|
send_pkt(State, xmpp:serr_invalid_namespace());
|
|
process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
|
|
send_pkt(State, xmpp:serr_unsupported_version());
|
|
process_stream(#stream_start{lang = Lang, id = ID,
|
|
version = Version} = StreamStart,
|
|
State) ->
|
|
State1 = State#{stream_remote_id => ID, lang => Lang},
|
|
State2 = try callback(handle_stream_start, StreamStart, State1)
|
|
catch _:{?MODULE, undef} -> State1
|
|
end,
|
|
case is_disconnected(State2) of
|
|
true -> State2;
|
|
false ->
|
|
case Version of
|
|
{1, _} ->
|
|
State2#{stream_state => wait_for_features};
|
|
_ ->
|
|
process_stream_downgrade(StreamStart, State2)
|
|
end
|
|
end.
|
|
|
|
-spec process_element(xmpp_element(), state()) -> state().
|
|
process_element(Pkt, #{stream_state := StateName} = State) ->
|
|
case Pkt of
|
|
#stream_features{} when StateName == wait_for_features ->
|
|
process_features(Pkt, State);
|
|
#starttls_proceed{} when StateName == wait_for_starttls_response ->
|
|
process_starttls(State);
|
|
#sasl_success{} when StateName == wait_for_sasl_response ->
|
|
process_sasl_success(State);
|
|
#sasl_failure{} when StateName == wait_for_sasl_response ->
|
|
process_sasl_failure(Pkt, State);
|
|
#stream_error{} ->
|
|
process_stream_end({stream, {in, Pkt}}, State);
|
|
_ when is_record(Pkt, stream_features);
|
|
is_record(Pkt, starttls_proceed);
|
|
is_record(Pkt, starttls);
|
|
is_record(Pkt, sasl_auth);
|
|
is_record(Pkt, sasl_success);
|
|
is_record(Pkt, sasl_failure);
|
|
is_record(Pkt, sasl_response);
|
|
is_record(Pkt, sasl_abort);
|
|
is_record(Pkt, compress);
|
|
is_record(Pkt, handshake) ->
|
|
%% Do not pass this crap upstream
|
|
State;
|
|
_ ->
|
|
process_packet(Pkt, State)
|
|
end.
|
|
|
|
-spec process_features(stream_features(), state()) -> state().
|
|
process_features(StreamFeatures,
|
|
#{stream_authenticated := true} = State) ->
|
|
State1 = try callback(handle_authenticated_features, StreamFeatures, State)
|
|
catch _:{?MODULE, undef} -> State
|
|
end,
|
|
process_stream_established(State1);
|
|
process_features(StreamFeatures,
|
|
#{stream_encrypted := Encrypted,
|
|
lang := Lang} = State) ->
|
|
State1 = try callback(handle_unauthenticated_features, StreamFeatures, State)
|
|
catch _:{?MODULE, undef} -> State
|
|
end,
|
|
case is_disconnected(State1) of
|
|
true -> State1;
|
|
false ->
|
|
TLSRequired = is_starttls_required(State1),
|
|
TLSAvailable = is_starttls_available(State1),
|
|
try xmpp:try_subtag(StreamFeatures, #starttls{}) of
|
|
false when TLSRequired and not Encrypted ->
|
|
Txt = <<"Use of STARTTLS required">>,
|
|
send_pkt(State1, xmpp:serr_policy_violation(Txt, Lang));
|
|
false when not Encrypted ->
|
|
process_sasl_failure(
|
|
<<"Peer doesn't support STARTTLS">>, State1);
|
|
#starttls{required = true} when not TLSAvailable and not Encrypted ->
|
|
Txt = <<"Use of STARTTLS forbidden">>,
|
|
send_pkt(State1, xmpp:serr_unsupported_feature(Txt, Lang));
|
|
#starttls{} when TLSAvailable and not Encrypted ->
|
|
State2 = State1#{stream_state => wait_for_starttls_response},
|
|
send_pkt(State2, #starttls{});
|
|
#starttls{} when not Encrypted ->
|
|
process_sasl_failure(
|
|
<<"STARTTLS is disabled in local configuration">>, State1);
|
|
_ ->
|
|
State2 = process_cert_verification(State1),
|
|
case is_disconnected(State2) of
|
|
true -> State2;
|
|
false ->
|
|
try xmpp:try_subtag(StreamFeatures, #sasl_mechanisms{}) of
|
|
#sasl_mechanisms{list = Mechs} ->
|
|
process_sasl_mechanisms(Mechs, State2);
|
|
false ->
|
|
Txt = <<"Peer provided no SASL mechanisms; "
|
|
"most likely it doesn't accept "
|
|
"our certificate">>,
|
|
process_sasl_failure(Txt, State2)
|
|
catch _:{xmpp_codec, Why} ->
|
|
Txt = xmpp:io_format_error(Why),
|
|
process_sasl_failure(Txt, State1)
|
|
end
|
|
end
|
|
catch _:{xmpp_codec, Why} ->
|
|
Txt = xmpp:io_format_error(Why),
|
|
process_sasl_failure(Txt, State1)
|
|
end
|
|
end.
|
|
|
|
-spec process_stream_established(state()) -> state().
|
|
process_stream_established(#{stream_state := StateName} = State)
|
|
when StateName == disconnected; StateName == established ->
|
|
State;
|
|
process_stream_established(State) ->
|
|
State1 = State#{stream_authenticated := true,
|
|
stream_state => established,
|
|
stream_timeout => infinity},
|
|
try callback(handle_stream_established, State1)
|
|
catch _:{?MODULE, undef} -> State1
|
|
end.
|
|
|
|
-spec process_sasl_mechanisms([binary()], state()) -> state().
|
|
process_sasl_mechanisms(Mechs, #{user := User, server := Server} = State) ->
|
|
%% TODO: support other mechanisms
|
|
Mech = <<"EXTERNAL">>,
|
|
case lists:member(<<"EXTERNAL">>, Mechs) of
|
|
true ->
|
|
State1 = State#{stream_state => wait_for_sasl_response},
|
|
Authzid = jid:encode(jid:make(User, Server)),
|
|
send_pkt(State1, #sasl_auth{mechanism = Mech, text = Authzid});
|
|
false ->
|
|
process_sasl_failure(
|
|
<<"Peer doesn't support EXTERNAL authentication">>, State)
|
|
end.
|
|
|
|
-spec process_starttls(state()) -> state().
|
|
process_starttls(#{socket := Socket} = State) ->
|
|
case starttls(Socket, State) of
|
|
{ok, TLSSocket} ->
|
|
State1 = State#{socket => TLSSocket,
|
|
stream_id => new_id(),
|
|
stream_restarted => true,
|
|
stream_state => wait_for_stream,
|
|
stream_encrypted => true},
|
|
send_header(State1);
|
|
{error, Why} ->
|
|
process_stream_end({tls, Why}, State)
|
|
end.
|
|
|
|
-spec process_stream_downgrade(stream_start(), state()) -> state().
|
|
process_stream_downgrade(StreamStart,
|
|
#{lang := Lang,
|
|
stream_encrypted := Encrypted} = State) ->
|
|
TLSRequired = is_starttls_required(State),
|
|
if not Encrypted and TLSRequired ->
|
|
Txt = <<"Use of STARTTLS required">>,
|
|
send_pkt(State, xmpp:serr_policy_violation(Txt, Lang));
|
|
true ->
|
|
State1 = State#{stream_state => downgraded},
|
|
try callback(handle_stream_downgraded, StreamStart, State1)
|
|
catch _:{?MODULE, undef} ->
|
|
send_pkt(State1, xmpp:serr_unsupported_version())
|
|
end
|
|
end.
|
|
|
|
-spec process_cert_verification(state()) -> state().
|
|
process_cert_verification(#{stream_encrypted := true,
|
|
stream_verified := false} = State) ->
|
|
case try callback(tls_verify, State)
|
|
catch _:{?MODULE, undef} -> true
|
|
end of
|
|
true ->
|
|
case xmpp_stream_pkix:authenticate(State) of
|
|
{ok, _} ->
|
|
State#{stream_verified => true};
|
|
{error, Why, _Peer} ->
|
|
process_stream_end({pkix, Why}, State)
|
|
end;
|
|
false ->
|
|
State#{stream_verified => true}
|
|
end;
|
|
process_cert_verification(State) ->
|
|
State.
|
|
|
|
-spec process_sasl_success(state()) -> state().
|
|
process_sasl_success(#{socket := Socket} = State) ->
|
|
Socket1 = xmpp_socket:reset_stream(Socket),
|
|
State0 = State#{socket => Socket1},
|
|
State1 = State0#{stream_id => new_id(),
|
|
stream_restarted => true,
|
|
stream_state => wait_for_stream,
|
|
stream_authenticated => true},
|
|
State2 = send_header(State1),
|
|
case is_disconnected(State2) of
|
|
true -> State2;
|
|
false ->
|
|
try callback(handle_auth_success, <<"EXTERNAL">>, State2)
|
|
catch _:{?MODULE, undef} -> State2
|
|
end
|
|
end.
|
|
|
|
-spec process_sasl_failure(sasl_failure() | binary(), state()) -> state().
|
|
process_sasl_failure(#sasl_failure{} = Failure, State) ->
|
|
Reason = format("Peer responded with error: ~s",
|
|
[format_sasl_failure(Failure)]),
|
|
process_sasl_failure(Reason, State);
|
|
process_sasl_failure(Reason, State) ->
|
|
try callback(handle_auth_failure, <<"EXTERNAL">>, {auth, Reason}, State)
|
|
catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State)
|
|
end.
|
|
|
|
-spec process_packet(xmpp_element(), state()) -> state().
|
|
process_packet(Pkt, State) ->
|
|
try callback(handle_packet, Pkt, State)
|
|
catch _:{?MODULE, undef} -> State
|
|
end.
|
|
|
|
-spec is_starttls_required(state()) -> boolean().
|
|
is_starttls_required(State) ->
|
|
try callback(tls_required, State)
|
|
catch _:{?MODULE, undef} -> false
|
|
end.
|
|
|
|
-spec is_starttls_available(state()) -> boolean().
|
|
is_starttls_available(State) ->
|
|
try callback(tls_enabled, State)
|
|
catch _:{?MODULE, undef} -> true
|
|
end.
|
|
|
|
-spec send_header(state()) -> state().
|
|
send_header(#{remote_server := RemoteServer,
|
|
stream_encrypted := Encrypted,
|
|
lang := Lang,
|
|
xmlns := NS,
|
|
user := User,
|
|
resource := Resource,
|
|
server := Server} = State) ->
|
|
NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
|
|
true -> <<"">>
|
|
end,
|
|
From = if Encrypted ->
|
|
jid:make(User, Server, Resource);
|
|
NS == ?NS_SERVER ->
|
|
jid:make(Server);
|
|
true ->
|
|
undefined
|
|
end,
|
|
StreamStart = #stream_start{xmlns = NS,
|
|
lang = Lang,
|
|
stream_xmlns = ?NS_STREAM,
|
|
db_xmlns = NS_DB,
|
|
from = From,
|
|
to = jid:make(RemoteServer),
|
|
version = {1,0}},
|
|
case socket_send(State, StreamStart) of
|
|
ok -> State;
|
|
{error, Why} -> process_stream_end({socket, Why}, State)
|
|
end.
|
|
|
|
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
|
|
send_pkt(State, Pkt) ->
|
|
Result = socket_send(State, Pkt),
|
|
State1 = try callback(handle_send, Pkt, Result, State)
|
|
catch _:{?MODULE, undef} -> State
|
|
end,
|
|
case Result of
|
|
_ when is_record(Pkt, stream_error) ->
|
|
process_stream_end({stream, {out, Pkt}}, State1);
|
|
ok ->
|
|
State1;
|
|
{error, Why} ->
|
|
process_stream_end({socket, Why}, State1)
|
|
end.
|
|
|
|
-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state().
|
|
send_error(State, Pkt, Err) ->
|
|
case xmpp:is_stanza(Pkt) of
|
|
true ->
|
|
case xmpp:get_type(Pkt) of
|
|
result -> State;
|
|
error -> State;
|
|
<<"result">> -> State;
|
|
<<"error">> -> State;
|
|
_ ->
|
|
ErrPkt = xmpp:make_error(Pkt, Err),
|
|
send_pkt(State, ErrPkt)
|
|
end;
|
|
false ->
|
|
State
|
|
end.
|
|
|
|
-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}.
|
|
socket_send(#{socket := Socket, xmlns := NS,
|
|
stream_state := StateName}, Pkt) ->
|
|
case Pkt of
|
|
trailer ->
|
|
xmpp_socket:send_trailer(Socket);
|
|
#stream_start{} when StateName /= disconnected ->
|
|
xmpp_socket:send_header(Socket, xmpp:encode(Pkt));
|
|
_ when StateName /= disconnected ->
|
|
xmpp_socket:send_element(Socket, xmpp:encode(Pkt, NS));
|
|
_ ->
|
|
{error, closed}
|
|
end;
|
|
socket_send(_, _) ->
|
|
{error, closed}.
|
|
|
|
-spec send_trailer(state()) -> state().
|
|
send_trailer(State) ->
|
|
socket_send(State, trailer),
|
|
close_socket(State).
|
|
|
|
-spec close_socket(state()) -> state().
|
|
close_socket(State) ->
|
|
case State of
|
|
#{socket := Socket} ->
|
|
xmpp_socket:close(Socket);
|
|
_ ->
|
|
ok
|
|
end,
|
|
State#{stream_timeout => infinity,
|
|
stream_state => disconnected}.
|
|
|
|
-spec starttls(term(), state()) -> {ok, term()} | {error, tls_error_reason()}.
|
|
starttls(Socket, #{xmlns := NS,
|
|
remote_server := RemoteServer} = State) ->
|
|
TLSOpts = try callback(tls_options, State)
|
|
catch _:{?MODULE, undef} -> []
|
|
end,
|
|
SNI = idna_to_ascii(RemoteServer),
|
|
ALPN = case NS of
|
|
?NS_SERVER -> <<"xmpp-server">>;
|
|
?NS_CLIENT -> <<"xmpp-client">>
|
|
end,
|
|
xmpp_socket:starttls(Socket, [connect, {sni, SNI}, {alpn, [ALPN]}|TLSOpts]).
|
|
|
|
-spec select_lang(binary(), binary()) -> binary().
|
|
select_lang(Lang, <<"">>) -> Lang;
|
|
select_lang(_, Lang) -> Lang.
|
|
|
|
-spec format_inet_error(atom()) -> string().
|
|
format_inet_error(closed) ->
|
|
"connection closed";
|
|
format_inet_error(Reason) ->
|
|
case inet:format_error(Reason) of
|
|
"unknown POSIX error" -> atom_to_list(Reason);
|
|
Txt -> Txt
|
|
end.
|
|
|
|
-spec format_stream_error(atom() | 'see-other-host'(), [text()]) -> string().
|
|
format_stream_error(Reason, Txt) ->
|
|
Slogan = case Reason of
|
|
undefined -> "no reason";
|
|
#'see-other-host'{} -> "see-other-host";
|
|
_ -> atom_to_list(Reason)
|
|
end,
|
|
case xmpp:get_text(Txt) of
|
|
<<"">> ->
|
|
Slogan;
|
|
Data ->
|
|
binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
|
|
end.
|
|
|
|
-spec format_tls_error(atom() | binary()) -> list().
|
|
format_tls_error(Reason) when is_atom(Reason) ->
|
|
format_inet_error(Reason);
|
|
format_tls_error(Reason) ->
|
|
binary_to_list(Reason).
|
|
|
|
format_sasl_failure(#sasl_failure{reason = Reason, text = Txt}) ->
|
|
Slogan = case Reason of
|
|
undefined -> "no reason";
|
|
_ -> atom_to_list(Reason)
|
|
end,
|
|
case xmpp:get_text(Txt) of
|
|
<<"">> -> Slogan;
|
|
Data ->
|
|
binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
|
|
end.
|
|
|
|
-spec format(io:format(), list()) -> binary().
|
|
format(Fmt, Args) ->
|
|
iolist_to_binary(io_lib:format(Fmt, Args)).
|
|
|
|
%%%===================================================================
|
|
%%% Connection stuff
|
|
%%%===================================================================
|
|
-spec idna_to_ascii(binary()) -> binary() | false.
|
|
idna_to_ascii(<<$[, _/binary>> = Host) ->
|
|
%% This is an IPv6 address in 'IP-literal' format (as per RFC7622)
|
|
%% We remove brackets here
|
|
case binary:last(Host) of
|
|
$] ->
|
|
IPv6 = binary:part(Host, {1, size(Host)-2}),
|
|
case inet:parse_ipv6strict_address(binary_to_list(IPv6)) of
|
|
{ok, _} -> IPv6;
|
|
{error, _} -> false
|
|
end;
|
|
_ ->
|
|
false
|
|
end;
|
|
idna_to_ascii(Host) ->
|
|
case inet:parse_address(binary_to_list(Host)) of
|
|
{ok, _} -> Host;
|
|
{error, _} -> ejabberd_idna:domain_utf8_to_ascii(Host)
|
|
end.
|
|
|
|
-spec resolve(string(), state()) -> {ok, [ip_port()]} | network_error().
|
|
resolve(Host, State) ->
|
|
case srv_lookup(Host, State) of
|
|
{error, _Reason} ->
|
|
DefaultPort = get_default_port(State),
|
|
a_lookup([{Host, DefaultPort, false}], State);
|
|
{ok, HostPorts} ->
|
|
a_lookup(HostPorts, State)
|
|
end.
|
|
|
|
-spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error().
|
|
srv_lookup(_Host, #{xmlns := ?NS_COMPONENT}) ->
|
|
%% Do not attempt to lookup SRV for component connections
|
|
{error, nxdomain};
|
|
srv_lookup(Host, State) ->
|
|
%% Only perform SRV lookups for FQDN names
|
|
case string:chr(Host, $.) of
|
|
0 ->
|
|
{error, nxdomain};
|
|
_ ->
|
|
case inet:parse_address(Host) of
|
|
{ok, _} ->
|
|
{error, nxdomain};
|
|
{error, _} ->
|
|
Timeout = get_dns_timeout(State),
|
|
Retries = get_dns_retries(State),
|
|
case srv_lookup(Host, State, Timeout, Retries) of
|
|
{ok, AddrList} ->
|
|
h_addr_list_to_host_ports(AddrList);
|
|
{error, _} = Err ->
|
|
Err
|
|
end
|
|
end
|
|
end.
|
|
|
|
srv_lookup(Host, State, Timeout, Retries) ->
|
|
TLSAddrs = case is_starttls_available(State) of
|
|
true ->
|
|
case srv_lookup("_xmpps-server._tcp." ++ Host,
|
|
Timeout, Retries) of
|
|
{ok, HostEnt} ->
|
|
[{A, true} || A <- HostEnt#hostent.h_addr_list];
|
|
{error, _} ->
|
|
[]
|
|
end;
|
|
false ->
|
|
[]
|
|
end,
|
|
case srv_lookup("_xmpp-server._tcp." ++ Host, Timeout, Retries) of
|
|
{ok, HostEntry} ->
|
|
Addrs = [{A, false} || A <- HostEntry#hostent.h_addr_list],
|
|
{ok, TLSAddrs ++ Addrs};
|
|
{error, _} when TLSAddrs /= [] ->
|
|
{ok, TLSAddrs};
|
|
{error, _} = Err ->
|
|
Err
|
|
end.
|
|
|
|
-spec srv_lookup(string(), timeout(), integer()) ->
|
|
{ok, inet:hostent()} | network_error().
|
|
srv_lookup(_SRVName, _Timeout, Retries) when Retries < 1 ->
|
|
{error, timeout};
|
|
srv_lookup(SRVName, Timeout, Retries) ->
|
|
case inet_res:getbyname(SRVName, srv, Timeout) of
|
|
{ok, HostEntry} ->
|
|
{ok, HostEntry};
|
|
{error, timeout} ->
|
|
srv_lookup(SRVName, Timeout, Retries - 1);
|
|
{error, _} = Err ->
|
|
Err
|
|
end.
|
|
|
|
-spec a_lookup([host_port()], state()) ->
|
|
{ok, [ip_port()]} | network_error().
|
|
a_lookup(HostPorts, State) ->
|
|
HostPortFamilies = [{Host, Port, TLS, Family}
|
|
|| {Host, Port, TLS} <- HostPorts,
|
|
Family <- get_address_families(State)],
|
|
a_lookup(HostPortFamilies, State, [], {error, nxdomain}).
|
|
|
|
-spec a_lookup([{inet:hostname(), inet:port_number(), boolean(), inet:address_family()}],
|
|
state(), [ip_port()], network_error()) -> {ok, [ip_port()]} | network_error().
|
|
a_lookup([{Host, Port, TLS, Family}|HostPortFamilies], State, Acc, Err) ->
|
|
Timeout = get_dns_timeout(State),
|
|
Retries = get_dns_retries(State),
|
|
case a_lookup(Host, Port, TLS, Family, Timeout, Retries) of
|
|
{error, Reason} ->
|
|
a_lookup(HostPortFamilies, State, Acc, {error, Reason});
|
|
{ok, AddrPorts} ->
|
|
a_lookup(HostPortFamilies, State, Acc ++ AddrPorts, Err)
|
|
end;
|
|
a_lookup([], _State, [], Err) ->
|
|
Err;
|
|
a_lookup([], _State, Acc, _) ->
|
|
{ok, Acc}.
|
|
|
|
-spec a_lookup(inet:hostname(), inet:port_number(), boolean(), inet:address_family(),
|
|
timeout(), integer()) -> {ok, [ip_port()]} | network_error().
|
|
a_lookup(_Host, _Port, _TLS, _Family, _Timeout, Retries) when Retries < 1 ->
|
|
{error, timeout};
|
|
a_lookup(Host, Port, TLS, Family, Timeout, Retries) ->
|
|
Start = p1_time_compat:monotonic_time(milli_seconds),
|
|
case inet:gethostbyname(Host, Family, Timeout) of
|
|
{error, nxdomain} = Err ->
|
|
%% inet:gethostbyname/3 doesn't return {error, timeout},
|
|
%% so we should check if 'nxdomain' is in fact a result
|
|
%% of a timeout.
|
|
%% We also cannot use inet_res:gethostbyname/3 because
|
|
%% it ignores DNS configuration settings (/etc/hosts, etc)
|
|
End = p1_time_compat:monotonic_time(milli_seconds),
|
|
if (End - Start) >= Timeout ->
|
|
a_lookup(Host, Port, TLS, Family, Timeout, Retries - 1);
|
|
true ->
|
|
Err
|
|
end;
|
|
{error, _} = Err ->
|
|
Err;
|
|
{ok, HostEntry} ->
|
|
host_entry_to_addr_ports(HostEntry, Port, TLS)
|
|
end.
|
|
|
|
-spec h_addr_list_to_host_ports(h_addr_list()) -> {ok, [host_port()]} |
|
|
{error, nxdomain}.
|
|
h_addr_list_to_host_ports(AddrList) ->
|
|
PrioHostPorts = lists:flatmap(
|
|
fun({{Priority, Weight, Port, Host}, TLS}) ->
|
|
N = case Weight of
|
|
0 -> 0;
|
|
_ -> (Weight + 1) * randoms:uniform()
|
|
end,
|
|
[{Priority * 65536 - N, Host, Port, TLS}];
|
|
(_) ->
|
|
[]
|
|
end, AddrList),
|
|
HostPorts = [{Host, Port, TLS}
|
|
|| {_Priority, Host, Port, TLS} <- lists:usort(PrioHostPorts)],
|
|
case HostPorts of
|
|
[] -> {error, nxdomain};
|
|
_ -> {ok, HostPorts}
|
|
end.
|
|
|
|
-spec host_entry_to_addr_ports(inet:hostent(), inet:port_number(), boolean()) ->
|
|
{ok, [ip_port()]} | {error, nxdomain}.
|
|
host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port, TLS) ->
|
|
AddrPorts = lists:flatmap(
|
|
fun(Addr) ->
|
|
try get_addr_type(Addr) of
|
|
_ -> [{Addr, Port, TLS}]
|
|
catch _:_ ->
|
|
[]
|
|
end
|
|
end, AddrList),
|
|
case AddrPorts of
|
|
[] -> {error, nxdomain};
|
|
_ -> {ok, AddrPorts}
|
|
end.
|
|
|
|
-spec connect([ip_port()], state()) -> {ok, term(), ip_port()} |
|
|
{error, {socket, socket_error_reason()}} |
|
|
{error, {tls, tls_error_reason()}}.
|
|
connect(AddrPorts, State) ->
|
|
Timeout = get_connect_timeout(State),
|
|
case connect(AddrPorts, Timeout, {error, nxdomain}) of
|
|
{ok, Socket, {Addr, Port, TLS = true}} ->
|
|
case starttls(Socket, State) of
|
|
{ok, TLSSocket} -> {ok, TLSSocket, {Addr, Port, TLS}};
|
|
{error, Why} -> {error, {tls, Why}}
|
|
end;
|
|
{ok, Socket, {Addr, Port, TLS = false}} ->
|
|
{ok, Socket, {Addr, Port, TLS}};
|
|
{error, Why} ->
|
|
{error, {socket, Why}}
|
|
end.
|
|
|
|
-spec connect([ip_port()], timeout(), network_error()) ->
|
|
{ok, term(), ip_port()} | network_error().
|
|
connect([{Addr, Port, TLS}|AddrPorts], Timeout, _) ->
|
|
Type = get_addr_type(Addr),
|
|
try xmpp_socket:connect(Addr, Port,
|
|
[binary, {packet, 0},
|
|
{send_timeout, ?TCP_SEND_TIMEOUT},
|
|
{send_timeout_close, true},
|
|
{active, false}, Type],
|
|
Timeout) of
|
|
{ok, Socket} ->
|
|
{ok, Socket, {Addr, Port, TLS}};
|
|
Err ->
|
|
connect(AddrPorts, Timeout, Err)
|
|
catch _:badarg ->
|
|
connect(AddrPorts, Timeout, {error, einval})
|
|
end;
|
|
connect([], _Timeout, Err) ->
|
|
Err.
|
|
|
|
-spec get_addr_type(inet:ip_address()) -> inet:address_family().
|
|
get_addr_type({_, _, _, _}) -> inet;
|
|
get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
|
|
|
|
-spec get_dns_timeout(state()) -> timeout().
|
|
get_dns_timeout(State) ->
|
|
try callback(dns_timeout, State)
|
|
catch _:{?MODULE, undef} -> timer:seconds(10)
|
|
end.
|
|
|
|
-spec get_dns_retries(state()) -> non_neg_integer().
|
|
get_dns_retries(State) ->
|
|
try callback(dns_retries, State)
|
|
catch _:{?MODULE, undef} -> 2
|
|
end.
|
|
|
|
-spec get_default_port(state()) -> inet:port_number().
|
|
get_default_port(#{xmlns := NS} = State) ->
|
|
try callback(default_port, State)
|
|
catch _:{?MODULE, undef} when NS == ?NS_SERVER -> 5269;
|
|
_:{?MODULE, undef} when NS == ?NS_CLIENT -> 5222
|
|
end.
|
|
|
|
-spec get_address_families(state()) -> [inet:address_family()].
|
|
get_address_families(State) ->
|
|
try callback(address_families, State)
|
|
catch _:{?MODULE, undef} -> [inet, inet6]
|
|
end.
|
|
|
|
-spec get_connect_timeout(state()) -> timeout().
|
|
get_connect_timeout(State) ->
|
|
try callback(connect_timeout, State)
|
|
catch _:{?MODULE, undef} -> timer:seconds(10)
|
|
end.
|
|
|
|
%%%===================================================================
|
|
%%% Callbacks
|
|
%%%===================================================================
|
|
callback(F, #{mod := Mod} = State) ->
|
|
case erlang:function_exported(Mod, F, 1) of
|
|
true -> Mod:F(State);
|
|
false -> erlang:error({?MODULE, undef})
|
|
end.
|
|
|
|
callback(F, Arg1, #{mod := Mod} = State) ->
|
|
case erlang:function_exported(Mod, F, 2) of
|
|
true -> Mod:F(Arg1, State);
|
|
false -> erlang:error({?MODULE, undef})
|
|
end.
|
|
|
|
callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
|
|
%% code_change/3 callback is a special snowflake
|
|
case erlang:function_exported(Mod, code_change, 3) of
|
|
true -> Mod:code_change(OldVsn, State, Extra);
|
|
false -> {ok, State}
|
|
end;
|
|
callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
|
|
case erlang:function_exported(Mod, F, 3) of
|
|
true -> Mod:F(Arg1, Arg2, State);
|
|
false -> erlang:error({?MODULE, undef})
|
|
end.
|