mirror of
https://github.com/processone/ejabberd.git
synced 2024-11-26 16:26:24 +01:00
1033 lines
34 KiB
Erlang
1033 lines
34 KiB
Erlang
%%%-------------------------------------------------------------------
|
|
%%% Created : 14 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
|
|
%%%
|
|
%%%
|
|
%%% ejabberd, Copyright (C) 2002-2017 ProcessOne
|
|
%%%
|
|
%%% This program is free software; you can redistribute it and/or
|
|
%%% modify it under the terms of the GNU General Public License as
|
|
%%% published by the Free Software Foundation; either version 2 of the
|
|
%%% License, or (at your option) any later version.
|
|
%%%
|
|
%%% This program is distributed in the hope that it will be useful,
|
|
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
%%% General Public License for more details.
|
|
%%%
|
|
%%% You should have received a copy of the GNU General Public License along
|
|
%%% with this program; if not, write to the Free Software Foundation, Inc.,
|
|
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
|
%%%
|
|
%%%-------------------------------------------------------------------
|
|
-module(xmpp_stream_out).
|
|
-define(GEN_SERVER, p1_server).
|
|
-behaviour(?GEN_SERVER).
|
|
|
|
-protocol({rfc, 6120}).
|
|
-protocol({xep, 114, '1.6'}).
|
|
|
|
%% API
|
|
-export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1,
|
|
stop/1, send/2, close/1, close/2, establish/1, format_error/1,
|
|
set_timeout/2, get_transport/1, change_shaper/2]).
|
|
%% gen_server callbacks
|
|
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
|
|
terminate/2, code_change/3]).
|
|
|
|
%%-define(DBGFSM, true).
|
|
-ifdef(DBGFSM).
|
|
-define(FSMOPTS, [{debug, [trace]}]).
|
|
-else.
|
|
-define(FSMOPTS, []).
|
|
-endif.
|
|
|
|
-define(TCP_SEND_TIMEOUT, 15000).
|
|
|
|
-include("xmpp.hrl").
|
|
-include_lib("kernel/include/inet.hrl").
|
|
|
|
-type state() :: map().
|
|
-type noreply() :: {noreply, state(), timeout()}.
|
|
-type host_port() :: {inet:hostname(), inet:port_number()}.
|
|
-type ip_port() :: {inet:ip_address(), inet:port_number()}.
|
|
-type network_error() :: {error, inet:posix() | inet_res:res_error()}.
|
|
-type stop_reason() :: {idna, bad_string} |
|
|
{dns, inet:posix() | inet_res:res_error()} |
|
|
{stream, reset | {in | out, stream_error()}} |
|
|
{tls, inet:posix() | atom() | binary()} |
|
|
{pkix, binary()} |
|
|
{auth, atom() | binary() | string()} |
|
|
{socket, inet:posix() | atom()} |
|
|
internal_failure.
|
|
-export_type([state/0, stop_reason/0]).
|
|
-callback init(list()) -> {ok, state()} | {error, term()} | ignore.
|
|
-callback handle_cast(term(), state()) -> state().
|
|
-callback handle_call(term(), term(), state()) -> state().
|
|
-callback handle_info(term(), state()) -> state().
|
|
-callback terminate(term(), state()) -> any().
|
|
-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
|
|
-callback handle_stream_start(stream_start(), state()) -> state().
|
|
-callback handle_stream_established(state()) -> state().
|
|
-callback handle_stream_downgraded(stream_start(), state()) -> state().
|
|
-callback handle_stream_end(stop_reason(), state()) -> state().
|
|
-callback handle_cdata(binary(), state()) -> state().
|
|
-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
|
|
-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
|
|
-callback handle_timeout(state()) -> state().
|
|
-callback handle_authenticated_features(stream_features(), state()) -> state().
|
|
-callback handle_unauthenticated_features(stream_features(), state()) -> state().
|
|
-callback handle_auth_success(cyrsasl:mechanism(), state()) -> state().
|
|
-callback handle_auth_failure(cyrsasl:mechanism(), binary(), state()) -> state().
|
|
-callback handle_packet(xmpp_element(), state()) -> state().
|
|
-callback tls_options(state()) -> [proplists:property()].
|
|
-callback tls_required(state()) -> boolean().
|
|
-callback tls_verify(state()) -> boolean().
|
|
-callback tls_enabled(state()) -> boolean().
|
|
-callback dns_timeout(state()) -> timeout().
|
|
-callback dns_retries(state()) -> non_neg_integer().
|
|
-callback default_port(state()) -> inet:port_number().
|
|
-callback address_families(state()) -> [inet:address_family()].
|
|
-callback connect_timeout(state()) -> timeout().
|
|
|
|
-optional_callbacks([init/1,
|
|
handle_cast/2,
|
|
handle_call/3,
|
|
handle_info/2,
|
|
terminate/2,
|
|
code_change/3,
|
|
handle_stream_start/2,
|
|
handle_stream_established/1,
|
|
handle_stream_downgraded/2,
|
|
handle_stream_end/2,
|
|
handle_cdata/2,
|
|
handle_send/3,
|
|
handle_recv/3,
|
|
handle_timeout/1,
|
|
handle_authenticated_features/2,
|
|
handle_unauthenticated_features/2,
|
|
handle_auth_success/2,
|
|
handle_auth_failure/3,
|
|
handle_packet/2,
|
|
tls_options/1,
|
|
tls_required/1,
|
|
tls_verify/1,
|
|
tls_enabled/1,
|
|
dns_timeout/1,
|
|
dns_retries/1,
|
|
default_port/1,
|
|
address_families/1,
|
|
connect_timeout/1]).
|
|
|
|
%%%===================================================================
|
|
%%% API
|
|
%%%===================================================================
|
|
start(Mod, Args, Opts) ->
|
|
?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
|
|
|
|
start_link(Mod, Args, Opts) ->
|
|
?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
|
|
|
|
call(Ref, Msg, Timeout) ->
|
|
?GEN_SERVER:call(Ref, Msg, Timeout).
|
|
|
|
cast(Ref, Msg) ->
|
|
?GEN_SERVER:cast(Ref, Msg).
|
|
|
|
reply(Ref, Reply) ->
|
|
?GEN_SERVER:reply(Ref, Reply).
|
|
|
|
-spec connect(pid()) -> ok.
|
|
connect(Ref) ->
|
|
cast(Ref, connect).
|
|
|
|
-spec stop(pid()) -> ok;
|
|
(state()) -> no_return().
|
|
stop(Pid) when is_pid(Pid) ->
|
|
cast(Pid, stop);
|
|
stop(#{owner := Owner} = State) when Owner == self() ->
|
|
terminate(normal, State),
|
|
exit(normal);
|
|
stop(_) ->
|
|
erlang:error(badarg).
|
|
|
|
-spec send(pid(), xmpp_element()) -> ok;
|
|
(state(), xmpp_element()) -> state().
|
|
send(Pid, Pkt) when is_pid(Pid) ->
|
|
cast(Pid, {send, Pkt});
|
|
send(#{owner := Owner} = State, Pkt) when Owner == self() ->
|
|
send_pkt(State, Pkt);
|
|
send(_, _) ->
|
|
erlang:error(badarg).
|
|
|
|
-spec close(pid()) -> ok;
|
|
(state()) -> state().
|
|
close(Pid) when is_pid(Pid) ->
|
|
close(Pid, closed);
|
|
close(#{owner := Owner} = State) when Owner == self() ->
|
|
close_socket(State);
|
|
close(_) ->
|
|
erlang:error(badarg).
|
|
|
|
-spec close(pid(), atom()) -> ok.
|
|
close(Pid, Reason) ->
|
|
cast(Pid, {close, Reason}).
|
|
|
|
-spec establish(state()) -> state().
|
|
establish(State) ->
|
|
process_stream_established(State).
|
|
|
|
-spec set_timeout(state(), timeout()) -> state().
|
|
set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
|
|
case Timeout of
|
|
infinity -> State#{stream_timeout => infinity};
|
|
_ ->
|
|
Time = p1_time_compat:monotonic_time(milli_seconds),
|
|
State#{stream_timeout => {Timeout, Time}}
|
|
end;
|
|
set_timeout(_, _) ->
|
|
erlang:error(badarg).
|
|
|
|
get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner})
|
|
when Owner == self() ->
|
|
SockMod:get_transport(Socket);
|
|
get_transport(_) ->
|
|
erlang:error(badarg).
|
|
|
|
-spec change_shaper(state(), shaper:shaper()) -> ok.
|
|
change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper)
|
|
when Owner == self() ->
|
|
SockMod:change_shaper(Socket, Shaper);
|
|
change_shaper(_, _) ->
|
|
erlang:error(badarg).
|
|
|
|
-spec format_error(stop_reason()) -> binary().
|
|
format_error({idna, _}) ->
|
|
<<"Remote domain is not an IDN hostname">>;
|
|
format_error({dns, Reason}) ->
|
|
format("DNS lookup failed: ~s", [format_inet_error(Reason)]);
|
|
format_error({socket, Reason}) ->
|
|
format("Connection failed: ~s", [format_inet_error(Reason)]);
|
|
format_error({pkix, Reason}) ->
|
|
{_, ErrTxt} = xmpp_stream_pkix:format_error(Reason),
|
|
format("Peer certificate rejected: ~s", [ErrTxt]);
|
|
format_error({stream, reset}) ->
|
|
<<"Stream reset by peer">>;
|
|
format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
|
|
format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]);
|
|
format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
|
|
format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
|
|
format_error({tls, Reason}) ->
|
|
format("TLS failed: ~s", [format_tls_error(Reason)]);
|
|
format_error({auth, Reason}) ->
|
|
format("Authentication failed: ~s", [Reason]);
|
|
format_error(internal_failure) ->
|
|
<<"Internal server error">>;
|
|
format_error(Err) ->
|
|
format("Unrecognized error: ~w", [Err]).
|
|
|
|
%%%===================================================================
|
|
%%% gen_server callbacks
|
|
%%%===================================================================
|
|
-spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore.
|
|
init([Mod, SockMod, From, To, Opts]) ->
|
|
Time = p1_time_compat:monotonic_time(milli_seconds),
|
|
State = #{owner => self(),
|
|
mod => Mod,
|
|
sockmod => SockMod,
|
|
server => From,
|
|
user => <<"">>,
|
|
resource => <<"">>,
|
|
lang => <<"">>,
|
|
remote_server => To,
|
|
xmlns => ?NS_SERVER,
|
|
stream_direction => out,
|
|
stream_timeout => {timer:seconds(30), Time},
|
|
stream_id => new_id(),
|
|
stream_encrypted => false,
|
|
stream_verified => false,
|
|
stream_authenticated => false,
|
|
stream_restarted => false,
|
|
stream_state => connecting},
|
|
case try Mod:init([State, Opts])
|
|
catch _:undef -> {ok, State}
|
|
end of
|
|
{ok, State1} ->
|
|
{_, State2, Timeout} = noreply(State1),
|
|
{ok, State2, Timeout};
|
|
{error, Reason} ->
|
|
{stop, Reason};
|
|
ignore ->
|
|
ignore
|
|
end.
|
|
|
|
-spec handle_call(term(), term(), state()) -> noreply().
|
|
handle_call(Call, From, #{mod := Mod} = State) ->
|
|
noreply(try Mod:handle_call(Call, From, State)
|
|
catch _:undef -> State
|
|
end).
|
|
|
|
-spec handle_cast(term(), state()) -> noreply().
|
|
handle_cast(connect, #{remote_server := RemoteServer,
|
|
sockmod := SockMod,
|
|
stream_state := connecting} = State) ->
|
|
noreply(
|
|
case idna_to_ascii(RemoteServer) of
|
|
false ->
|
|
process_stream_end({idna, bad_string}, State);
|
|
ASCIIName ->
|
|
case resolve(binary_to_list(ASCIIName), State) of
|
|
{ok, AddrPorts} ->
|
|
case connect(AddrPorts, State) of
|
|
{ok, Socket, AddrPort} ->
|
|
SocketMonitor = SockMod:monitor(Socket),
|
|
State1 = State#{ip => AddrPort,
|
|
socket => Socket,
|
|
socket_monitor => SocketMonitor},
|
|
State2 = State1#{stream_state => wait_for_stream},
|
|
send_header(State2);
|
|
{error, Why} ->
|
|
process_stream_end({socket, Why}, State)
|
|
end;
|
|
{error, Why} ->
|
|
process_stream_end({dns, Why}, State)
|
|
end
|
|
end);
|
|
handle_cast(connect, State) ->
|
|
%% Ignoring connection attempts in other states
|
|
noreply(State);
|
|
handle_cast({send, Pkt}, State) ->
|
|
noreply(send_pkt(State, Pkt));
|
|
handle_cast(stop, State) ->
|
|
{stop, normal, State};
|
|
handle_cast({close, Reason}, State) ->
|
|
State1 = close_socket(State),
|
|
noreply(
|
|
case is_disconnected(State) of
|
|
true -> State1;
|
|
false -> process_stream_end({socket, Reason}, State)
|
|
end);
|
|
handle_cast(Cast, #{mod := Mod} = State) ->
|
|
noreply(try Mod:handle_cast(Cast, State)
|
|
catch _:undef -> State
|
|
end).
|
|
|
|
-spec handle_info(term(), state()) -> noreply().
|
|
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
|
|
#{stream_state := wait_for_stream,
|
|
xmlns := XMLNS, lang := MyLang} = State) ->
|
|
El = #xmlel{name = Name, attrs = Attrs},
|
|
noreply(
|
|
try xmpp:decode(El, XMLNS, []) of
|
|
#stream_start{} = Pkt ->
|
|
process_stream(Pkt, State);
|
|
_ ->
|
|
send_pkt(State, xmpp:serr_invalid_xml())
|
|
catch _:{xmpp_codec, Why} ->
|
|
Txt = xmpp:io_format_error(Why),
|
|
Lang = select_lang(MyLang, xmpp:get_lang(El)),
|
|
Err = xmpp:serr_invalid_xml(Txt, Lang),
|
|
send_pkt(State, Err)
|
|
end);
|
|
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
|
|
State1 = send_header(State),
|
|
noreply(
|
|
case is_disconnected(State1) of
|
|
true -> State1;
|
|
false ->
|
|
Err = case Reason of
|
|
<<"XML stanza is too big">> ->
|
|
xmpp:serr_policy_violation(Reason, Lang);
|
|
{_, Txt} ->
|
|
xmpp:serr_not_well_formed(Txt, Lang)
|
|
end,
|
|
send_pkt(State1, Err)
|
|
end);
|
|
handle_info({'$gen_event', {xmlstreamelement, El}},
|
|
#{xmlns := NS, mod := Mod} = State) ->
|
|
noreply(
|
|
try xmpp:decode(El, NS, [ignore_els]) of
|
|
Pkt ->
|
|
State1 = try Mod:handle_recv(El, Pkt, State)
|
|
catch _:undef -> State
|
|
end,
|
|
case is_disconnected(State1) of
|
|
true -> State1;
|
|
false -> process_element(Pkt, State1)
|
|
end
|
|
catch _:{xmpp_codec, Why} ->
|
|
State1 = try Mod:handle_recv(El, {error, Why}, State)
|
|
catch _:undef -> State
|
|
end,
|
|
case is_disconnected(State1) of
|
|
true -> State1;
|
|
false -> process_invalid_xml(State1, El, Why)
|
|
end
|
|
end);
|
|
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
|
|
#{mod := Mod} = State) ->
|
|
noreply(try Mod:handle_cdata(Data, State)
|
|
catch _:undef -> State
|
|
end);
|
|
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
|
|
noreply(process_stream_end({stream, reset}, State));
|
|
handle_info({'$gen_event', closed}, State) ->
|
|
noreply(process_stream_end({socket, closed}, State));
|
|
handle_info(timeout, #{mod := Mod} = State) ->
|
|
Disconnected = is_disconnected(State),
|
|
noreply(try Mod:handle_timeout(State)
|
|
catch _:undef when not Disconnected ->
|
|
send_pkt(State, xmpp:serr_connection_timeout());
|
|
_:undef ->
|
|
stop(State)
|
|
end);
|
|
handle_info({'DOWN', MRef, _Type, _Object, _Info},
|
|
#{socket_monitor := MRef} = State) ->
|
|
noreply(process_stream_end({socket, closed}, State));
|
|
handle_info(Info, #{mod := Mod} = State) ->
|
|
noreply(try Mod:handle_info(Info, State)
|
|
catch _:undef -> State
|
|
end).
|
|
|
|
-spec terminate(term(), state()) -> any().
|
|
terminate(Reason, #{mod := Mod} = State) ->
|
|
case get(already_terminated) of
|
|
true ->
|
|
State;
|
|
_ ->
|
|
put(already_terminated, true),
|
|
try Mod:terminate(Reason, State)
|
|
catch _:undef -> ok
|
|
end,
|
|
send_trailer(State)
|
|
end.
|
|
|
|
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
|
|
Mod:code_change(OldVsn, State, Extra).
|
|
|
|
%%%===================================================================
|
|
%%% Internal functions
|
|
%%%===================================================================
|
|
-spec noreply(state()) -> noreply().
|
|
noreply(#{stream_timeout := infinity} = State) ->
|
|
{noreply, State, infinity};
|
|
noreply(#{stream_timeout := {MSecs, OldTime}} = State) ->
|
|
NewTime = p1_time_compat:monotonic_time(milli_seconds),
|
|
Timeout = max(0, MSecs - NewTime + OldTime),
|
|
{noreply, State, Timeout}.
|
|
|
|
-spec new_id() -> binary().
|
|
new_id() ->
|
|
randoms:get_string().
|
|
|
|
-spec is_disconnected(state()) -> boolean().
|
|
is_disconnected(#{stream_state := StreamState}) ->
|
|
StreamState == disconnected.
|
|
|
|
-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
|
|
process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
|
|
case xmpp:is_stanza(El) of
|
|
true ->
|
|
Txt = xmpp:io_format_error(Reason),
|
|
Lang = select_lang(MyLang, xmpp:get_lang(El)),
|
|
send_error(State, El, xmpp:err_bad_request(Txt, Lang));
|
|
false ->
|
|
State
|
|
end.
|
|
|
|
-spec process_stream_end(stop_reason(), state()) -> state().
|
|
process_stream_end(_, #{stream_state := disconnected} = State) ->
|
|
State;
|
|
process_stream_end(Reason, #{mod := Mod} = State) ->
|
|
State1 = State#{stream_timeout => infinity,
|
|
stream_state => disconnected},
|
|
try Mod:handle_stream_end(Reason, State1)
|
|
catch _:undef -> stop(State1)
|
|
end.
|
|
|
|
-spec process_stream(stream_start(), state()) -> state().
|
|
process_stream(#stream_start{xmlns = XML_NS,
|
|
stream_xmlns = STREAM_NS},
|
|
#{xmlns := NS} = State)
|
|
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
|
|
send_pkt(State, xmpp:serr_invalid_namespace());
|
|
process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
|
|
send_pkt(State, xmpp:serr_unsupported_version());
|
|
process_stream(#stream_start{lang = Lang, id = ID,
|
|
version = Version} = StreamStart,
|
|
#{mod := Mod} = State) ->
|
|
State1 = State#{stream_remote_id => ID, lang => Lang},
|
|
State2 = try Mod:handle_stream_start(StreamStart, State1)
|
|
catch _:undef -> State1
|
|
end,
|
|
case is_disconnected(State2) of
|
|
true -> State2;
|
|
false ->
|
|
case Version of
|
|
{1, _} ->
|
|
State2#{stream_state => wait_for_features};
|
|
_ ->
|
|
process_stream_downgrade(StreamStart, State2)
|
|
end
|
|
end.
|
|
|
|
-spec process_element(xmpp_element(), state()) -> state().
|
|
process_element(Pkt, #{stream_state := StateName} = State) ->
|
|
case Pkt of
|
|
#stream_features{} when StateName == wait_for_features ->
|
|
process_features(Pkt, State);
|
|
#starttls_proceed{} when StateName == wait_for_starttls_response ->
|
|
process_starttls(State);
|
|
#sasl_success{} when StateName == wait_for_sasl_response ->
|
|
process_sasl_success(State);
|
|
#sasl_failure{} when StateName == wait_for_sasl_response ->
|
|
process_sasl_failure(Pkt, State);
|
|
#stream_error{} ->
|
|
process_stream_end({stream, {in, Pkt}}, State);
|
|
_ when is_record(Pkt, stream_features);
|
|
is_record(Pkt, starttls_proceed);
|
|
is_record(Pkt, starttls);
|
|
is_record(Pkt, sasl_auth);
|
|
is_record(Pkt, sasl_success);
|
|
is_record(Pkt, sasl_failure);
|
|
is_record(Pkt, sasl_response);
|
|
is_record(Pkt, sasl_abort);
|
|
is_record(Pkt, compress);
|
|
is_record(Pkt, handshake) ->
|
|
%% Do not pass this crap upstream
|
|
State;
|
|
_ ->
|
|
process_packet(Pkt, State)
|
|
end.
|
|
|
|
-spec process_features(stream_features(), state()) -> state().
|
|
process_features(StreamFeatures,
|
|
#{stream_authenticated := true, mod := Mod} = State) ->
|
|
State1 = try Mod:handle_authenticated_features(StreamFeatures, State)
|
|
catch _:undef -> State
|
|
end,
|
|
process_stream_established(State1);
|
|
process_features(#stream_features{sub_els = Els} = StreamFeatures,
|
|
#{stream_encrypted := Encrypted,
|
|
mod := Mod, lang := Lang} = State) ->
|
|
State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State)
|
|
catch _:undef -> State
|
|
end,
|
|
case is_disconnected(State1) of
|
|
true -> State1;
|
|
false ->
|
|
TLSRequired = is_starttls_required(State1),
|
|
TLSAvailable = is_starttls_available(State1),
|
|
%% TODO: improve xmpp.erl
|
|
Msg = #message{sub_els = Els},
|
|
case xmpp:get_subtag(Msg, #starttls{}) of
|
|
false when TLSRequired and not Encrypted ->
|
|
Txt = <<"Use of STARTTLS required">>,
|
|
send_pkt(State1, xmpp:serr_policy_violation(Txt, Lang));
|
|
false when not Encrypted ->
|
|
process_sasl_failure(
|
|
<<"Peer doesn't support STARTTLS">>, State1);
|
|
#starttls{required = true} when not TLSAvailable and not Encrypted ->
|
|
Txt = <<"Use of STARTTLS forbidden">>,
|
|
send_pkt(State1, xmpp:serr_unsupported_feature(Txt, Lang));
|
|
#starttls{} when TLSAvailable and not Encrypted ->
|
|
State2 = State1#{stream_state => wait_for_starttls_response},
|
|
send_pkt(State2, #starttls{});
|
|
#starttls{} when not Encrypted ->
|
|
process_sasl_failure(
|
|
<<"STARTTLS is disabled in local configuration">>, State1);
|
|
_ ->
|
|
State2 = process_cert_verification(State1),
|
|
case is_disconnected(State2) of
|
|
true -> State2;
|
|
false ->
|
|
case xmpp:get_subtag(Msg, #sasl_mechanisms{}) of
|
|
#sasl_mechanisms{list = Mechs} ->
|
|
process_sasl_mechanisms(Mechs, State2);
|
|
false ->
|
|
process_sasl_failure(
|
|
<<"Peer provided no SASL mechanisms">>, State2)
|
|
end
|
|
end
|
|
end
|
|
end.
|
|
|
|
-spec process_stream_established(state()) -> state().
|
|
process_stream_established(#{stream_state := StateName} = State)
|
|
when StateName == disconnected; StateName == established ->
|
|
State;
|
|
process_stream_established(#{mod := Mod} = State) ->
|
|
State1 = State#{stream_authenticated := true,
|
|
stream_state => established,
|
|
stream_timeout => infinity},
|
|
try Mod:handle_stream_established(State1)
|
|
catch _:undef -> State1
|
|
end.
|
|
|
|
-spec process_sasl_mechanisms([binary()], state()) -> state().
|
|
process_sasl_mechanisms(Mechs, #{user := User, server := Server} = State) ->
|
|
%% TODO: support other mechanisms
|
|
Mech = <<"EXTERNAL">>,
|
|
case lists:member(<<"EXTERNAL">>, Mechs) of
|
|
true ->
|
|
State1 = State#{stream_state => wait_for_sasl_response},
|
|
Authzid = jid:encode(jid:make(User, Server)),
|
|
send_pkt(State1, #sasl_auth{mechanism = Mech, text = Authzid});
|
|
false ->
|
|
process_sasl_failure(
|
|
<<"Peer doesn't support EXTERNAL authentication">>, State)
|
|
end.
|
|
|
|
-spec process_starttls(state()) -> state().
|
|
process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) ->
|
|
TLSOpts = try Mod:tls_options(State)
|
|
catch _:undef -> []
|
|
end,
|
|
case SockMod:starttls(Socket, [connect|TLSOpts]) of
|
|
{ok, TLSSocket} ->
|
|
State1 = State#{socket => TLSSocket,
|
|
stream_id => new_id(),
|
|
stream_restarted => true,
|
|
stream_state => wait_for_stream,
|
|
stream_encrypted => true},
|
|
send_header(State1);
|
|
{error, Why} ->
|
|
process_stream_end({tls, Why}, State)
|
|
end.
|
|
|
|
-spec process_stream_downgrade(stream_start(), state()) -> state().
|
|
process_stream_downgrade(StreamStart,
|
|
#{mod := Mod, lang := Lang,
|
|
stream_encrypted := Encrypted} = State) ->
|
|
TLSRequired = is_starttls_required(State),
|
|
if not Encrypted and TLSRequired ->
|
|
Txt = <<"Use of STARTTLS required">>,
|
|
send_pkt(State, xmpp:serr_policy_violation(Txt, Lang));
|
|
true ->
|
|
State1 = State#{stream_state => downgraded},
|
|
try Mod:handle_stream_downgraded(StreamStart, State1)
|
|
catch _:undef ->
|
|
send_pkt(State1, xmpp:serr_unsupported_version())
|
|
end
|
|
end.
|
|
|
|
-spec process_cert_verification(state()) -> state().
|
|
process_cert_verification(#{stream_encrypted := true,
|
|
stream_verified := false,
|
|
mod := Mod} = State) ->
|
|
case try Mod:tls_verify(State)
|
|
catch _:undef -> true
|
|
end of
|
|
true ->
|
|
case xmpp_stream_pkix:authenticate(State) of
|
|
{ok, _} ->
|
|
State#{stream_verified => true};
|
|
{error, Why, _Peer} ->
|
|
process_stream_end({pkix, Why}, State)
|
|
end;
|
|
false ->
|
|
State#{stream_verified => true}
|
|
end;
|
|
process_cert_verification(State) ->
|
|
State.
|
|
|
|
-spec process_sasl_success(state()) -> state().
|
|
process_sasl_success(#{mod := Mod,
|
|
sockmod := SockMod,
|
|
socket := Socket} = State) ->
|
|
SockMod:reset_stream(Socket),
|
|
State1 = State#{stream_id => new_id(),
|
|
stream_restarted => true,
|
|
stream_state => wait_for_stream,
|
|
stream_authenticated => true},
|
|
State2 = send_header(State1),
|
|
case is_disconnected(State2) of
|
|
true -> State2;
|
|
false ->
|
|
try Mod:handle_auth_success(<<"EXTERNAL">>, State2)
|
|
catch _:undef -> State2
|
|
end
|
|
end.
|
|
|
|
-spec process_sasl_failure(sasl_failure() | binary(), state()) -> state().
|
|
process_sasl_failure(#sasl_failure{} = Failure, State) ->
|
|
Reason = format("Peer responded with error: ~s",
|
|
[format_sasl_failure(Failure)]),
|
|
process_sasl_failure(Reason, State);
|
|
process_sasl_failure(Reason, #{mod := Mod} = State) ->
|
|
try Mod:handle_auth_failure(<<"EXTERNAL">>, {auth, Reason}, State)
|
|
catch _:undef -> process_stream_end({auth, Reason}, State)
|
|
end.
|
|
|
|
-spec process_packet(xmpp_element(), state()) -> state().
|
|
process_packet(Pkt, #{mod := Mod} = State) ->
|
|
try Mod:handle_packet(Pkt, State)
|
|
catch _:undef -> State
|
|
end.
|
|
|
|
-spec is_starttls_required(state()) -> boolean().
|
|
is_starttls_required(#{mod := Mod} = State) ->
|
|
try Mod:tls_required(State)
|
|
catch _:undef -> false
|
|
end.
|
|
|
|
-spec is_starttls_available(state()) -> boolean().
|
|
is_starttls_available(#{mod := Mod} = State) ->
|
|
try Mod:tls_enabled(State)
|
|
catch _:undef -> true
|
|
end.
|
|
|
|
-spec send_header(state()) -> state().
|
|
send_header(#{remote_server := RemoteServer,
|
|
stream_encrypted := Encrypted,
|
|
lang := Lang,
|
|
xmlns := NS,
|
|
user := User,
|
|
resource := Resource,
|
|
server := Server} = State) ->
|
|
NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
|
|
true -> <<"">>
|
|
end,
|
|
From = if Encrypted ->
|
|
jid:make(User, Server, Resource);
|
|
NS == ?NS_SERVER ->
|
|
jid:make(Server);
|
|
true ->
|
|
undefined
|
|
end,
|
|
StreamStart = #stream_start{xmlns = NS,
|
|
lang = Lang,
|
|
stream_xmlns = ?NS_STREAM,
|
|
db_xmlns = NS_DB,
|
|
from = From,
|
|
to = jid:make(RemoteServer),
|
|
version = {1,0}},
|
|
case socket_send(State, StreamStart) of
|
|
ok -> State;
|
|
{error, Why} -> process_stream_end({socket, Why}, State)
|
|
end.
|
|
|
|
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
|
|
send_pkt(#{mod := Mod} = State, Pkt) ->
|
|
Result = socket_send(State, Pkt),
|
|
State1 = try Mod:handle_send(Pkt, Result, State)
|
|
catch _:undef -> State
|
|
end,
|
|
case Result of
|
|
_ when is_record(Pkt, stream_error) ->
|
|
process_stream_end({stream, {out, Pkt}}, State1);
|
|
ok ->
|
|
State1;
|
|
{error, Why} ->
|
|
process_stream_end({socket, Why}, State1)
|
|
end.
|
|
|
|
-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state().
|
|
send_error(State, Pkt, Err) ->
|
|
case xmpp:is_stanza(Pkt) of
|
|
true ->
|
|
case xmpp:get_type(Pkt) of
|
|
result -> State;
|
|
error -> State;
|
|
<<"result">> -> State;
|
|
<<"error">> -> State;
|
|
_ ->
|
|
ErrPkt = xmpp:make_error(Pkt, Err),
|
|
send_pkt(State, ErrPkt)
|
|
end;
|
|
false ->
|
|
State
|
|
end.
|
|
|
|
-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}.
|
|
socket_send(#{sockmod := SockMod, socket := Socket, xmlns := NS,
|
|
stream_state := StateName}, Pkt) ->
|
|
case Pkt of
|
|
trailer ->
|
|
SockMod:send_trailer(Socket);
|
|
#stream_start{} when StateName /= disconnected ->
|
|
SockMod:send_header(Socket, xmpp:encode(Pkt));
|
|
_ when StateName /= disconnected ->
|
|
SockMod:send_element(Socket, xmpp:encode(Pkt, NS));
|
|
_ ->
|
|
{error, closed}
|
|
end;
|
|
socket_send(_, _) ->
|
|
{error, closed}.
|
|
|
|
-spec send_trailer(state()) -> state().
|
|
send_trailer(State) ->
|
|
socket_send(State, trailer),
|
|
close_socket(State).
|
|
|
|
-spec close_socket(state()) -> state().
|
|
close_socket(State) ->
|
|
case State of
|
|
#{sockmod := SockMod, socket := Socket} ->
|
|
SockMod:close(Socket);
|
|
_ ->
|
|
ok
|
|
end,
|
|
State#{stream_timeout => infinity,
|
|
stream_state => disconnected}.
|
|
|
|
-spec select_lang(binary(), binary()) -> binary().
|
|
select_lang(Lang, <<"">>) -> Lang;
|
|
select_lang(_, Lang) -> Lang.
|
|
|
|
-spec format_inet_error(atom()) -> string().
|
|
format_inet_error(closed) ->
|
|
"connection closed";
|
|
format_inet_error(Reason) ->
|
|
case inet:format_error(Reason) of
|
|
"unknown POSIX error" -> atom_to_list(Reason);
|
|
Txt -> Txt
|
|
end.
|
|
|
|
-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
|
|
format_stream_error(Reason, Txt) ->
|
|
Slogan = case Reason of
|
|
undefined -> "no reason";
|
|
#'see-other-host'{} -> "see-other-host";
|
|
_ -> atom_to_list(Reason)
|
|
end,
|
|
case Txt of
|
|
undefined -> Slogan;
|
|
#text{data = <<"">>} -> Slogan;
|
|
#text{data = Data} ->
|
|
binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
|
|
end.
|
|
|
|
-spec format_tls_error(atom() | binary()) -> list().
|
|
format_tls_error(Reason) when is_atom(Reason) ->
|
|
format_inet_error(Reason);
|
|
format_tls_error(Reason) ->
|
|
binary_to_list(Reason).
|
|
|
|
format_sasl_failure(#sasl_failure{reason = Reason, text = Txt}) ->
|
|
Slogan = case Reason of
|
|
undefined -> "no reason";
|
|
_ -> atom_to_list(Reason)
|
|
end,
|
|
case xmpp:get_text(Txt) of
|
|
<<"">> -> Slogan;
|
|
Data ->
|
|
binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
|
|
end.
|
|
|
|
-spec format(io:format(), list()) -> binary().
|
|
format(Fmt, Args) ->
|
|
iolist_to_binary(io_lib:format(Fmt, Args)).
|
|
|
|
%%%===================================================================
|
|
%%% Connection stuff
|
|
%%%===================================================================
|
|
-spec idna_to_ascii(binary()) -> binary() | false.
|
|
idna_to_ascii(<<$[, _/binary>> = Host) ->
|
|
%% This is an IPv6 address in 'IP-literal' format (as per RFC7622)
|
|
%% We remove brackets here
|
|
case binary:last(Host) of
|
|
$] ->
|
|
IPv6 = binary:part(Host, {1, size(Host)-2}),
|
|
case inet:parse_ipv6strict_address(binary_to_list(IPv6)) of
|
|
{ok, _} -> IPv6;
|
|
{error, _} -> false
|
|
end;
|
|
_ ->
|
|
false
|
|
end;
|
|
idna_to_ascii(Host) ->
|
|
case inet:parse_address(binary_to_list(Host)) of
|
|
{ok, _} -> Host;
|
|
{error, _} -> ejabberd_idna:domain_utf8_to_ascii(Host)
|
|
end.
|
|
|
|
-spec resolve(string(), state()) -> {ok, [ip_port()]} | network_error().
|
|
resolve(Host, State) ->
|
|
case srv_lookup(Host, State) of
|
|
{error, _Reason} ->
|
|
DefaultPort = get_default_port(State),
|
|
a_lookup([{Host, DefaultPort}], State);
|
|
{ok, HostPorts} ->
|
|
a_lookup(HostPorts, State)
|
|
end.
|
|
|
|
-spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error().
|
|
srv_lookup(_Host, #{xmlns := ?NS_COMPONENT}) ->
|
|
%% Do not attempt to lookup SRV for component connections
|
|
{error, nxdomain};
|
|
srv_lookup(Host, State) ->
|
|
%% Only perform SRV lookups for FQDN names
|
|
case string:chr(Host, $.) of
|
|
0 ->
|
|
{error, nxdomain};
|
|
_ ->
|
|
case inet:parse_address(Host) of
|
|
{ok, _} ->
|
|
{error, nxdomain};
|
|
{error, _} ->
|
|
Timeout = get_dns_timeout(State),
|
|
Retries = get_dns_retries(State),
|
|
srv_lookup(Host, Timeout, Retries)
|
|
end
|
|
end.
|
|
|
|
-spec srv_lookup(string(), timeout(), integer()) ->
|
|
{ok, [host_port()]} | network_error().
|
|
srv_lookup(_Host, _Timeout, Retries) when Retries < 1 ->
|
|
{error, timeout};
|
|
srv_lookup(Host, Timeout, Retries) ->
|
|
SRVName = "_xmpp-server._tcp." ++ Host,
|
|
case inet_res:getbyname(SRVName, srv, Timeout) of
|
|
{ok, HostEntry} ->
|
|
host_entry_to_host_ports(HostEntry);
|
|
{error, timeout} ->
|
|
srv_lookup(Host, Timeout, Retries - 1);
|
|
{error, _} = Err ->
|
|
Err
|
|
end.
|
|
|
|
-spec a_lookup([{inet:hostname(), inet:port_number()}], state()) ->
|
|
{ok, [ip_port()]} | network_error().
|
|
a_lookup(HostPorts, State) ->
|
|
HostPortFamilies = [{Host, Port, Family}
|
|
|| {Host, Port} <- HostPorts,
|
|
Family <- get_address_families(State)],
|
|
a_lookup(HostPortFamilies, State, [], {error, nxdomain}).
|
|
|
|
-spec a_lookup([{inet:hostname(), inet:port_number(), inet:address_family()}],
|
|
state(), [ip_port()], network_error()) -> {ok, [ip_port()]} | network_error().
|
|
a_lookup([{Host, Port, Family}|HostPortFamilies], State, Acc, Err) ->
|
|
Timeout = get_dns_timeout(State),
|
|
Retries = get_dns_retries(State),
|
|
case a_lookup(Host, Port, Family, Timeout, Retries) of
|
|
{error, Reason} ->
|
|
a_lookup(HostPortFamilies, State, Acc, {error, Reason});
|
|
{ok, AddrPorts} ->
|
|
a_lookup(HostPortFamilies, State, Acc ++ AddrPorts, Err)
|
|
end;
|
|
a_lookup([], _State, [], Err) ->
|
|
Err;
|
|
a_lookup([], _State, Acc, _) ->
|
|
{ok, Acc}.
|
|
|
|
-spec a_lookup(inet:hostname(), inet:port_number(), inet:address_family(),
|
|
timeout(), integer()) -> {ok, [ip_port()]} | network_error().
|
|
a_lookup(_Host, _Port, _Family, _Timeout, Retries) when Retries < 1 ->
|
|
{error, timeout};
|
|
a_lookup(Host, Port, Family, Timeout, Retries) ->
|
|
Start = p1_time_compat:monotonic_time(milli_seconds),
|
|
case inet:gethostbyname(Host, Family, Timeout) of
|
|
{error, nxdomain} = Err ->
|
|
%% inet:gethostbyname/3 doesn't return {error, timeout},
|
|
%% so we should check if 'nxdomain' is in fact a result
|
|
%% of a timeout.
|
|
%% We also cannot use inet_res:gethostbyname/3 because
|
|
%% it ignores DNS configuration settings (/etc/hosts, etc)
|
|
End = p1_time_compat:monotonic_time(milli_seconds),
|
|
if (End - Start) >= Timeout ->
|
|
a_lookup(Host, Port, Family, Timeout, Retries - 1);
|
|
true ->
|
|
Err
|
|
end;
|
|
{error, _} = Err ->
|
|
Err;
|
|
{ok, HostEntry} ->
|
|
host_entry_to_addr_ports(HostEntry, Port)
|
|
end.
|
|
|
|
-spec host_entry_to_host_ports(inet:hostent()) -> {ok, [host_port()]} |
|
|
{error, nxdomain}.
|
|
host_entry_to_host_ports(#hostent{h_addr_list = AddrList}) ->
|
|
PrioHostPorts = lists:flatmap(
|
|
fun({Priority, Weight, Port, Host}) ->
|
|
N = case Weight of
|
|
0 -> 0;
|
|
_ -> (Weight + 1) * randoms:uniform()
|
|
end,
|
|
[{Priority * 65536 - N, Host, Port}];
|
|
(_) ->
|
|
[]
|
|
end, AddrList),
|
|
HostPorts = [{Host, Port}
|
|
|| {_Priority, Host, Port} <- lists:usort(PrioHostPorts)],
|
|
case HostPorts of
|
|
[] -> {error, nxdomain};
|
|
_ -> {ok, HostPorts}
|
|
end.
|
|
|
|
-spec host_entry_to_addr_ports(inet:hostent(), inet:port_number()) ->
|
|
{ok, [ip_port()]} | {error, nxdomain}.
|
|
host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port) ->
|
|
AddrPorts = lists:flatmap(
|
|
fun(Addr) ->
|
|
try get_addr_type(Addr) of
|
|
_ -> [{Addr, Port}]
|
|
catch _:_ ->
|
|
[]
|
|
end
|
|
end, AddrList),
|
|
case AddrPorts of
|
|
[] -> {error, nxdomain};
|
|
_ -> {ok, AddrPorts}
|
|
end.
|
|
|
|
-spec connect([ip_port()], state()) -> {ok, term(), ip_port()} | network_error().
|
|
connect(AddrPorts, #{sockmod := SockMod} = State) ->
|
|
Timeout = get_connect_timeout(State),
|
|
connect(AddrPorts, SockMod, Timeout, {error, nxdomain}).
|
|
|
|
-spec connect([ip_port()], module(), timeout(), network_error()) ->
|
|
{ok, term(), ip_port()} | network_error().
|
|
connect([{Addr, Port}|AddrPorts], SockMod, Timeout, _) ->
|
|
Type = get_addr_type(Addr),
|
|
try SockMod:connect(Addr, Port,
|
|
[binary, {packet, 0},
|
|
{send_timeout, ?TCP_SEND_TIMEOUT},
|
|
{send_timeout_close, true},
|
|
{active, false}, Type],
|
|
Timeout) of
|
|
{ok, Socket} ->
|
|
{ok, Socket, {Addr, Port}};
|
|
Err ->
|
|
connect(AddrPorts, SockMod, Timeout, Err)
|
|
catch _:badarg ->
|
|
connect(AddrPorts, SockMod, Timeout, {error, einval})
|
|
end;
|
|
connect([], _SockMod, _Timeout, Err) ->
|
|
Err.
|
|
|
|
-spec get_addr_type(inet:ip_address()) -> inet:address_family().
|
|
get_addr_type({_, _, _, _}) -> inet;
|
|
get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
|
|
|
|
-spec get_dns_timeout(state()) -> timeout().
|
|
get_dns_timeout(#{mod := Mod} = State) ->
|
|
try Mod:dns_timeout(State)
|
|
catch _:undef -> timer:seconds(10)
|
|
end.
|
|
|
|
-spec get_dns_retries(state()) -> non_neg_integer().
|
|
get_dns_retries(#{mod := Mod} = State) ->
|
|
try Mod:dns_retries(State)
|
|
catch _:undef -> 2
|
|
end.
|
|
|
|
-spec get_default_port(state()) -> inet:port_number().
|
|
get_default_port(#{mod := Mod, xmlns := NS} = State) ->
|
|
try Mod:default_port(State)
|
|
catch _:undef when NS == ?NS_SERVER -> 5269;
|
|
_:undef when NS == ?NS_CLIENT -> 5222
|
|
end.
|
|
|
|
-spec get_address_families(state()) -> [inet:address_family()].
|
|
get_address_families(#{mod := Mod} = State) ->
|
|
try Mod:address_families(State)
|
|
catch _:undef -> [inet, inet6]
|
|
end.
|
|
|
|
-spec get_connect_timeout(state()) -> timeout().
|
|
get_connect_timeout(#{mod := Mod} = State) ->
|
|
try Mod:connect_timeout(State)
|
|
catch _:undef -> timer:seconds(10)
|
|
end.
|