25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-10-13 15:16:49 +02:00
xmpp.chapril.org-ejabberd/src/xmpp_stream_in.erl
Evgeniy Khramtsov fc77051b68 Don't call Mod:function() in xmpp_stream callbacks
If a callback function is not defined by the `Mod` then
a call to code_server process is performed. Under heavy load
this may cause code_server to get overloaded. We now avoid this.
2018-05-26 09:06:24 +03:00

1235 lines
42 KiB
Erlang

%%%-------------------------------------------------------------------
%%% Created : 26 Nov 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
%%%
%%%
%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
%%%
%%% This program is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU General Public License as
%%% published by the Free Software Foundation; either version 2 of the
%%% License, or (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
%%% General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License along
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%-------------------------------------------------------------------
-module(xmpp_stream_in).
-define(GEN_SERVER, p1_server).
-behaviour(?GEN_SERVER).
-protocol({rfc, 6120}).
-protocol({xep, 114, '1.6'}).
%% API
-export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1,
send/2, close/1, close/2, send_error/3, establish/1,
get_transport/1, change_shaper/2, set_timeout/2, format_error/1]).
%% gen_server callbacks
-export([init/1, handle_cast/2, handle_call/3, handle_info/2,
terminate/2, code_change/3]).
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
-define(FSMOPTS, [{debug, [trace]}]).
-else.
-define(FSMOPTS, []).
-endif.
-include("xmpp.hrl").
-type state() :: map().
-type stop_reason() :: {stream, reset | {in | out, stream_error()}} |
{tls, inet:posix() | atom() | binary()} |
{socket, inet:posix() | atom()} |
internal_failure.
-export_type([state/0, stop_reason/0]).
-callback init(list()) -> {ok, state()} | {error, term()} | ignore.
-callback handle_cast(term(), state()) -> state().
-callback handle_call(term(), term(), state()) -> state().
-callback handle_info(term(), state()) -> state().
-callback terminate(term(), state()) -> any().
-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
-callback handle_stream_start(stream_start(), state()) -> state().
-callback handle_stream_established(state()) -> state().
-callback handle_stream_end(stop_reason(), state()) -> state().
-callback handle_cdata(binary(), state()) -> state().
-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
-callback handle_authenticated_packet(xmpp_element(), state()) -> state().
-callback handle_unbinded_packet(xmpp_element(), state()) -> state().
-callback handle_auth_success(binary(), binary(), module(), state()) -> state().
-callback handle_auth_failure(binary(), binary(), binary(), state()) -> state().
-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
-callback handle_timeout(state()) -> state().
-callback get_password_fun(state()) -> fun().
-callback check_password_fun(state()) -> fun().
-callback check_password_digest_fun(state()) -> fun().
-callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}.
-callback compress_methods(state()) -> [binary()].
-callback tls_options(state()) -> [proplists:property()].
-callback tls_required(state()) -> boolean().
-callback tls_verify(state()) -> boolean().
-callback tls_enabled(state()) -> boolean().
-callback sasl_mechanisms([cyrsasl:mechanism()], state()) -> [cyrsasl:mechanism()].
-callback unauthenticated_stream_features(state()) -> [xmpp_element()].
-callback authenticated_stream_features(state()) -> [xmpp_element()].
%% All callbacks are optional
-optional_callbacks([init/1,
handle_cast/2,
handle_call/3,
handle_info/2,
terminate/2,
code_change/3,
handle_stream_start/2,
handle_stream_established/1,
handle_stream_end/2,
handle_cdata/2,
handle_authenticated_packet/2,
handle_unauthenticated_packet/2,
handle_unbinded_packet/2,
handle_auth_success/4,
handle_auth_failure/4,
handle_send/3,
handle_recv/3,
handle_timeout/1,
get_password_fun/1,
check_password_fun/1,
check_password_digest_fun/1,
bind/2,
compress_methods/1,
tls_options/1,
tls_required/1,
tls_verify/1,
tls_enabled/1,
sasl_mechanisms/2,
unauthenticated_stream_features/1,
authenticated_stream_features/1]).
%%%===================================================================
%%% API
%%%===================================================================
start(Mod, Args, Opts) ->
?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
start_link(Mod, Args, Opts) ->
?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
call(Ref, Msg, Timeout) ->
?GEN_SERVER:call(Ref, Msg, Timeout).
cast(Ref, Msg) ->
?GEN_SERVER:cast(Ref, Msg).
reply(Ref, Reply) ->
?GEN_SERVER:reply(Ref, Reply).
-spec stop(pid()) -> ok;
(state()) -> no_return().
stop(Pid) when is_pid(Pid) ->
cast(Pid, stop);
stop(#{owner := Owner} = State) when Owner == self() ->
terminate(normal, State),
exit(normal);
stop(_) ->
erlang:error(badarg).
-spec send(pid(), xmpp_element()) -> ok;
(state(), xmpp_element()) -> state().
send(Pid, Pkt) when is_pid(Pid) ->
cast(Pid, {send, Pkt});
send(#{owner := Owner} = State, Pkt) when Owner == self() ->
send_pkt(State, Pkt);
send(_, _) ->
erlang:error(badarg).
-spec close(pid()) -> ok;
(state()) -> state().
close(Pid) when is_pid(Pid) ->
close(Pid, closed);
close(#{owner := Owner} = State) when Owner == self() ->
close_socket(State);
close(_) ->
erlang:error(badarg).
-spec close(pid(), atom()) -> ok.
close(Pid, Reason) ->
cast(Pid, {close, Reason}).
-spec establish(state()) -> state().
establish(State) ->
process_stream_established(State).
-spec set_timeout(state(), non_neg_integer() | infinity) -> state().
set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
case Timeout of
infinity -> State#{stream_timeout => infinity};
_ ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State#{stream_timeout => {Timeout, Time}}
end;
set_timeout(_, _) ->
erlang:error(badarg).
get_transport(#{socket := Socket, owner := Owner})
when Owner == self() ->
xmpp_socket:get_transport(Socket);
get_transport(_) ->
erlang:error(badarg).
-spec change_shaper(state(), shaper:shaper()) -> state().
change_shaper(#{socket := Socket, owner := Owner} = State, Shaper)
when Owner == self() ->
Socket1 = xmpp_socket:change_shaper(Socket, Shaper),
State#{socket => Socket1};
change_shaper(_, _) ->
erlang:error(badarg).
-spec format_error(stop_reason()) -> binary().
format_error({socket, Reason}) ->
format("Connection failed: ~s", [format_inet_error(Reason)]);
format_error({stream, reset}) ->
<<"Stream reset by peer">>;
format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]);
format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
format_error({tls, Reason}) ->
format("TLS failed: ~s", [format_tls_error(Reason)]);
format_error(internal_failure) ->
<<"Internal server error">>;
format_error(Err) ->
format("Unrecognized error: ~w", [Err]).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([Mod, {_SockMod, Socket}, Opts]) ->
Encrypted = proplists:get_bool(tls, Opts),
SocketMonitor = xmpp_socket:monitor(Socket),
case xmpp_socket:peername(Socket) of
{ok, IP} ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State = #{owner => self(),
mod => Mod,
socket => Socket,
socket_monitor => SocketMonitor,
stream_timeout => {timer:seconds(30), Time},
stream_direction => in,
stream_id => new_id(),
stream_state => wait_for_stream,
stream_header_sent => false,
stream_restarted => false,
stream_compressed => false,
stream_encrypted => Encrypted,
stream_version => {1,0},
stream_authenticated => false,
codec_options => [ignore_els],
xmlns => ?NS_CLIENT,
lang => <<"">>,
user => <<"">>,
server => <<"">>,
resource => <<"">>,
lserver => <<"">>,
ip => IP},
case try Mod:init([State, Opts])
catch _:undef -> {ok, State}
end of
{ok, State1} when not Encrypted ->
{_, State2, Timeout} = noreply(State1),
{ok, State2, Timeout};
{ok, State1} when Encrypted ->
TLSOpts = try callback(tls_options, State1)
catch _:{?MODULE, undef} -> []
end,
case xmpp_socket:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
State2 = State1#{socket => TLSSocket},
{_, State3, Timeout} = noreply(State2),
{ok, State3, Timeout};
{error, Reason} ->
{stop, Reason}
end;
{error, Reason} ->
{stop, Reason};
ignore ->
ignore
end;
{error, _Reason} ->
ignore
end.
handle_cast({send, Pkt}, State) ->
noreply(send_pkt(State, Pkt));
handle_cast(stop, State) ->
{stop, normal, State};
handle_cast({close, Reason}, State) ->
State1 = close_socket(State),
noreply(
case is_disconnected(State) of
true -> State1;
false -> process_stream_end({socket, Reason}, State)
end);
handle_cast(Cast, State) ->
noreply(try callback(handle_cast, Cast, State)
catch _:{?MODULE, undef} -> State
end).
handle_call(Call, From, State) ->
noreply(try callback(handle_call, Call, From, State)
catch _:{?MODULE, undef} -> State
end).
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
#{stream_state := wait_for_stream,
xmlns := XMLNS, lang := MyLang} = State) ->
El = #xmlel{name = Name, attrs = Attrs},
noreply(
try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt ->
State1 = send_header(State, Pkt),
case is_disconnected(State1) of
true -> State1;
false -> process_stream(Pkt, State1)
end;
_ ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false -> send_pkt(State1, xmpp:serr_invalid_xml())
end
catch _:{xmpp_codec, Why} ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang),
send_pkt(State1, Err)
end
end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_end({socket, closed}, State));
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
State1 = send_header(State),
noreply(
case is_disconnected(State1) of
true -> State1;
false ->
Err = case Reason of
<<"XML stanza is too big">> ->
xmpp:serr_policy_violation(Reason, Lang);
{_, Txt} ->
xmpp:serr_not_well_formed(Txt, Lang)
end,
send_pkt(State1, Err)
end);
handle_info({'$gen_event', El}, #{stream_state := wait_for_stream} = State) ->
error_logger:warning_msg("unexpected event from XML driver: ~p; "
"xmlstreamstart was expected", [El]),
State1 = send_header(State),
noreply(
case is_disconnected(State1) of
true -> State1;
false -> send_pkt(State1, xmpp:serr_invalid_xml())
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, codec_options := Opts} = State) ->
noreply(
try xmpp:decode(El, NS, Opts) of
Pkt ->
State1 = try callback(handle_recv, El, Pkt, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_element(Pkt, State1)
end
catch _:{xmpp_codec, Why} ->
State1 = try callback(handle_recv, El, {error, Why}, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_invalid_xml(State1, El, Why)
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
State) ->
noreply(try callback(handle_cdata, Data, State)
catch _:{?MODULE, undef} -> State
end);
handle_info(timeout, #{lang := Lang} = State) ->
Disconnected = is_disconnected(State),
noreply(try callback(handle_timeout, State)
catch _:{?MODULE, undef} when not Disconnected ->
Txt = <<"Idle connection">>,
send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
_:{?MODULE, undef} ->
stop(State)
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
#{socket_monitor := MRef} = State) ->
noreply(process_stream_end({socket, closed}, State));
handle_info({tcp, _, Data}, #{socket := Socket} = State) ->
noreply(
case xmpp_socket:recv(Socket, Data) of
{ok, NewSocket} ->
State#{socket => NewSocket};
{error, Reason} when is_atom(Reason) ->
process_stream_end({socket, Reason}, State);
{error, Reason} ->
%% TODO: make fast_tls return atoms
process_stream_end({tls, Reason}, State)
end);
handle_info({tcp_closed, _}, State) ->
handle_info({'$gen_event', closed}, State);
handle_info({tcp_error, _, Reason}, State) ->
noreply(process_stream_end({socket, Reason}, State));
handle_info(Info, State) ->
noreply(try callback(handle_info, Info, State)
catch _:{?MODULE, undef} -> State
end).
terminate(Reason, State) ->
case get(already_terminated) of
true ->
State;
_ ->
put(already_terminated, true),
try callback(terminate, Reason, State)
catch _:{?MODULE, undef} -> ok
end,
send_trailer(State)
end.
code_change(OldVsn, State, Extra) ->
callback(code_change, OldVsn, State, Extra).
%%%===================================================================
%%% Internal functions
%%%===================================================================
-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
noreply(#{stream_timeout := infinity} = State) ->
{noreply, State, infinity};
noreply(#{stream_timeout := {MSecs, StartTime}} = State) ->
CurrentTime = p1_time_compat:monotonic_time(milli_seconds),
Timeout = max(0, MSecs - CurrentTime + StartTime),
{noreply, State, Timeout}.
-spec new_id() -> binary().
new_id() ->
randoms:get_string().
-spec is_disconnected(state()) -> boolean().
is_disconnected(#{stream_state := StreamState}) ->
StreamState == disconnected.
-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
case xmpp:is_stanza(El) of
true ->
Txt = xmpp:io_format_error(Reason),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
send_error(State, El, xmpp:err_bad_request(Txt, Lang));
false ->
case {xmpp:get_name(El), xmpp:get_ns(El)} of
{Tag, ?NS_SASL} when Tag == <<"auth">>;
Tag == <<"response">>;
Tag == <<"abort">> ->
Txt = xmpp:io_format_error(Reason),
Err = #sasl_failure{reason = 'malformed-request',
text = xmpp:mk_text(Txt, MyLang)},
send_pkt(State, Err);
{<<"starttls">>, ?NS_TLS} ->
send_pkt(State, #starttls_failure{});
{<<"compress">>, ?NS_COMPRESS} ->
Err = #compress_failure{reason = 'setup-failed'},
send_pkt(State, Err);
_ ->
%% Maybe add something more?
State
end
end.
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
process_stream_end(Reason, State) ->
State1 = State#{stream_timeout => infinity,
stream_state => disconnected},
try callback(handle_stream_end, Reason, State1)
catch _:{?MODULE, undef} -> stop(State1)
end.
-spec process_stream(stream_start(), state()) -> state().
process_stream(#stream_start{xmlns = XML_NS,
stream_xmlns = STREAM_NS},
#{xmlns := NS} = State)
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
send_pkt(State, xmpp:serr_invalid_namespace());
process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
send_pkt(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang},
#{xmlns := ?NS_CLIENT, lang := DefaultLang} = State)
when size(Lang) > 35 ->
%% As stated in BCP47, 4.4.1:
%% Protocols or specifications that specify limited buffer sizes for
%% language tags MUST allow for language tags of at least 35 characters.
%% Do not store long language tag to avoid possible DoS/flood attacks
Txt = <<"Too long value of 'xml:lang' attribute">>,
send_pkt(State, xmpp:serr_policy_violation(Txt, DefaultLang));
process_stream(#stream_start{to = undefined, version = Version} = StreamStart,
#{lang := Lang, server := Server, xmlns := NS} = State) ->
if Version < {1,0} andalso NS /= ?NS_COMPONENT ->
%% Work-around for gmail servers
To = jid:make(Server),
process_stream(StreamStart#stream_start{to = To}, State);
true ->
Txt = <<"Missing 'to' attribute">>,
send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang))
end;
process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
#{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
Txt = <<"Improper 'to' attribute">>,
send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang));
process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
#{xmlns := ?NS_COMPONENT} = State) ->
State1 = State#{remote_server => RemoteServer,
stream_state => wait_for_handshake},
try callback(handle_stream_start, StreamStart, State1)
catch _:{?MODULE, undef} -> State1
end;
process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
from = From} = StreamStart,
#{stream_authenticated := Authenticated,
stream_restarted := StreamWasRestarted,
xmlns := NS, resource := Resource,
stream_encrypted := Encrypted} = State) ->
State1 = if not StreamWasRestarted ->
State#{server => Server, lserver => LServer};
true ->
State
end,
State2 = case From of
#jid{lserver = RemoteServer} when NS == ?NS_SERVER ->
State1#{remote_server => RemoteServer};
_ ->
State1
end,
State3 = try callback(handle_stream_start, StreamStart, State2)
catch _:{?MODULE, undef} -> State2
end,
case is_disconnected(State3) of
true -> State3;
false ->
State4 = send_features(State3),
case is_disconnected(State4) of
true -> State4;
false ->
TLSRequired = is_starttls_required(State4),
if not Authenticated and (TLSRequired and not Encrypted) ->
State4#{stream_state => wait_for_starttls};
not Authenticated ->
State4#{stream_state => wait_for_sasl_request};
(NS == ?NS_CLIENT) and (Resource == <<"">>) ->
State4#{stream_state => wait_for_bind};
true ->
process_stream_established(State4)
end
end
end.
-spec process_element(xmpp_element(), state()) -> state().
process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
case Pkt of
#starttls{} when StateName == wait_for_starttls;
StateName == wait_for_sasl_request ->
process_starttls(State);
#starttls{} ->
process_starttls_failure(unexpected_starttls_request, State);
#sasl_auth{} when StateName == wait_for_starttls ->
send_pkt(State, #sasl_failure{reason = 'encryption-required'});
#sasl_auth{} when StateName == wait_for_sasl_request ->
process_sasl_request(Pkt, State);
#sasl_auth{} when StateName == wait_for_sasl_response ->
process_sasl_request(Pkt, maps:remove(sasl_state, State));
#sasl_auth{} ->
Txt = <<"SASL negotiation is not allowed in this state">>,
send_pkt(State, #sasl_failure{reason = 'not-authorized',
text = xmpp:mk_text(Txt, Lang)});
#sasl_response{} when StateName == wait_for_starttls ->
send_pkt(State, #sasl_failure{reason = 'encryption-required'});
#sasl_response{} when StateName == wait_for_sasl_response ->
process_sasl_response(Pkt, State);
#sasl_response{} ->
Txt = <<"SASL negotiation is not allowed in this state">>,
send_pkt(State, #sasl_failure{reason = 'not-authorized',
text = xmpp:mk_text(Txt, Lang)});
#sasl_abort{} when StateName == wait_for_sasl_response ->
process_sasl_abort(State);
#sasl_abort{} ->
send_pkt(State, #sasl_failure{reason = 'aborted'});
#sasl_success{} ->
State;
#compress{} ->
process_compress(Pkt, State);
#handshake{} when StateName == wait_for_handshake ->
process_handshake(Pkt, State);
#handshake{} ->
State;
#stream_error{} ->
process_stream_end({stream, {in, Pkt}}, State);
_ when StateName == wait_for_sasl_request;
StateName == wait_for_handshake;
StateName == wait_for_sasl_response ->
process_unauthenticated_packet(Pkt, State);
_ when StateName == wait_for_starttls ->
Txt = <<"Use of STARTTLS required">>,
Err = xmpp:serr_policy_violation(Txt, Lang),
send_pkt(State, Err);
_ when StateName == wait_for_bind ->
process_bind(Pkt, State);
_ when StateName == established ->
process_authenticated_packet(Pkt, State)
end.
-spec process_unauthenticated_packet(xmpp_element(), state()) -> state().
process_unauthenticated_packet(Pkt, State) ->
NewPkt = set_lang(Pkt, State),
try callback(handle_unauthenticated_packet, NewPkt, State)
catch _:{?MODULE, undef} ->
Err = xmpp:serr_not_authorized(),
send(State, Err)
end.
-spec process_authenticated_packet(xmpp_element(), state()) -> state().
process_authenticated_packet(Pkt, State) ->
Pkt1 = set_lang(Pkt, State),
case set_from_to(Pkt1, State) of
{ok, Pkt2} ->
try callback(handle_authenticated_packet, Pkt2, State)
catch _:{?MODULE, undef} ->
Err = xmpp:err_service_unavailable(),
send_error(State, Pkt, Err)
end;
{error, Err} ->
send_pkt(State, Err)
end.
-spec process_bind(xmpp_element(), state()) -> state().
process_bind(#iq{type = set, sub_els = [_]} = Pkt,
#{xmlns := ?NS_CLIENT, lang := MyLang} = State) ->
try xmpp:try_subtag(Pkt, #bind{}) of
#bind{resource = R} ->
case callback(bind, R, State) of
{ok, #{user := U, server := S, resource := NewR} = State1}
when NewR /= <<"">> ->
Reply = #bind{jid = jid:make(U, S, NewR)},
State2 = send_pkt(State1, xmpp:make_iq_result(Pkt, Reply)),
process_stream_established(State2);
{error, #stanza_error{} = Err, State1} ->
send_error(State1, Pkt, Err)
end;
_ ->
try callback(handle_unbinded_packet, Pkt, State)
catch _:{?MODULE, undef} ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end
catch _:{xmpp_codec, Why} ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(Pkt)),
Err = xmpp:err_bad_request(Txt, Lang),
send_error(State, Pkt, Err)
end;
process_bind(Pkt, State) ->
try callback(handle_unbinded_packet, Pkt, State)
catch _:{?MODULE, undef} ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end.
-spec process_handshake(handshake(), state()) -> state().
process_handshake(#handshake{data = Digest},
#{stream_id := StreamID,
remote_server := RemoteServer} = State) ->
GetPW = try callback(get_password_fun, State)
catch _:{?MODULE, undef} -> fun(_) -> {false, undefined} end
end,
AuthRes = case GetPW(<<"">>) of
{false, _} ->
false;
{Password, _} ->
str:sha(<<StreamID/binary, Password/binary>>) == Digest
end,
case AuthRes of
true ->
State1 = try callback(handle_auth_success,
RemoteServer, <<"handshake">>, undefined, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
State2 = send_pkt(State1, #handshake{}),
process_stream_established(State2)
end;
false ->
State1 = try callback(handle_auth_failure,
RemoteServer, <<"handshake">>, <<"not authorized">>, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> send_pkt(State1, xmpp:serr_not_authorized())
end
end.
-spec process_stream_established(state()) -> state().
process_stream_established(#{stream_state := StateName} = State)
when StateName == disconnected; StateName == established ->
State;
process_stream_established(State) ->
State1 = State#{stream_authenticated => true,
stream_state => established,
stream_timeout => infinity},
try callback(handle_stream_established, State1)
catch _:{?MODULE, undef} -> State1
end.
-spec process_compress(compress(), state()) -> state().
process_compress(#compress{},
#{stream_compressed := Compressed,
stream_authenticated := Authenticated} = State)
when Compressed or not Authenticated ->
send_pkt(State, #compress_failure{reason = 'setup-failed'});
process_compress(#compress{methods = HisMethods},
#{socket := Socket} = State) ->
MyMethods = try callback(compress_methods, State)
catch _:{?MODULE, undef} -> []
end,
CommonMethods = lists_intersection(MyMethods, HisMethods),
case lists:member(<<"zlib">>, CommonMethods) of
true ->
case xmpp_socket:compress(Socket) of
{ok, ZlibSocket} ->
State1 = send_pkt(State, #compressed{}),
case is_disconnected(State1) of
true -> State1;
false ->
State1#{socket => ZlibSocket,
stream_id => new_id(),
stream_header_sent => false,
stream_restarted => true,
stream_state => wait_for_stream,
stream_compressed => true}
end;
{error, _} ->
Err = #compress_failure{reason = 'setup-failed'},
send_pkt(State, Err)
end;
false ->
send_pkt(State, #compress_failure{reason = 'unsupported-method'})
end.
-spec process_starttls(state()) -> state().
process_starttls(#{stream_encrypted := true} = State) ->
process_starttls_failure(already_encrypted, State);
process_starttls(#{socket := Socket} = State) ->
case is_starttls_available(State) of
true ->
TLSOpts = try callback(tls_options, State)
catch _:{?MODULE, undef} -> []
end,
case xmpp_socket:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
State1 = send_pkt(State, #starttls_proceed{}),
case is_disconnected(State1) of
true -> State1;
false ->
State1#{socket => TLSSocket,
stream_id => new_id(),
stream_header_sent => false,
stream_restarted => true,
stream_state => wait_for_stream,
stream_encrypted => true}
end;
{error, Reason} ->
process_starttls_failure(Reason, State)
end;
false ->
process_starttls_failure(starttls_unsupported, State)
end.
-spec process_starttls_failure(term(), state()) -> state().
process_starttls_failure(Why, State) ->
State1 = send_pkt(State, #starttls_failure{}),
case is_disconnected(State1) of
true -> State1;
false -> process_stream_end({tls, Why}, State1)
end.
-spec process_sasl_request(sasl_auth(), state()) -> state().
process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
#{lserver := LServer} = State) ->
State1 = State#{sasl_mech => Mech},
Mechs = get_sasl_mechanisms(State1),
case lists:member(Mech, Mechs) of
true when Mech == <<"EXTERNAL">> ->
Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of
{ok, Peer} ->
{ok, [{auth_module, pkix}, {username, Peer}]};
{error, Reason, Peer} ->
{error, Reason, Peer}
end,
process_sasl_result(Res, State1);
true ->
GetPW = try callback(get_password_fun, State1)
catch _:{?MODULE, undef} -> fun(_) -> false end
end,
CheckPW = try callback(check_password_fun, State1)
catch _:{?MODULE, undef} -> fun(_, _, _) -> false end
end,
CheckPWDigest = try callback(check_password_digest_fun, State1)
catch _:{?MODULE, undef} -> fun(_, _, _, _, _) -> false end
end,
SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
GetPW, CheckPW, CheckPWDigest),
Res = cyrsasl:server_start(SASLState, Mech, ClientIn),
process_sasl_result(Res, State1#{sasl_state => SASLState});
false ->
process_sasl_result({error, unsupported_mechanism, <<"">>}, State1)
end.
-spec process_sasl_response(sasl_response(), state()) -> state().
process_sasl_response(#sasl_response{text = ClientIn},
#{sasl_state := SASLState} = State) ->
SASLResult = cyrsasl:server_step(SASLState, ClientIn),
process_sasl_result(SASLResult, State).
-spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state().
process_sasl_result({ok, Props}, State) ->
process_sasl_success(Props, <<"">>, State);
process_sasl_result({ok, Props, ServerOut}, State) ->
process_sasl_success(Props, ServerOut, State);
process_sasl_result({continue, ServerOut, NewSASLState}, State) ->
process_sasl_continue(ServerOut, NewSASLState, State);
process_sasl_result({error, Reason, User}, State) ->
process_sasl_failure(Reason, User, State).
-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
process_sasl_success(Props, ServerOut,
#{socket := Socket,
sasl_mech := Mech} = State) ->
User = identity(Props),
AuthModule = proplists:get_value(auth_module, Props),
Socket1 = xmpp_socket:reset_stream(Socket),
State0 = State#{socket => Socket1},
State1 = try callback(handle_auth_success, User, Mech, AuthModule, State0)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
State2 = send_pkt(State1, #sasl_success{text = ServerOut}),
case is_disconnected(State2) of
true -> State2;
false ->
State3 = maps:remove(sasl_state,
maps:remove(sasl_mech, State2)),
State3#{stream_id => new_id(),
stream_authenticated => true,
stream_header_sent => false,
stream_restarted => true,
stream_state => wait_for_stream,
user => User}
end
end.
-spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state().
process_sasl_continue(ServerOut, NewSASLState, State) ->
State1 = State#{sasl_state => NewSASLState,
stream_state => wait_for_sasl_response},
send_pkt(State1, #sasl_challenge{text = ServerOut}).
-spec process_sasl_failure(atom(), binary(), state()) -> state().
process_sasl_failure(Err, User,
#{sasl_mech := Mech, lang := Lang} = State) ->
{Reason, Text} = format_sasl_error(Mech, Err),
State1 = try callback(handle_auth_failure, User, Mech, Text, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
State2 = send_pkt(State1,
#sasl_failure{reason = Reason,
text = xmpp:mk_text(Text, Lang)}),
case is_disconnected(State2) of
true -> State2;
false ->
State3 = maps:remove(sasl_state,
maps:remove(sasl_mech, State2)),
State3#{stream_state => wait_for_sasl_request}
end
end.
-spec process_sasl_abort(state()) -> state().
process_sasl_abort(State) ->
process_sasl_failure(aborted, <<"">>, State).
-spec send_features(state()) -> state().
send_features(#{stream_version := {1,0},
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
Features = if TLSRequired and not Encrypted ->
get_tls_feature(State);
true ->
get_sasl_feature(State) ++ get_compress_feature(State)
++ get_tls_feature(State) ++ get_bind_feature(State)
++ get_session_feature(State) ++ get_other_features(State)
end,
send_pkt(State, #stream_features{sub_els = Features});
send_features(State) ->
%% clients and servers from stone age
State.
-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()].
get_sasl_mechanisms(#{stream_encrypted := Encrypted,
xmlns := NS, lserver := LServer} = State) ->
Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer);
true -> []
end,
TLSVerify = try callback(tls_verify, State)
catch _:{?MODULE, undef} -> false
end,
Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
[<<"EXTERNAL">>|Mechs];
true ->
Mechs
end,
try callback(sasl_mechanisms, Mechs1, State)
catch _:{?MODULE, undef} -> Mechs1
end.
-spec get_sasl_feature(state()) -> [sasl_mechanisms()].
get_sasl_feature(#{stream_authenticated := false,
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
if Encrypted or not TLSRequired ->
Mechs = get_sasl_mechanisms(State),
[#sasl_mechanisms{list = Mechs}];
true ->
[]
end;
get_sasl_feature(_) ->
[].
-spec get_compress_feature(state()) -> [compression()].
get_compress_feature(#{stream_compressed := false,
stream_authenticated := true} = State) ->
try callback(compress_methods, State) of
[] -> [];
Ms -> [#compression{methods = Ms}]
catch _:{?MODULE, undef} ->
[]
end;
get_compress_feature(_) ->
[].
-spec get_tls_feature(state()) -> [starttls()].
get_tls_feature(#{stream_authenticated := false,
stream_encrypted := false} = State) ->
case is_starttls_available(State) of
true ->
TLSRequired = is_starttls_required(State),
[#starttls{required = TLSRequired}];
false ->
[]
end;
get_tls_feature(_) ->
[].
-spec get_bind_feature(state()) -> [bind()].
get_bind_feature(#{xmlns := ?NS_CLIENT,
stream_authenticated := true,
resource := <<"">>}) ->
[#bind{}];
get_bind_feature(_) ->
[].
-spec get_session_feature(state()) -> [xmpp_session()].
get_session_feature(#{xmlns := ?NS_CLIENT,
stream_authenticated := true,
resource := <<"">>}) ->
[#xmpp_session{optional = true}];
get_session_feature(_) ->
[].
-spec get_other_features(state()) -> [xmpp_element()].
get_other_features(#{stream_authenticated := Auth} = State) ->
try
if Auth -> callback(authenticated_stream_features, State);
true -> callback(unauthenticated_stream_features, State)
end
catch _:{?MODULE, undef} ->
[]
end.
-spec is_starttls_available(state()) -> boolean().
is_starttls_available(State) ->
try callback(tls_enabled, State)
catch _:{?MODULE, undef} -> true
end.
-spec is_starttls_required(state()) -> boolean().
is_starttls_required(State) ->
try callback(tls_required, State)
catch _:{?MODULE, undef} -> false
end.
-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} |
{error, stream_error()}.
set_from_to(Pkt, _State) when not ?is_stanza(Pkt) ->
{ok, Pkt};
set_from_to(Pkt, #{user := U, server := S, resource := R,
lang := Lang, xmlns := ?NS_CLIENT}) ->
JID = jid:make(U, S, R),
From = case xmpp:get_from(Pkt) of
undefined -> JID;
F -> F
end,
if JID#jid.luser == From#jid.luser andalso
JID#jid.lserver == From#jid.lserver andalso
(JID#jid.lresource == From#jid.lresource
orelse From#jid.lresource == <<"">>) ->
To = case xmpp:get_to(Pkt) of
undefined -> jid:make(U, S);
T -> T
end,
{ok, xmpp:set_from_to(Pkt, JID, To)};
true ->
Txt = <<"Improper 'from' attribute">>,
{error, xmpp:serr_invalid_from(Txt, Lang)}
end;
set_from_to(Pkt, #{lang := Lang}) ->
From = xmpp:get_from(Pkt),
To = xmpp:get_to(Pkt),
if From == undefined ->
Txt = <<"Missing 'from' attribute">>,
{error, xmpp:serr_improper_addressing(Txt, Lang)};
To == undefined ->
Txt = <<"Missing 'to' attribute">>,
{error, xmpp:serr_improper_addressing(Txt, Lang)};
true ->
{ok, Pkt}
end.
-spec send_header(state()) -> state().
send_header(#{stream_version := Version} = State) ->
send_header(State, #stream_start{version = Version}).
-spec send_header(state(), stream_start()) -> state().
send_header(#{stream_id := StreamID,
stream_version := MyVersion,
stream_header_sent := false,
lang := MyLang,
xmlns := NS} = State,
#stream_start{to = HisTo, from = HisFrom,
lang = HisLang, version = HisVersion}) ->
Lang = select_lang(MyLang, HisLang),
NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
true -> <<"">>
end,
Version = case HisVersion of
undefined -> undefined;
{0,_} -> HisVersion;
_ -> MyVersion
end,
StreamStart = #stream_start{version = Version,
lang = Lang,
xmlns = NS,
stream_xmlns = ?NS_STREAM,
db_xmlns = NS_DB,
id = StreamID,
to = HisFrom,
from = HisTo},
State1 = State#{lang => Lang,
stream_version => Version,
stream_header_sent => true},
case socket_send(State1, StreamStart) of
ok -> State1;
{error, Why} -> process_stream_end({socket, Why}, State1)
end;
send_header(State, _) ->
State.
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
send_pkt(State, Pkt) ->
Result = socket_send(State, Pkt),
State1 = try callback(handle_send, Pkt, Result, State)
catch _:{?MODULE, undef} -> State
end,
case Result of
_ when is_record(Pkt, stream_error) ->
process_stream_end({stream, {out, Pkt}}, State1);
ok ->
State1;
{error, Why} ->
process_stream_end({socket, Why}, State1)
end.
-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state().
send_error(State, Pkt, Err) ->
case xmpp:is_stanza(Pkt) of
true ->
case xmpp:get_type(Pkt) of
result -> State;
error -> State;
<<"result">> -> State;
<<"error">> -> State;
_ ->
ErrPkt = xmpp:make_error(Pkt, Err),
send_pkt(State, ErrPkt)
end;
false ->
State
end.
-spec send_trailer(state()) -> state().
send_trailer(State) ->
socket_send(State, trailer),
close_socket(State).
-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}.
socket_send(#{socket := Sock,
stream_state := StateName,
xmlns := NS,
stream_header_sent := true}, Pkt) ->
case Pkt of
trailer ->
xmpp_socket:send_trailer(Sock);
#stream_start{} when StateName /= disconnected ->
xmpp_socket:send_header(Sock, xmpp:encode(Pkt));
_ when StateName /= disconnected ->
xmpp_socket:send_element(Sock, xmpp:encode(Pkt, NS));
_ ->
{error, closed}
end;
socket_send(_, _) ->
{error, closed}.
-spec close_socket(state()) -> state().
close_socket(#{socket := Socket} = State) ->
xmpp_socket:close(Socket),
State#{stream_timeout => infinity,
stream_state => disconnected}.
-spec select_lang(binary(), binary()) -> binary().
select_lang(Lang, <<"">>) -> Lang;
select_lang(_, Lang) -> Lang.
-spec set_lang(xmpp_element(), state()) -> xmpp_element().
set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) ->
HisLang = xmpp:get_lang(Pkt),
Lang = select_lang(MyLang, HisLang),
xmpp:set_lang(Pkt, Lang);
set_lang(Pkt, _) ->
Pkt.
-spec format_inet_error(atom()) -> string().
format_inet_error(closed) ->
"connection closed";
format_inet_error(Reason) ->
case inet:format_error(Reason) of
"unknown POSIX error" -> atom_to_list(Reason);
Txt -> Txt
end.
-spec format_stream_error(atom() | 'see-other-host'(), [text()]) -> string().
format_stream_error(Reason, Txt) ->
Slogan = case Reason of
undefined -> "no reason";
#'see-other-host'{} -> "see-other-host";
_ -> atom_to_list(Reason)
end,
case xmpp:get_text(Txt) of
<<"">> ->
Slogan;
Data ->
binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
end.
-spec format_sasl_error(cyrsasl:mechanism(), atom()) -> {atom(), binary()}.
format_sasl_error(<<"EXTERNAL">>, Err) ->
xmpp_stream_pkix:format_error(Err);
format_sasl_error(Mech, Err) ->
cyrsasl:format_error(Mech, Err).
-spec format_tls_error(atom() | binary()) -> list().
format_tls_error(Reason) when is_atom(Reason) ->
format_inet_error(Reason);
format_tls_error(Reason) ->
Reason.
-spec format(io:format(), list()) -> binary().
format(Fmt, Args) ->
iolist_to_binary(io_lib:format(Fmt, Args)).
-spec lists_intersection(list(), list()) -> list().
lists_intersection(L1, L2) ->
lists:filter(
fun(E) ->
lists:member(E, L2)
end, L1).
-spec identity([cyrsasl:sasl_property()]) -> binary().
identity(Props) ->
case proplists:get_value(authzid, Props, <<>>) of
<<>> -> proplists:get_value(username, Props, <<>>);
AuthzId -> AuthzId
end.
%%%===================================================================
%%% Callbacks
%%%===================================================================
callback(F, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 1) of
true -> Mod:F(State);
false -> erlang:error({?MODULE, undef})
end.
callback(F, Arg1, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 2) of
true -> Mod:F(Arg1, State);
false -> erlang:error({?MODULE, undef})
end.
callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
%% code_change/3 callback is a special snowflake
case erlang:function_exported(Mod, code_change, 3) of
true -> Mod:code_change(OldVsn, State, Extra);
false -> {ok, State}
end;
callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 3) of
true -> Mod:F(Arg1, Arg2, State);
false -> erlang:error({?MODULE, undef})
end.
callback(F, Arg1, Arg2, Arg3, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 4) of
true -> Mod:F(Arg1, Arg2, Arg3, State);
false -> erlang:error({?MODULE, undef})
end.