24
1
mirror of https://github.com/processone/ejabberd.git synced 2024-05-31 21:07:55 +02:00
xmpp.chapril.org-ejabberd/src/xmpp_stream_in.erl
Evgeniy Khramtsov 672c2f75d3 Introduce option 'validate_stream'
If set to `true`, all incoming XML packets are fully validated
against known schemas. If an error occurs, the packet will be bounced
with the corresponding error reason. The default value is `false`.
The option might be useful to protect client software from sofisticated
bugs related to XML validation as well as for client developers
who want to catch validation errors at early stage of development.

Note that the option might have slight performance impact, so use it
with care on loaded machines.
2018-02-09 18:12:50 +03:00

1197 lines
40 KiB
Erlang

%%%-------------------------------------------------------------------
%%% Created : 26 Nov 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
%%%
%%%
%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
%%%
%%% This program is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU General Public License as
%%% published by the Free Software Foundation; either version 2 of the
%%% License, or (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
%%% General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License along
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%-------------------------------------------------------------------
-module(xmpp_stream_in).
-define(GEN_SERVER, p1_server).
-behaviour(?GEN_SERVER).
-protocol({rfc, 6120}).
-protocol({xep, 114, '1.6'}).
%% API
-export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1,
send/2, close/1, close/2, send_error/3, establish/1,
get_transport/1, change_shaper/2, set_timeout/2, format_error/1]).
%% gen_server callbacks
-export([init/1, handle_cast/2, handle_call/3, handle_info/2,
terminate/2, code_change/3]).
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
-define(FSMOPTS, [{debug, [trace]}]).
-else.
-define(FSMOPTS, []).
-endif.
-include("xmpp.hrl").
-type state() :: map().
-type stop_reason() :: {stream, reset | {in | out, stream_error()}} |
{tls, inet:posix() | atom() | binary()} |
{socket, inet:posix() | atom()} |
internal_failure.
-export_type([state/0, stop_reason/0]).
-callback init(list()) -> {ok, state()} | {error, term()} | ignore.
-callback handle_cast(term(), state()) -> state().
-callback handle_call(term(), term(), state()) -> state().
-callback handle_info(term(), state()) -> state().
-callback terminate(term(), state()) -> any().
-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
-callback handle_stream_start(stream_start(), state()) -> state().
-callback handle_stream_established(state()) -> state().
-callback handle_stream_end(stop_reason(), state()) -> state().
-callback handle_cdata(binary(), state()) -> state().
-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
-callback handle_authenticated_packet(xmpp_element(), state()) -> state().
-callback handle_unbinded_packet(xmpp_element(), state()) -> state().
-callback handle_auth_success(binary(), binary(), module(), state()) -> state().
-callback handle_auth_failure(binary(), binary(), binary(), state()) -> state().
-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
-callback handle_timeout(state()) -> state().
-callback get_password_fun(state()) -> fun().
-callback check_password_fun(state()) -> fun().
-callback check_password_digest_fun(state()) -> fun().
-callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}.
-callback compress_methods(state()) -> [binary()].
-callback tls_options(state()) -> [proplists:property()].
-callback tls_required(state()) -> boolean().
-callback tls_verify(state()) -> boolean().
-callback tls_enabled(state()) -> boolean().
-callback sasl_mechanisms([cyrsasl:mechanism()], state()) -> [cyrsasl:mechanism()].
-callback unauthenticated_stream_features(state()) -> [xmpp_element()].
-callback authenticated_stream_features(state()) -> [xmpp_element()].
%% All callbacks are optional
-optional_callbacks([init/1,
handle_cast/2,
handle_call/3,
handle_info/2,
terminate/2,
code_change/3,
handle_stream_start/2,
handle_stream_established/1,
handle_stream_end/2,
handle_cdata/2,
handle_authenticated_packet/2,
handle_unauthenticated_packet/2,
handle_unbinded_packet/2,
handle_auth_success/4,
handle_auth_failure/4,
handle_send/3,
handle_recv/3,
handle_timeout/1,
get_password_fun/1,
check_password_fun/1,
check_password_digest_fun/1,
bind/2,
compress_methods/1,
tls_options/1,
tls_required/1,
tls_verify/1,
tls_enabled/1,
sasl_mechanisms/2,
unauthenticated_stream_features/1,
authenticated_stream_features/1]).
%%%===================================================================
%%% API
%%%===================================================================
start(Mod, Args, Opts) ->
?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
start_link(Mod, Args, Opts) ->
?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
call(Ref, Msg, Timeout) ->
?GEN_SERVER:call(Ref, Msg, Timeout).
cast(Ref, Msg) ->
?GEN_SERVER:cast(Ref, Msg).
reply(Ref, Reply) ->
?GEN_SERVER:reply(Ref, Reply).
-spec stop(pid()) -> ok;
(state()) -> no_return().
stop(Pid) when is_pid(Pid) ->
cast(Pid, stop);
stop(#{owner := Owner} = State) when Owner == self() ->
terminate(normal, State),
exit(normal);
stop(_) ->
erlang:error(badarg).
-spec send(pid(), xmpp_element()) -> ok;
(state(), xmpp_element()) -> state().
send(Pid, Pkt) when is_pid(Pid) ->
cast(Pid, {send, Pkt});
send(#{owner := Owner} = State, Pkt) when Owner == self() ->
send_pkt(State, Pkt);
send(_, _) ->
erlang:error(badarg).
-spec close(pid()) -> ok;
(state()) -> state().
close(Pid) when is_pid(Pid) ->
close(Pid, closed);
close(#{owner := Owner} = State) when Owner == self() ->
close_socket(State);
close(_) ->
erlang:error(badarg).
-spec close(pid(), atom()) -> ok.
close(Pid, Reason) ->
cast(Pid, {close, Reason}).
-spec establish(state()) -> state().
establish(State) ->
process_stream_established(State).
-spec set_timeout(state(), non_neg_integer() | infinity) -> state().
set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
case Timeout of
infinity -> State#{stream_timeout => infinity};
_ ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State#{stream_timeout => {Timeout, Time}}
end;
set_timeout(_, _) ->
erlang:error(badarg).
get_transport(#{socket := Socket, owner := Owner})
when Owner == self() ->
xmpp_socket:get_transport(Socket);
get_transport(_) ->
erlang:error(badarg).
-spec change_shaper(state(), shaper:shaper()) -> state().
change_shaper(#{socket := Socket, owner := Owner} = State, Shaper)
when Owner == self() ->
Socket1 = xmpp_socket:change_shaper(Socket, Shaper),
State#{socket => Socket1};
change_shaper(_, _) ->
erlang:error(badarg).
-spec format_error(stop_reason()) -> binary().
format_error({socket, Reason}) ->
format("Connection failed: ~s", [format_inet_error(Reason)]);
format_error({stream, reset}) ->
<<"Stream reset by peer">>;
format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]);
format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
format_error({tls, Reason}) ->
format("TLS failed: ~s", [format_tls_error(Reason)]);
format_error(internal_failure) ->
<<"Internal server error">>;
format_error(Err) ->
format("Unrecognized error: ~w", [Err]).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([Module, {_SockMod, Socket}, Opts]) ->
Encrypted = proplists:get_bool(tls, Opts),
SocketMonitor = xmpp_socket:monitor(Socket),
case xmpp_socket:peername(Socket) of
{ok, IP} ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State = #{owner => self(),
mod => Module,
socket => Socket,
socket_monitor => SocketMonitor,
stream_timeout => {timer:seconds(30), Time},
stream_direction => in,
stream_id => new_id(),
stream_state => wait_for_stream,
stream_header_sent => false,
stream_restarted => false,
stream_compressed => false,
stream_encrypted => Encrypted,
stream_version => {1,0},
stream_authenticated => false,
codec_options => [ignore_els],
xmlns => ?NS_CLIENT,
lang => <<"">>,
user => <<"">>,
server => <<"">>,
resource => <<"">>,
lserver => <<"">>,
ip => IP},
case try Module:init([State, Opts])
catch _:undef -> {ok, State}
end of
{ok, State1} when not Encrypted ->
{_, State2, Timeout} = noreply(State1),
{ok, State2, Timeout};
{ok, State1} when Encrypted ->
TLSOpts = try Module:tls_options(State1)
catch _:undef -> []
end,
case xmpp_socket:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
State2 = State1#{socket => TLSSocket},
{_, State3, Timeout} = noreply(State2),
{ok, State3, Timeout};
{error, Reason} ->
{stop, Reason}
end;
{error, Reason} ->
{stop, Reason};
ignore ->
ignore
end;
{error, _Reason} ->
ignore
end.
handle_cast({send, Pkt}, State) ->
noreply(send_pkt(State, Pkt));
handle_cast(stop, State) ->
{stop, normal, State};
handle_cast({close, Reason}, State) ->
State1 = close_socket(State),
noreply(
case is_disconnected(State) of
true -> State1;
false -> process_stream_end({socket, Reason}, State)
end);
handle_cast(Cast, #{mod := Mod} = State) ->
noreply(try Mod:handle_cast(Cast, State)
catch _:undef -> State
end).
handle_call(Call, From, #{mod := Mod} = State) ->
noreply(try Mod:handle_call(Call, From, State)
catch _:undef -> State
end).
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
#{stream_state := wait_for_stream,
xmlns := XMLNS, lang := MyLang} = State) ->
El = #xmlel{name = Name, attrs = Attrs},
noreply(
try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt ->
State1 = send_header(State, Pkt),
case is_disconnected(State1) of
true -> State1;
false -> process_stream(Pkt, State1)
end;
_ ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false -> send_pkt(State1, xmpp:serr_invalid_xml())
end
catch _:{xmpp_codec, Why} ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang),
send_pkt(State1, Err)
end
end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_end({socket, closed}, State));
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
State1 = send_header(State),
noreply(
case is_disconnected(State1) of
true -> State1;
false ->
Err = case Reason of
<<"XML stanza is too big">> ->
xmpp:serr_policy_violation(Reason, Lang);
{_, Txt} ->
xmpp:serr_not_well_formed(Txt, Lang)
end,
send_pkt(State1, Err)
end);
handle_info({'$gen_event', El}, #{stream_state := wait_for_stream} = State) ->
error_logger:warning_msg("unexpected event from XML driver: ~p; "
"xmlstreamstart was expected", [El]),
State1 = send_header(State),
noreply(
case is_disconnected(State1) of
true -> State1;
false -> send_pkt(State1, xmpp:serr_invalid_xml())
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, mod := Mod, codec_options := Opts} = State) ->
noreply(
try xmpp:decode(El, NS, Opts) of
Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_element(Pkt, State1)
end
catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, {error, Why}, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_invalid_xml(State1, El, Why)
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) ->
noreply(try Mod:handle_cdata(Data, State)
catch _:undef -> State
end);
handle_info(timeout, #{mod := Mod, lang := Lang} = State) ->
Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State)
catch _:undef when not Disconnected ->
Txt = <<"Idle connection">>,
send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
_:undef ->
stop(State)
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
#{socket_monitor := MRef} = State) ->
noreply(process_stream_end({socket, closed}, State));
handle_info({tcp, _, Data}, #{socket := Socket} = State) ->
noreply(
case xmpp_socket:recv(Socket, Data) of
{ok, NewSocket} ->
State#{socket => NewSocket};
{error, Reason} when is_atom(Reason) ->
process_stream_end({socket, Reason}, State);
{error, Reason} ->
%% TODO: make fast_tls return atoms
process_stream_end({tls, Reason}, State)
end);
handle_info({tcp_closed, _}, State) ->
handle_info({'$gen_event', closed}, State);
handle_info({tcp_error, _, Reason}, State) ->
noreply(process_stream_end({socket, Reason}, State));
handle_info(Info, #{mod := Mod} = State) ->
noreply(try Mod:handle_info(Info, State)
catch _:undef -> State
end).
terminate(Reason, #{mod := Mod} = State) ->
case get(already_terminated) of
true ->
State;
_ ->
put(already_terminated, true),
try Mod:terminate(Reason, State)
catch _:undef -> ok
end,
send_trailer(State)
end.
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
Mod:code_change(OldVsn, State, Extra).
%%%===================================================================
%%% Internal functions
%%%===================================================================
-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
noreply(#{stream_timeout := infinity} = State) ->
{noreply, State, infinity};
noreply(#{stream_timeout := {MSecs, StartTime}} = State) ->
CurrentTime = p1_time_compat:monotonic_time(milli_seconds),
Timeout = max(0, MSecs - CurrentTime + StartTime),
{noreply, State, Timeout}.
-spec new_id() -> binary().
new_id() ->
randoms:get_string().
-spec is_disconnected(state()) -> boolean().
is_disconnected(#{stream_state := StreamState}) ->
StreamState == disconnected.
-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
case xmpp:is_stanza(El) of
true ->
Txt = xmpp:io_format_error(Reason),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
send_error(State, El, xmpp:err_bad_request(Txt, Lang));
false ->
case {xmpp:get_name(El), xmpp:get_ns(El)} of
{Tag, ?NS_SASL} when Tag == <<"auth">>;
Tag == <<"response">>;
Tag == <<"abort">> ->
Txt = xmpp:io_format_error(Reason),
Err = #sasl_failure{reason = 'malformed-request',
text = xmpp:mk_text(Txt, MyLang)},
send_pkt(State, Err);
{<<"starttls">>, ?NS_TLS} ->
send_pkt(State, #starttls_failure{});
{<<"compress">>, ?NS_COMPRESS} ->
Err = #compress_failure{reason = 'setup-failed'},
send_pkt(State, Err);
_ ->
%% Maybe add something more?
State
end
end.
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
process_stream_end(Reason, #{mod := Mod} = State) ->
State1 = State#{stream_timeout => infinity,
stream_state => disconnected},
try Mod:handle_stream_end(Reason, State1)
catch _:undef -> stop(State1)
end.
-spec process_stream(stream_start(), state()) -> state().
process_stream(#stream_start{xmlns = XML_NS,
stream_xmlns = STREAM_NS},
#{xmlns := NS} = State)
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
send_pkt(State, xmpp:serr_invalid_namespace());
process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
send_pkt(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang},
#{xmlns := ?NS_CLIENT, lang := DefaultLang} = State)
when size(Lang) > 35 ->
%% As stated in BCP47, 4.4.1:
%% Protocols or specifications that specify limited buffer sizes for
%% language tags MUST allow for language tags of at least 35 characters.
%% Do not store long language tag to avoid possible DoS/flood attacks
Txt = <<"Too long value of 'xml:lang' attribute">>,
send_pkt(State, xmpp:serr_policy_violation(Txt, DefaultLang));
process_stream(#stream_start{to = undefined, version = Version} = StreamStart,
#{lang := Lang, server := Server, xmlns := NS} = State) ->
if Version < {1,0} andalso NS /= ?NS_COMPONENT ->
%% Work-around for gmail servers
To = jid:make(Server),
process_stream(StreamStart#stream_start{to = To}, State);
true ->
Txt = <<"Missing 'to' attribute">>,
send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang))
end;
process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
#{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
Txt = <<"Improper 'to' attribute">>,
send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang));
process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
#{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
State1 = State#{remote_server => RemoteServer,
stream_state => wait_for_handshake},
try Mod:handle_stream_start(StreamStart, State1)
catch _:undef -> State1
end;
process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
from = From} = StreamStart,
#{stream_authenticated := Authenticated,
stream_restarted := StreamWasRestarted,
mod := Mod, xmlns := NS, resource := Resource,
stream_encrypted := Encrypted} = State) ->
State1 = if not StreamWasRestarted ->
State#{server => Server, lserver => LServer};
true ->
State
end,
State2 = case From of
#jid{lserver = RemoteServer} when NS == ?NS_SERVER ->
State1#{remote_server => RemoteServer};
_ ->
State1
end,
State3 = try Mod:handle_stream_start(StreamStart, State2)
catch _:undef -> State2
end,
case is_disconnected(State3) of
true -> State3;
false ->
State4 = send_features(State3),
case is_disconnected(State4) of
true -> State4;
false ->
TLSRequired = is_starttls_required(State4),
if not Authenticated and (TLSRequired and not Encrypted) ->
State4#{stream_state => wait_for_starttls};
not Authenticated ->
State4#{stream_state => wait_for_sasl_request};
(NS == ?NS_CLIENT) and (Resource == <<"">>) ->
State4#{stream_state => wait_for_bind};
true ->
process_stream_established(State4)
end
end
end.
-spec process_element(xmpp_element(), state()) -> state().
process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
case Pkt of
#starttls{} when StateName == wait_for_starttls;
StateName == wait_for_sasl_request ->
process_starttls(State);
#starttls{} ->
process_starttls_failure(unexpected_starttls_request, State);
#sasl_auth{} when StateName == wait_for_starttls ->
send_pkt(State, #sasl_failure{reason = 'encryption-required'});
#sasl_auth{} when StateName == wait_for_sasl_request ->
process_sasl_request(Pkt, State);
#sasl_auth{} when StateName == wait_for_sasl_response ->
process_sasl_request(Pkt, maps:remove(sasl_state, State));
#sasl_auth{} ->
Txt = <<"SASL negotiation is not allowed in this state">>,
send_pkt(State, #sasl_failure{reason = 'not-authorized',
text = xmpp:mk_text(Txt, Lang)});
#sasl_response{} when StateName == wait_for_starttls ->
send_pkt(State, #sasl_failure{reason = 'encryption-required'});
#sasl_response{} when StateName == wait_for_sasl_response ->
process_sasl_response(Pkt, State);
#sasl_response{} ->
Txt = <<"SASL negotiation is not allowed in this state">>,
send_pkt(State, #sasl_failure{reason = 'not-authorized',
text = xmpp:mk_text(Txt, Lang)});
#sasl_abort{} when StateName == wait_for_sasl_response ->
process_sasl_abort(State);
#sasl_abort{} ->
send_pkt(State, #sasl_failure{reason = 'aborted'});
#sasl_success{} ->
State;
#compress{} ->
process_compress(Pkt, State);
#handshake{} when StateName == wait_for_handshake ->
process_handshake(Pkt, State);
#handshake{} ->
State;
#stream_error{} ->
process_stream_end({stream, {in, Pkt}}, State);
_ when StateName == wait_for_sasl_request;
StateName == wait_for_handshake;
StateName == wait_for_sasl_response ->
process_unauthenticated_packet(Pkt, State);
_ when StateName == wait_for_starttls ->
Txt = <<"Use of STARTTLS required">>,
Err = xmpp:serr_policy_violation(Txt, Lang),
send_pkt(State, Err);
_ when StateName == wait_for_bind ->
process_bind(Pkt, State);
_ when StateName == established ->
process_authenticated_packet(Pkt, State)
end.
-spec process_unauthenticated_packet(xmpp_element(), state()) -> state().
process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
NewPkt = set_lang(Pkt, State),
try Mod:handle_unauthenticated_packet(NewPkt, State)
catch _:undef ->
Err = xmpp:serr_not_authorized(),
send(State, Err)
end.
-spec process_authenticated_packet(xmpp_element(), state()) -> state().
process_authenticated_packet(Pkt, #{mod := Mod} = State) ->
Pkt1 = set_lang(Pkt, State),
case set_from_to(Pkt1, State) of
{ok, Pkt2} ->
try Mod:handle_authenticated_packet(Pkt2, State)
catch _:undef ->
Err = xmpp:err_service_unavailable(),
send_error(State, Pkt, Err)
end;
{error, Err} ->
send_pkt(State, Err)
end.
-spec process_bind(xmpp_element(), state()) -> state().
process_bind(#iq{type = set, sub_els = [_]} = Pkt,
#{xmlns := ?NS_CLIENT, mod := Mod, lang := MyLang} = State) ->
try xmpp:try_subtag(Pkt, #bind{}) of
#bind{resource = R} ->
case Mod:bind(R, State) of
{ok, #{user := U, server := S, resource := NewR} = State1}
when NewR /= <<"">> ->
Reply = #bind{jid = jid:make(U, S, NewR)},
State2 = send_pkt(State1, xmpp:make_iq_result(Pkt, Reply)),
process_stream_established(State2);
{error, #stanza_error{} = Err, State1} ->
send_error(State1, Pkt, Err)
end;
_ ->
try Mod:handle_unbinded_packet(Pkt, State)
catch _:undef ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end
catch _:{xmpp_codec, Why} ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(Pkt)),
Err = xmpp:err_bad_request(Txt, Lang),
send_error(State, Pkt, Err)
end;
process_bind(Pkt, #{mod := Mod} = State) ->
try Mod:handle_unbinded_packet(Pkt, State)
catch _:undef ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end.
-spec process_handshake(handshake(), state()) -> state().
process_handshake(#handshake{data = Digest},
#{mod := Mod, stream_id := StreamID,
remote_server := RemoteServer} = State) ->
GetPW = try Mod:get_password_fun(State)
catch _:undef -> fun(_) -> {false, undefined} end
end,
AuthRes = case GetPW(<<"">>) of
{false, _} ->
false;
{Password, _} ->
str:sha(<<StreamID/binary, Password/binary>>) == Digest
end,
case AuthRes of
true ->
State1 = try Mod:handle_auth_success(
RemoteServer, <<"handshake">>, undefined, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
State2 = send_pkt(State1, #handshake{}),
process_stream_established(State2)
end;
false ->
State1 = try Mod:handle_auth_failure(
RemoteServer, <<"handshake">>, <<"not authorized">>, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> send_pkt(State1, xmpp:serr_not_authorized())
end
end.
-spec process_stream_established(state()) -> state().
process_stream_established(#{stream_state := StateName} = State)
when StateName == disconnected; StateName == established ->
State;
process_stream_established(#{mod := Mod} = State) ->
State1 = State#{stream_authenticated => true,
stream_state => established,
stream_timeout => infinity},
try Mod:handle_stream_established(State1)
catch _:undef -> State1
end.
-spec process_compress(compress(), state()) -> state().
process_compress(#compress{},
#{stream_compressed := Compressed,
stream_authenticated := Authenticated} = State)
when Compressed or not Authenticated ->
send_pkt(State, #compress_failure{reason = 'setup-failed'});
process_compress(#compress{methods = HisMethods},
#{socket := Socket, mod := Mod} = State) ->
MyMethods = try Mod:compress_methods(State)
catch _:undef -> []
end,
CommonMethods = lists_intersection(MyMethods, HisMethods),
case lists:member(<<"zlib">>, CommonMethods) of
true ->
case xmpp_socket:compress(Socket) of
{ok, ZlibSocket} ->
State1 = send_pkt(State, #compressed{}),
case is_disconnected(State1) of
true -> State1;
false ->
State1#{socket => ZlibSocket,
stream_id => new_id(),
stream_header_sent => false,
stream_restarted => true,
stream_state => wait_for_stream,
stream_compressed => true}
end;
{error, _} ->
Err = #compress_failure{reason = 'setup-failed'},
send_pkt(State, Err)
end;
false ->
send_pkt(State, #compress_failure{reason = 'unsupported-method'})
end.
-spec process_starttls(state()) -> state().
process_starttls(#{stream_encrypted := true} = State) ->
process_starttls_failure(already_encrypted, State);
process_starttls(#{socket := Socket,
mod := Mod} = State) ->
case is_starttls_available(State) of
true ->
TLSOpts = try Mod:tls_options(State)
catch _:undef -> []
end,
case xmpp_socket:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
State1 = send_pkt(State, #starttls_proceed{}),
case is_disconnected(State1) of
true -> State1;
false ->
State1#{socket => TLSSocket,
stream_id => new_id(),
stream_header_sent => false,
stream_restarted => true,
stream_state => wait_for_stream,
stream_encrypted => true}
end;
{error, Reason} ->
process_starttls_failure(Reason, State)
end;
false ->
process_starttls_failure(starttls_unsupported, State)
end.
-spec process_starttls_failure(term(), state()) -> state().
process_starttls_failure(Why, State) ->
State1 = send_pkt(State, #starttls_failure{}),
case is_disconnected(State1) of
true -> State1;
false -> process_stream_end({tls, Why}, State1)
end.
-spec process_sasl_request(sasl_auth(), state()) -> state().
process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
#{mod := Mod, lserver := LServer} = State) ->
State1 = State#{sasl_mech => Mech},
Mechs = get_sasl_mechanisms(State1),
case lists:member(Mech, Mechs) of
true when Mech == <<"EXTERNAL">> ->
Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of
{ok, Peer} ->
{ok, [{auth_module, pkix}, {username, Peer}]};
{error, Reason, Peer} ->
{error, Reason, Peer}
end,
process_sasl_result(Res, State1);
true ->
GetPW = try Mod:get_password_fun(State1)
catch _:undef -> fun(_) -> false end
end,
CheckPW = try Mod:check_password_fun(State1)
catch _:undef -> fun(_, _, _) -> false end
end,
CheckPWDigest = try Mod:check_password_digest_fun(State1)
catch _:undef -> fun(_, _, _, _, _) -> false end
end,
SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
GetPW, CheckPW, CheckPWDigest),
Res = cyrsasl:server_start(SASLState, Mech, ClientIn),
process_sasl_result(Res, State1#{sasl_state => SASLState});
false ->
process_sasl_result({error, unsupported_mechanism, <<"">>}, State1)
end.
-spec process_sasl_response(sasl_response(), state()) -> state().
process_sasl_response(#sasl_response{text = ClientIn},
#{sasl_state := SASLState} = State) ->
SASLResult = cyrsasl:server_step(SASLState, ClientIn),
process_sasl_result(SASLResult, State).
-spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state().
process_sasl_result({ok, Props}, State) ->
process_sasl_success(Props, <<"">>, State);
process_sasl_result({ok, Props, ServerOut}, State) ->
process_sasl_success(Props, ServerOut, State);
process_sasl_result({continue, ServerOut, NewSASLState}, State) ->
process_sasl_continue(ServerOut, NewSASLState, State);
process_sasl_result({error, Reason, User}, State) ->
process_sasl_failure(Reason, User, State).
-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
process_sasl_success(Props, ServerOut,
#{socket := Socket,
mod := Mod, sasl_mech := Mech} = State) ->
User = identity(Props),
AuthModule = proplists:get_value(auth_module, Props),
Socket1 = xmpp_socket:reset_stream(Socket),
State0 = State#{socket => Socket1},
State1 = send_pkt(State0, #sasl_success{text = ServerOut}),
case is_disconnected(State1) of
true -> State1;
false ->
State2 = try Mod:handle_auth_success(User, Mech, AuthModule, State1)
catch _:undef -> State1
end,
case is_disconnected(State2) of
true -> State2;
false ->
State3 = maps:remove(sasl_state,
maps:remove(sasl_mech, State2)),
State3#{stream_id => new_id(),
stream_authenticated => true,
stream_header_sent => false,
stream_restarted => true,
stream_state => wait_for_stream,
user => User}
end
end.
-spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state().
process_sasl_continue(ServerOut, NewSASLState, State) ->
State1 = State#{sasl_state => NewSASLState,
stream_state => wait_for_sasl_response},
send_pkt(State1, #sasl_challenge{text = ServerOut}).
-spec process_sasl_failure(atom(), binary(), state()) -> state().
process_sasl_failure(Err, User,
#{mod := Mod, sasl_mech := Mech, lang := Lang} = State) ->
{Reason, Text} = format_sasl_error(Mech, Err),
State1 = send_pkt(State, #sasl_failure{reason = Reason,
text = xmpp:mk_text(Text, Lang)}),
case is_disconnected(State1) of
true -> State1;
false ->
State2 = try Mod:handle_auth_failure(User, Mech, Text, State1)
catch _:undef -> State1
end,
State3 = maps:remove(sasl_state, maps:remove(sasl_mech, State2)),
State3#{stream_state => wait_for_sasl_request}
end.
-spec process_sasl_abort(state()) -> state().
process_sasl_abort(State) ->
process_sasl_failure(aborted, <<"">>, State).
-spec send_features(state()) -> state().
send_features(#{stream_version := {1,0},
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
Features = if TLSRequired and not Encrypted ->
get_tls_feature(State);
true ->
get_sasl_feature(State) ++ get_compress_feature(State)
++ get_tls_feature(State) ++ get_bind_feature(State)
++ get_session_feature(State) ++ get_other_features(State)
end,
send_pkt(State, #stream_features{sub_els = Features});
send_features(State) ->
%% clients and servers from stone age
State.
-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()].
get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod,
xmlns := NS, lserver := LServer} = State) ->
Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer);
true -> []
end,
TLSVerify = try Mod:tls_verify(State)
catch _:undef -> false
end,
Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
[<<"EXTERNAL">>|Mechs];
true ->
Mechs
end,
try Mod:sasl_mechanisms(Mechs1, State)
catch _:undef -> Mechs1
end.
-spec get_sasl_feature(state()) -> [sasl_mechanisms()].
get_sasl_feature(#{stream_authenticated := false,
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
if Encrypted or not TLSRequired ->
Mechs = get_sasl_mechanisms(State),
[#sasl_mechanisms{list = Mechs}];
true ->
[]
end;
get_sasl_feature(_) ->
[].
-spec get_compress_feature(state()) -> [compression()].
get_compress_feature(#{stream_compressed := false, mod := Mod,
stream_authenticated := true} = State) ->
try Mod:compress_methods(State) of
[] -> [];
Ms -> [#compression{methods = Ms}]
catch _:undef ->
[]
end;
get_compress_feature(_) ->
[].
-spec get_tls_feature(state()) -> [starttls()].
get_tls_feature(#{stream_authenticated := false,
stream_encrypted := false} = State) ->
case is_starttls_available(State) of
true ->
TLSRequired = is_starttls_required(State),
[#starttls{required = TLSRequired}];
false ->
[]
end;
get_tls_feature(_) ->
[].
-spec get_bind_feature(state()) -> [bind()].
get_bind_feature(#{xmlns := ?NS_CLIENT,
stream_authenticated := true,
resource := <<"">>}) ->
[#bind{}];
get_bind_feature(_) ->
[].
-spec get_session_feature(state()) -> [xmpp_session()].
get_session_feature(#{xmlns := ?NS_CLIENT,
stream_authenticated := true,
resource := <<"">>}) ->
[#xmpp_session{optional = true}];
get_session_feature(_) ->
[].
-spec get_other_features(state()) -> [xmpp_element()].
get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
try
if Auth -> Mod:authenticated_stream_features(State);
true -> Mod:unauthenticated_stream_features(State)
end
catch _:undef ->
[]
end.
-spec is_starttls_available(state()) -> boolean().
is_starttls_available(#{mod := Mod} = State) ->
try Mod:tls_enabled(State)
catch _:undef -> true
end.
-spec is_starttls_required(state()) -> boolean().
is_starttls_required(#{mod := Mod} = State) ->
try Mod:tls_required(State)
catch _:undef -> false
end.
-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} |
{error, stream_error()}.
set_from_to(Pkt, _State) when not ?is_stanza(Pkt) ->
{ok, Pkt};
set_from_to(Pkt, #{user := U, server := S, resource := R,
lang := Lang, xmlns := ?NS_CLIENT}) ->
JID = jid:make(U, S, R),
From = case xmpp:get_from(Pkt) of
undefined -> JID;
F -> F
end,
if JID#jid.luser == From#jid.luser andalso
JID#jid.lserver == From#jid.lserver andalso
(JID#jid.lresource == From#jid.lresource
orelse From#jid.lresource == <<"">>) ->
To = case xmpp:get_to(Pkt) of
undefined -> jid:make(U, S);
T -> T
end,
{ok, xmpp:set_from_to(Pkt, JID, To)};
true ->
Txt = <<"Improper 'from' attribute">>,
{error, xmpp:serr_invalid_from(Txt, Lang)}
end;
set_from_to(Pkt, #{lang := Lang}) ->
From = xmpp:get_from(Pkt),
To = xmpp:get_to(Pkt),
if From == undefined ->
Txt = <<"Missing 'from' attribute">>,
{error, xmpp:serr_improper_addressing(Txt, Lang)};
To == undefined ->
Txt = <<"Missing 'to' attribute">>,
{error, xmpp:serr_improper_addressing(Txt, Lang)};
true ->
{ok, Pkt}
end.
-spec send_header(state()) -> state().
send_header(#{stream_version := Version} = State) ->
send_header(State, #stream_start{version = Version}).
-spec send_header(state(), stream_start()) -> state().
send_header(#{stream_id := StreamID,
stream_version := MyVersion,
stream_header_sent := false,
lang := MyLang,
xmlns := NS} = State,
#stream_start{to = HisTo, from = HisFrom,
lang = HisLang, version = HisVersion}) ->
Lang = select_lang(MyLang, HisLang),
NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
true -> <<"">>
end,
Version = case HisVersion of
undefined -> undefined;
{0,_} -> HisVersion;
_ -> MyVersion
end,
StreamStart = #stream_start{version = Version,
lang = Lang,
xmlns = NS,
stream_xmlns = ?NS_STREAM,
db_xmlns = NS_DB,
id = StreamID,
to = HisFrom,
from = HisTo},
State1 = State#{lang => Lang,
stream_version => Version,
stream_header_sent => true},
case socket_send(State1, StreamStart) of
ok -> State1;
{error, Why} -> process_stream_end({socket, Why}, State1)
end;
send_header(State, _) ->
State.
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
send_pkt(#{mod := Mod} = State, Pkt) ->
Result = socket_send(State, Pkt),
State1 = try Mod:handle_send(Pkt, Result, State)
catch _:undef -> State
end,
case Result of
_ when is_record(Pkt, stream_error) ->
process_stream_end({stream, {out, Pkt}}, State1);
ok ->
State1;
{error, Why} ->
process_stream_end({socket, Why}, State1)
end.
-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state().
send_error(State, Pkt, Err) ->
case xmpp:is_stanza(Pkt) of
true ->
case xmpp:get_type(Pkt) of
result -> State;
error -> State;
<<"result">> -> State;
<<"error">> -> State;
_ ->
ErrPkt = xmpp:make_error(Pkt, Err),
send_pkt(State, ErrPkt)
end;
false ->
State
end.
-spec send_trailer(state()) -> state().
send_trailer(State) ->
socket_send(State, trailer),
close_socket(State).
-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}.
socket_send(#{socket := Sock,
stream_state := StateName,
xmlns := NS,
stream_header_sent := true}, Pkt) ->
case Pkt of
trailer ->
xmpp_socket:send_trailer(Sock);
#stream_start{} when StateName /= disconnected ->
xmpp_socket:send_header(Sock, xmpp:encode(Pkt));
_ when StateName /= disconnected ->
xmpp_socket:send_element(Sock, xmpp:encode(Pkt, NS));
_ ->
{error, closed}
end;
socket_send(_, _) ->
{error, closed}.
-spec close_socket(state()) -> state().
close_socket(#{socket := Socket} = State) ->
xmpp_socket:close(Socket),
State#{stream_timeout => infinity,
stream_state => disconnected}.
-spec select_lang(binary(), binary()) -> binary().
select_lang(Lang, <<"">>) -> Lang;
select_lang(_, Lang) -> Lang.
-spec set_lang(xmpp_element(), state()) -> xmpp_element().
set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) ->
HisLang = xmpp:get_lang(Pkt),
Lang = select_lang(MyLang, HisLang),
xmpp:set_lang(Pkt, Lang);
set_lang(Pkt, _) ->
Pkt.
-spec format_inet_error(atom()) -> string().
format_inet_error(closed) ->
"connection closed";
format_inet_error(Reason) ->
case inet:format_error(Reason) of
"unknown POSIX error" -> atom_to_list(Reason);
Txt -> Txt
end.
-spec format_stream_error(atom() | 'see-other-host'(), [text()]) -> string().
format_stream_error(Reason, Txt) ->
Slogan = case Reason of
undefined -> "no reason";
#'see-other-host'{} -> "see-other-host";
_ -> atom_to_list(Reason)
end,
case xmpp:get_text(Txt) of
<<"">> ->
Slogan;
Data ->
binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
end.
-spec format_sasl_error(cyrsasl:mechanism(), atom()) -> {atom(), binary()}.
format_sasl_error(<<"EXTERNAL">>, Err) ->
xmpp_stream_pkix:format_error(Err);
format_sasl_error(Mech, Err) ->
cyrsasl:format_error(Mech, Err).
-spec format_tls_error(atom() | binary()) -> list().
format_tls_error(Reason) when is_atom(Reason) ->
format_inet_error(Reason);
format_tls_error(Reason) ->
Reason.
-spec format(io:format(), list()) -> binary().
format(Fmt, Args) ->
iolist_to_binary(io_lib:format(Fmt, Args)).
-spec lists_intersection(list(), list()) -> list().
lists_intersection(L1, L2) ->
lists:filter(
fun(E) ->
lists:member(E, L2)
end, L1).
-spec identity([cyrsasl:sasl_property()]) -> binary().
identity(Props) ->
case proplists:get_value(authzid, Props, <<>>) of
<<>> -> proplists:get_value(username, Props, <<>>);
AuthzId -> AuthzId
end.