xmpp.chapril.org-ejabberd/src/xmpp_stream_in.erl

1082 lines
36 KiB
Erlang

%%%-------------------------------------------------------------------
%%% Created : 26 Nov 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
%%%
%%%
%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
%%%
%%% This program is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU General Public License as
%%% published by the Free Software Foundation; either version 2 of the
%%% License, or (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
%%% General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License along
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%-------------------------------------------------------------------
-module(xmpp_stream_in).
-behaviour(gen_server).
-protocol({rfc, 6120}).
%% API
-export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1,
send/2, close/1, close/2, send_error/3, establish/1,
get_transport/1, change_shaper/2, set_timeout/2, format_error/1]).
%% gen_server callbacks
-export([init/1, handle_cast/2, handle_call/3, handle_info/2,
terminate/2, code_change/3]).
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
-define(FSMOPTS, [{debug, [trace]}]).
-else.
-define(FSMOPTS, []).
-endif.
-include("xmpp.hrl").
-type state() :: map().
-type stop_reason() :: {stream, reset | stream_error()} |
{tls, term()} |
{socket, inet:posix() | closed | timeout} |
internal_failure.
-callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
-callback handle_cast(term(), state()) -> state().
-callback handle_call(term(), term(), state()) -> state().
-callback handle_info(term(), state()) -> state().
-callback terminate(term(), state()) -> any().
-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
-callback handle_stream_start(state()) -> state().
-callback handle_stream_end(stop_reason(), state()) -> state().
-callback handle_cdata(binary(), state()) -> state().
-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
-callback handle_authenticated_packet(xmpp_element(), state()) -> state().
-callback handle_unbinded_packet(xmpp_element(), state()) -> state().
-callback handle_auth_success(binary(), binary(), module(), state()) -> state().
-callback handle_auth_failure(binary(), binary(), atom(), state()) -> state().
-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
-callback get_password_fun(state()) -> fun().
-callback check_password_fun(state()) -> fun().
-callback check_password_digest_fun(state()) -> fun().
-callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}.
-callback compress_methods(state()) -> [binary()].
-callback tls_options(state()) -> [proplists:property()].
-callback tls_required(state()) -> boolean().
-callback tls_verify(state()) -> boolean().
-callback unauthenticated_stream_features(state()) -> [xmpp_element()].
-callback authenticated_stream_features(state()) -> [xmpp_element()].
%% All callbacks are optional
-optional_callbacks([init/1,
handle_cast/2,
handle_call/3,
handle_info/2,
terminate/2,
code_change/3,
handle_stream_start/1,
handle_stream_end/2,
handle_cdata/2,
handle_authenticated_packet/2,
handle_unauthenticated_packet/2,
handle_unbinded_packet/2,
handle_auth_success/4,
handle_auth_failure/4,
handle_send/3,
handle_recv/3,
get_password_fun/1,
check_password_fun/1,
check_password_digest_fun/1,
bind/2,
compress_methods/1,
tls_options/1,
tls_required/1,
tls_verify/1,
unauthenticated_stream_features/1,
authenticated_stream_features/1]).
%%%===================================================================
%%% API
%%%===================================================================
start(Mod, Args, Opts) ->
gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
start_link(Mod, Args, Opts) ->
gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
call(Ref, Msg, Timeout) ->
gen_server:call(Ref, Msg, Timeout).
cast(Ref, Msg) ->
gen_server:cast(Ref, Msg).
reply(Ref, Reply) ->
gen_server:reply(Ref, Reply).
-spec stop(pid()) -> ok;
(state()) -> no_return().
stop(Pid) when is_pid(Pid) ->
cast(Pid, stop);
stop(#{owner := Owner} = State) when Owner == self() ->
terminate(normal, State),
exit(normal);
stop(_) ->
erlang:error(badarg).
-spec send(pid(), xmpp_element()) -> ok;
(state(), xmpp_element()) -> state().
send(Pid, Pkt) when is_pid(Pid) ->
cast(Pid, {send, Pkt});
send(#{owner := Owner} = State, Pkt) when Owner == self() ->
send_element(State, Pkt);
send(_, _) ->
erlang:error(badarg).
-spec close(pid()) -> ok;
(state()) -> state().
close(Ref) ->
close(Ref, true).
-spec close(pid(), boolean()) -> ok;
(state(), boolean()) -> state().
close(Pid, SendTrailer) when is_pid(Pid) ->
cast(Pid, {close, SendTrailer});
close(#{owner := Owner} = State, SendTrailer) when Owner == self() ->
if SendTrailer -> send_trailer(State);
true -> close_socket(State)
end;
close(_, _) ->
erlang:error(badarg).
-spec establish(state()) -> state().
establish(State) ->
process_stream_established(State).
-spec set_timeout(state(), non_neg_integer() | infinity) -> state().
set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
case Timeout of
infinity -> State#{stream_timeout => infinity};
_ ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State#{stream_timeout => {Timeout, Time}}
end;
set_timeout(_, _) ->
erlang:error(badarg).
get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner})
when Owner == self() ->
SockMod:get_transport(Socket);
get_transport(_) ->
erlang:error(badarg).
-spec change_shaper(state(), shaper:shaper()) -> ok.
change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper)
when Owner == self() ->
SockMod:change_shaper(Socket, Shaper);
change_shaper(_, _) ->
erlang:error(badarg).
-spec format_error(stop_reason()) -> binary().
format_error({socket, Reason}) ->
format("Connection failed: ~s", [format_inet_error(Reason)]);
format_error({stream, reset}) ->
<<"Stream reset by peer">>;
format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
format_error({tls, Reason}) ->
format("TLS failed: ~w", [Reason]);
format_error(internal_failure) ->
<<"Internal server error">>;
format_error(Err) ->
format("Unrecognized error: ~w", [Err]).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([Module, {SockMod, Socket}, Opts]) ->
XMLSocket = case lists:keyfind(xml_socket, 1, Opts) of
{_, XS} -> XS;
false -> false
end,
Encrypted = proplists:get_bool(tls, Opts),
SocketMonitor = SockMod:monitor(Socket),
case peername(SockMod, Socket) of
{ok, IP} ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State = #{owner => self(),
mod => Module,
socket => Socket,
sockmod => SockMod,
socket_monitor => SocketMonitor,
stream_timeout => {timer:seconds(30), Time},
stream_direction => in,
stream_id => new_id(),
stream_state => wait_for_stream,
stream_header_sent => false,
stream_restarted => false,
stream_compressed => false,
stream_encrypted => Encrypted,
stream_version => {1,0},
stream_authenticated => false,
xml_socket => XMLSocket,
xmlns => ?NS_CLIENT,
lang => <<"">>,
user => <<"">>,
server => <<"">>,
resource => <<"">>,
lserver => <<"">>,
ip => IP},
case try Module:init([State, Opts])
catch _:undef -> {ok, State}
end of
{ok, State1} ->
{_, State2, Timeout} = noreply(State1),
{ok, State2, Timeout};
Err ->
Err
end;
{error, Reason} ->
{stop, Reason}
end.
handle_cast({send, Pkt}, State) ->
noreply(send_element(State, Pkt));
handle_cast(stop, State) ->
{stop, normal, State};
handle_cast(Cast, #{mod := Mod} = State) ->
noreply(try Mod:handle_cast(Cast, State)
catch _:undef -> State
end).
handle_call(Call, From, #{mod := Mod} = State) ->
noreply(try Mod:handle_call(Call, From, State)
catch _:undef -> State
end).
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
#{stream_state := wait_for_stream,
xmlns := XMLNS, lang := MyLang} = State) ->
El = #xmlel{name = Name, attrs = Attrs},
noreply(
try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt ->
State1 = send_header(State, Pkt),
case is_disconnected(State1) of
true -> State1;
false -> process_stream(Pkt, State1)
end;
_ ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false -> send_element(State1, xmpp:serr_invalid_xml())
end
catch _:{xmpp_codec, Why} ->
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang),
send_element(State1, Err)
end
end);
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
State1 = send_header(State),
noreply(
case is_disconnected(State1) of
true -> State1;
false ->
Err = case Reason of
<<"XML stanza is too big">> ->
xmpp:serr_policy_violation(Reason, Lang);
_ ->
xmpp:serr_not_well_formed()
end,
send_element(State1, Err)
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
noreply(
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_element(Pkt, State1)
end
catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, {error, Why}, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) ->
noreply(try Mod:handle_cdata(Data, State)
catch _:undef -> State
end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_end({socket, closed}, State));
handle_info(timeout, #{mod := Mod} = State) ->
Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State)
catch _:undef when not Disconnected ->
send_element(State, xmpp:serr_connection_timeout());
_:undef ->
stop(State)
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
#{socket_monitor := MRef} = State) ->
noreply(process_stream_end({socket, closed}, State));
handle_info(Info, #{mod := Mod} = State) ->
noreply(try Mod:handle_info(Info, State)
catch _:undef -> State
end).
terminate(Reason, #{mod := Mod} = State) ->
case get(already_terminated) of
true ->
State;
_ ->
put(already_terminated, true),
try Mod:terminate(Reason, State)
catch _:undef -> ok
end,
send_trailer(State)
end.
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
Mod:code_change(OldVsn, State, Extra).
%%%===================================================================
%%% Internal functions
%%%===================================================================
-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
noreply(#{stream_timeout := infinity} = State) ->
{noreply, State, infinity};
noreply(#{stream_timeout := {MSecs, StartTime}} = State) ->
CurrentTime = p1_time_compat:monotonic_time(milli_seconds),
Timeout = max(0, MSecs - CurrentTime + StartTime),
{noreply, State, Timeout}.
-spec new_id() -> binary().
new_id() ->
randoms:get_string().
-spec is_disconnected(state()) -> boolean().
is_disconnected(#{stream_state := StreamState}) ->
StreamState == disconnected.
-spec peername(term(), term()) -> {ok, {inet:ip_address(), inet:port_number()}}|
{error, inet:posix()}.
peername(SockMod, Socket) ->
case SockMod of
gen_tcp -> inet:peername(Socket);
_ -> SockMod:peername(Socket)
end.
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
process_stream_end(Reason, #{mod := Mod} = State) ->
State1 = send_trailer(State),
try Mod:handle_stream_end(Reason, State1)
catch _:undef -> stop(State1)
end.
-spec process_stream(stream_start(), state()) -> state().
process_stream(#stream_start{xmlns = XML_NS,
stream_xmlns = STREAM_NS},
#{xmlns := NS} = State)
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
send_element(State, xmpp:serr_invalid_namespace());
process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
send_element(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang},
#{xmlns := ?NS_CLIENT, lang := DefaultLang} = State)
when size(Lang) > 35 ->
%% As stated in BCP47, 4.4.1:
%% Protocols or specifications that specify limited buffer sizes for
%% language tags MUST allow for language tags of at least 35 characters.
%% Do not store long language tag to avoid possible DoS/flood attacks
Txt = <<"Too long value of 'xml:lang' attribute">>,
send_element(State, xmpp:serr_policy_violation(Txt, DefaultLang));
process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) ->
Txt = <<"Missing 'to' attribute">>,
send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
process_stream(#stream_start{from = undefined, version = {1,0}},
#{lang := Lang, xmlns := ?NS_SERVER,
stream_encrypted := true} = State) ->
Txt = <<"Missing 'from' attribute">>,
send_element(State, xmpp:serr_invalid_from(Txt, Lang));
process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
#{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
Txt = <<"Improper 'to' attribute">>,
send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
#{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
State1 = State#{remote_server => RemoteServer,
stream_state => wait_for_handshake},
try Mod:handle_stream_start(StreamStart, State1)
catch _:undef -> State1
end;
process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
from = From} = StreamStart,
#{stream_authenticated := Authenticated,
stream_restarted := StreamWasRestarted,
mod := Mod, xmlns := NS, resource := Resource,
stream_encrypted := Encrypted} = State) ->
State1 = if not StreamWasRestarted ->
State#{server => Server, lserver => LServer};
true ->
State
end,
State2 = if NS == ?NS_SERVER andalso Encrypted ->
State1#{remote_server => From#jid.lserver};
true ->
State1
end,
State3 = try Mod:handle_stream_start(StreamStart, State2)
catch _:undef -> State2
end,
case is_disconnected(State3) of
true -> State3;
false ->
State4 = send_features(State3),
case is_disconnected(State4) of
true -> State4;
false ->
TLSRequired = is_starttls_required(State4),
if not Authenticated and (TLSRequired and not Encrypted) ->
State4#{stream_state => wait_for_starttls};
not Authenticated ->
State4#{stream_state => wait_for_sasl_request};
(NS == ?NS_CLIENT) and (Resource == <<"">>) ->
State4#{stream_state => wait_for_bind};
true ->
process_stream_established(State4)
end
end
end.
-spec process_element(xmpp_element(), state()) -> state().
process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
case Pkt of
#starttls{} when StateName == wait_for_starttls;
StateName == wait_for_sasl_request ->
process_starttls(State);
#starttls{} ->
process_starttls_failure(unexpected_starttls_request, State);
#sasl_auth{} when StateName == wait_for_starttls ->
send_element(State, #sasl_failure{reason = 'encryption-required'});
#sasl_auth{} when StateName == wait_for_sasl_request ->
process_sasl_request(Pkt, State);
#sasl_auth{} ->
Txt = <<"SASL negotiation is not allowed in this state">>,
send_element(State, #sasl_failure{reason = 'not-authorized',
text = xmpp:mk_text(Txt, Lang)});
#sasl_response{} when StateName == wait_for_starttls ->
send_element(State, #sasl_failure{reason = 'encryption-required'});
#sasl_response{} when StateName == wait_for_sasl_response ->
process_sasl_response(Pkt, State);
#sasl_response{} ->
Txt = <<"SASL negotiation is not allowed in this state">>,
send_element(State, #sasl_failure{reason = 'not-authorized',
text = xmpp:mk_text(Txt, Lang)});
#sasl_abort{} when StateName == wait_for_sasl_response ->
process_sasl_abort(State);
#sasl_abort{} ->
send_element(State, #sasl_failure{reason = 'aborted'});
#sasl_success{} ->
State;
#compress{} when StateName == wait_for_sasl_response ->
send_element(State, #compress_failure{reason = 'setup-failed'});
#compress{} ->
process_compress(Pkt, State);
#handshake{} when StateName == wait_for_handshake ->
process_handshake(Pkt, State);
#handshake{} ->
State;
#stream_error{} ->
process_stream_end({stream, Pkt}, State);
_ when StateName == wait_for_sasl_request;
StateName == wait_for_handshake;
StateName == wait_for_sasl_response ->
process_unauthenticated_packet(Pkt, State);
_ when StateName == wait_for_starttls ->
Txt = <<"Use of STARTTLS required">>,
Err = xmpp:err_policy_violation(Txt, Lang),
send_error(State, Pkt, Err);
_ when StateName == wait_for_bind ->
process_bind(Pkt, State);
_ when StateName == established ->
process_authenticated_packet(Pkt, State)
end.
-spec process_unauthenticated_packet(xmpp_element(), state()) -> state().
process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
NewPkt = set_lang(Pkt, State),
try Mod:handle_unauthenticated_packet(NewPkt, State)
catch _:undef ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end.
-spec process_authenticated_packet(xmpp_element(), state()) -> state().
process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
Pkt1 = set_lang(Pkt, State),
case set_from_to(Pkt1, State) of
{ok, #iq{type = set, sub_els = [_]} = Pkt2} when NS == ?NS_CLIENT ->
case xmpp:get_subtag(Pkt2, #xmpp_session{}) of
#xmpp_session{} ->
send_element(State, xmpp:make_iq_result(Pkt2));
_ ->
try Mod:handle_authenticated_packet(Pkt2, State)
catch _:undef ->
Err = xmpp:err_service_unavailable(),
send_error(State, Pkt, Err)
end
end;
{ok, Pkt2} ->
try Mod:handle_authenticated_packet(Pkt2, State)
catch _:undef ->
Err = xmpp:err_service_unavailable(),
send_error(State, Pkt, Err)
end;
{error, Err} ->
send_element(State, Err)
end.
-spec process_bind(xmpp_element(), state()) -> state().
process_bind(#iq{type = set, sub_els = [_]} = Pkt,
#{xmlns := ?NS_CLIENT, mod := Mod, lang := Lang} = State) ->
case xmpp:get_subtag(Pkt, #bind{}) of
#bind{resource = R} ->
case jid:resourceprep(R) of
error ->
Txt = <<"Malformed resource">>,
Err = xmpp:err_bad_request(Txt, Lang),
send_error(State, Pkt, Err);
_ ->
case Mod:bind(R, State) of
{ok, #{user := U,
server := S,
resource := NewR} = State1} when NewR /= <<"">> ->
Reply = #bind{jid = jid:make(U, S, NewR)},
State2 = send_element(State1, xmpp:make_iq_result(Pkt, Reply)),
process_stream_established(State2);
{error, #stanza_error{}, State1} = Err ->
send_error(State1, Pkt, Err)
end
end;
_ ->
try Mod:handle_unbinded_packet(Pkt, State)
catch _:undef ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end
end;
process_bind(Pkt, #{mod := Mod} = State) ->
try Mod:handle_unbinded_packet(Pkt, State)
catch _:undef ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end.
-spec process_handshake(handshake(), state()) -> state().
process_handshake(#handshake{data = Digest},
#{mod := Mod, stream_id := StreamID,
remote_server := RemoteServer} = State) ->
GetPW = try Mod:get_password_fun(State)
catch _:undef -> fun(_) -> {false, undefined} end
end,
AuthRes = case GetPW(<<"">>) of
{false, _} ->
false;
{Password, _} ->
p1_sha:sha(<<StreamID/binary, Password/binary>>) == Digest
end,
case AuthRes of
true ->
State1 = try Mod:handle_auth_success(
RemoteServer, <<"handshake">>, undefined, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
State2 = send_element(State1, #handshake{}),
process_stream_established(State2)
end;
false ->
State1 = try Mod:handle_auth_failure(
RemoteServer, <<"handshake">>, 'not-authorized', State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> send_element(State1, xmpp:serr_not_authorized())
end
end.
-spec process_stream_established(state()) -> state().
process_stream_established(#{stream_state := StateName} = State)
when StateName == disconnected; StateName == established ->
State;
process_stream_established(#{mod := Mod} = State) ->
State1 = State#{stream_authenticated := true,
stream_state => established,
stream_timeout => infinity},
try Mod:handle_stream_established(State1)
catch _:undef -> State1
end.
-spec process_compress(compress(), state()) -> state().
process_compress(#compress{}, #{stream_compressed := true} = State) ->
send_element(State, #compress_failure{reason = 'setup-failed'});
process_compress(#compress{methods = HisMethods},
#{socket := Socket, sockmod := SockMod, mod := Mod} = State) ->
MyMethods = try Mod:compress_methods(State)
catch _:undef -> []
end,
CommonMethods = lists_intersection(MyMethods, HisMethods),
case lists:member(<<"zlib">>, CommonMethods) of
true ->
BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})),
ZlibSocket = SockMod:compress(Socket, BCompressed),
State#{socket => ZlibSocket,
stream_id => new_id(),
stream_header_sent => false,
stream_restarted => true,
stream_state => wait_for_stream,
stream_compressed => true};
false ->
send_element(State, #compress_failure{reason = 'unsupported-method'})
end.
-spec process_starttls(state()) -> state().
process_starttls(#{socket := Socket,
sockmod := SockMod, mod := Mod} = State) ->
TLSOpts = try Mod:tls_options(State)
catch _:undef -> []
end,
case SockMod:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
State1 = send_element(State, #starttls_proceed{}),
case is_disconnected(State1) of
true -> State1;
false ->
State1#{socket => TLSSocket,
stream_id => new_id(),
stream_header_sent => false,
stream_restarted => true,
stream_state => wait_for_stream,
stream_encrypted => true}
end;
{error, Reason} ->
process_starttls_failure(Reason, State)
end.
-spec process_starttls_failure(term(), state()) -> state().
process_starttls_failure(Why, State) ->
State1 = send_element(State, #starttls_failure{}),
case is_disconnected(State1) of
true -> State1;
false -> process_stream_end({tls, Why}, State1)
end.
-spec process_sasl_request(sasl_auth(), state()) -> state().
process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
#{mod := Mod, lserver := LServer} = State) ->
GetPW = try Mod:get_password_fun(State)
catch _:undef -> fun(_) -> false end
end,
CheckPW = try Mod:check_password_fun(State)
catch _:undef -> fun(_, _, _) -> false end
end,
CheckPWDigest = try Mod:check_password_digest_fun(State)
catch _:undef -> fun(_, _, _, _, _) -> false end
end,
SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
GetPW, CheckPW, CheckPWDigest),
State1 = State#{sasl_state => SASLState, sasl_mech => Mech},
Mechs = get_sasl_mechanisms(State1),
SASLResult = case lists:member(Mech, Mechs) of
true when Mech == <<"EXTERNAL">> ->
case xmpp_stream_pkix:authenticate(State1, ClientIn) of
{ok, Peer} ->
{ok, [{auth_module, pkix},
{username, Peer}]};
{error, _Reason, Peer} ->
%% TODO: return meaningful error
{error, 'not-authorized', Peer}
end;
true ->
cyrsasl:server_start(SASLState, Mech, ClientIn);
false ->
{error, 'invalid-mechanism'}
end,
process_sasl_result(SASLResult, State1).
-spec process_sasl_response(sasl_response(), state()) -> state().
process_sasl_response(#sasl_response{text = ClientIn},
#{sasl_state := SASLState} = State) ->
SASLResult = cyrsasl:server_step(SASLState, ClientIn),
process_sasl_result(SASLResult, State).
-spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state().
process_sasl_result({ok, Props}, State) ->
process_sasl_success(Props, <<"">>, State);
process_sasl_result({ok, Props, ServerOut}, State) ->
process_sasl_success(Props, ServerOut, State);
process_sasl_result({continue, ServerOut, NewSASLState}, State) ->
process_sasl_continue(ServerOut, NewSASLState, State);
process_sasl_result({error, Reason, User}, State) ->
process_sasl_failure(Reason, User, State);
process_sasl_result({error, Reason}, State) ->
process_sasl_failure(Reason, <<"">>, State).
-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
process_sasl_success(Props, ServerOut,
#{socket := Socket, sockmod := SockMod,
mod := Mod, sasl_mech := Mech} = State) ->
User = identity(Props),
AuthModule = proplists:get_value(auth_module, Props),
State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State)
catch _:undef -> State
end,
case is_disconnected(State1) of
true -> State1;
false ->
SockMod:reset_stream(Socket),
State2 = send_element(State1, #sasl_success{text = ServerOut}),
case is_disconnected(State2) of
true -> State2;
false ->
State3 = maps:remove(sasl_state,
maps:remove(sasl_mech, State2)),
State3#{stream_id => new_id(),
stream_authenticated => true,
stream_header_sent => false,
stream_restarted => true,
stream_state => wait_for_stream,
user => User}
end
end.
-spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state().
process_sasl_continue(ServerOut, NewSASLState, State) ->
State1 = State#{sasl_state => NewSASLState,
stream_state => wait_for_sasl_response},
send_element(State1, #sasl_challenge{text = ServerOut}).
-spec process_sasl_failure(atom(), binary(), state()) -> state().
process_sasl_failure(Reason, User,
#{mod := Mod, sasl_mech := Mech} = State) ->
State1 = try Mod:handle_auth_failure(User, Mech, Reason, State)
catch _:undef -> State
end,
State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)),
State3 = State2#{stream_state => wait_for_sasl_request},
send_element(State3, #sasl_failure{reason = Reason}).
-spec process_sasl_abort(state()) -> state().
process_sasl_abort(State) ->
process_sasl_failure('aborted', <<"">>, State).
-spec send_features(state()) -> state().
send_features(#{stream_version := {1,0},
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
Features = if TLSRequired and not Encrypted ->
get_tls_feature(State);
true ->
get_sasl_feature(State) ++ get_compress_feature(State)
++ get_tls_feature(State) ++ get_bind_feature(State)
++ get_session_feature(State) ++ get_other_features(State)
end,
send_element(State, #stream_features{sub_els = Features});
send_features(State) ->
%% clients and servers from stone age
State.
-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()].
get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod,
xmlns := NS, lserver := LServer} = State) ->
Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer);
true -> []
end,
TLSVerify = try Mod:tls_verify(State)
catch _:undef -> false
end,
if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
[<<"EXTERNAL">>|Mechs];
true ->
Mechs
end.
-spec get_sasl_feature(state()) -> [sasl_mechanisms()].
get_sasl_feature(#{stream_authenticated := false,
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
if Encrypted or not TLSRequired ->
Mechs = get_sasl_mechanisms(State),
[#sasl_mechanisms{list = Mechs}];
true ->
[]
end;
get_sasl_feature(_) ->
[].
-spec get_compress_feature(state()) -> [compression()].
get_compress_feature(#{stream_compressed := false, mod := Mod} = State) ->
try Mod:compress_methods(State) of
[] -> [];
Ms -> [#compression{methods = Ms}]
catch _:undef ->
[]
end;
get_compress_feature(_) ->
[].
-spec get_tls_feature(state()) -> [starttls()].
get_tls_feature(#{stream_authenticated := false,
stream_encrypted := false} = State) ->
TLSRequired = is_starttls_required(State),
[#starttls{required = TLSRequired}];
get_tls_feature(_) ->
[].
-spec get_bind_feature(state()) -> [bind()].
get_bind_feature(#{xmlns := ?NS_CLIENT,
stream_authenticated := true,
resource := <<"">>}) ->
[#bind{}];
get_bind_feature(_) ->
[].
-spec get_session_feature(state()) -> [xmpp_session()].
get_session_feature(#{xmlns := ?NS_CLIENT,
stream_authenticated := true,
resource := <<"">>}) ->
[#xmpp_session{optional = true}];
get_session_feature(_) ->
[].
-spec get_other_features(state()) -> [xmpp_element()].
get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
try
if Auth -> Mod:authenticated_stream_features(State);
true -> Mod:unauthenticated_stream_features(State)
end
catch _:undef ->
[]
end.
-spec is_starttls_required(state()) -> boolean().
is_starttls_required(#{mod := Mod} = State) ->
try Mod:tls_required(State)
catch _:undef -> false
end.
-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} |
{error, stream_error()}.
set_from_to(Pkt, _State) when not ?is_stanza(Pkt) ->
{ok, Pkt};
set_from_to(Pkt, #{user := U, server := S, resource := R,
lang := Lang, xmlns := ?NS_CLIENT}) ->
JID = jid:make(U, S, R),
From = case xmpp:get_from(Pkt) of
undefined -> JID;
F -> F
end,
if JID#jid.luser == From#jid.luser andalso
JID#jid.lserver == From#jid.lserver andalso
(JID#jid.lresource == From#jid.lresource
orelse From#jid.lresource == <<"">>) ->
To = case xmpp:get_to(Pkt) of
undefined -> jid:make(U, S);
T -> T
end,
{ok, xmpp:set_from_to(Pkt, JID, To)};
true ->
Txt = <<"Improper 'from' attribute">>,
{error, xmpp:serr_invalid_from(Txt, Lang)}
end;
set_from_to(Pkt, #{lang := Lang}) ->
From = xmpp:get_from(Pkt),
To = xmpp:get_to(Pkt),
if From == undefined ->
Txt = <<"Missing 'from' attribute">>,
{error, xmpp:serr_invalid_from(Txt, Lang)};
To == undefined ->
Txt = <<"Missing 'to' attribute">>,
{error, xmpp:serr_improper_addressing(Txt, Lang)};
true ->
{ok, Pkt}
end.
-spec send_header(state()) -> state().
send_header(#{stream_version := Version} = State) ->
send_header(State, #stream_start{version = Version}).
-spec send_header(state(), stream_start()) -> state().
send_header(#{stream_id := StreamID,
stream_version := MyVersion,
stream_header_sent := false,
lang := MyLang,
xmlns := NS,
server := DefaultServer} = State,
#stream_start{to = To, lang = HisLang, version = HisVersion}) ->
Lang = select_lang(MyLang, HisLang),
NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
true -> <<"">>
end,
From = case To of
#jid{} -> To;
undefined -> jid:make(DefaultServer)
end,
Version = case HisVersion of
undefined -> undefined;
{0,_} -> HisVersion;
_ -> MyVersion
end,
Header = xmpp:encode(#stream_start{version = Version,
lang = Lang,
xmlns = NS,
stream_xmlns = ?NS_STREAM,
db_xmlns = NS_DB,
id = StreamID,
from = From}),
State1 = State#{lang => Lang,
stream_version => Version,
stream_header_sent => true},
case send_text(State1, fxml:element_to_header(Header)) of
ok -> State1;
{error, Why} -> process_stream_end({socket, Why}, State1)
end;
send_header(State, _) ->
State.
-spec send_element(state(), xmpp_element()) -> state().
send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
El = xmpp:encode(Pkt, NS),
Data = fxml:element_to_binary(El),
Result = send_text(State, Data),
State1 = try Mod:handle_send(Pkt, Result, State)
catch _:undef -> State
end,
case Result of
_ when is_record(Pkt, stream_error) ->
process_stream_end({stream, Pkt}, State1);
ok ->
State1;
{error, Why} ->
process_stream_end({socket, Why}, State1)
end.
-spec send_error(state(), xmpp_element(), stanza_error()) -> state().
send_error(State, Pkt, Err) ->
case xmpp:is_stanza(Pkt) of
true ->
case xmpp:get_type(Pkt) of
result -> State;
error -> State;
<<"result">> -> State;
<<"error">> -> State;
_ ->
ErrPkt = xmpp:make_error(Pkt, Err),
send_element(State, ErrPkt)
end;
false ->
State
end.
-spec send_trailer(state()) -> state().
send_trailer(State) ->
send_text(State, <<"</stream:stream>">>),
close_socket(State).
-spec send_text(state(), binary()) -> ok | {error, inet:posix()}.
send_text(#{socket := Sock, sockmod := SockMod,
stream_state := StateName,
stream_header_sent := true}, Data) when StateName /= disconnected ->
SockMod:send(Sock, Data);
send_text(_, _) ->
{error, closed}.
-spec close_socket(state()) -> state().
close_socket(#{sockmod := SockMod, socket := Socket} = State) ->
SockMod:close(Socket),
State#{stream_timeout => infinity,
stream_state => disconnected}.
-spec select_lang(binary(), binary()) -> binary().
select_lang(Lang, <<"">>) -> Lang;
select_lang(_, Lang) -> Lang.
-spec set_lang(xmpp_element(), state()) -> xmpp_element().
set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) ->
HisLang = xmpp:get_lang(Pkt),
Lang = select_lang(MyLang, HisLang),
xmpp:set_lang(Pkt, Lang);
set_lang(Pkt, _) ->
Pkt.
-spec format_inet_error(atom()) -> string().
format_inet_error(Reason) ->
case inet:format_error(Reason) of
"unknown POSIX error" -> atom_to_list(Reason);
Txt -> Txt
end.
-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
format_stream_error(Reason, Txt) ->
Slogan = case Reason of
#'see-other-host'{} -> "see-other-host";
_ -> atom_to_list(Reason)
end,
case Txt of
undefined -> Slogan;
#text{data = <<"">>} -> Slogan;
#text{data = Data} ->
binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
end.
-spec format(io:format(), list()) -> binary().
format(Fmt, Args) ->
iolist_to_binary(io_lib:format(Fmt, Args)).
-spec lists_intersection(list(), list()) -> list().
lists_intersection(L1, L2) ->
lists:filter(
fun(E) ->
lists:member(E, L2)
end, L1).
-spec identity([cyrsasl:sasl_property()]) -> binary().
identity(Props) ->
case proplists:get_value(authzid, Props, <<>>) of
<<>> -> proplists:get_value(username, Props, <<>>);
AuthzId -> AuthzId
end.