2016-12-11 13:03:37 +01:00
|
|
|
%%%-------------------------------------------------------------------
|
|
|
|
%%% Created : 26 Nov 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
|
|
|
|
%%%
|
|
|
|
%%%
|
|
|
|
%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
|
|
|
|
%%%
|
|
|
|
%%% This program is free software; you can redistribute it and/or
|
|
|
|
%%% modify it under the terms of the GNU General Public License as
|
|
|
|
%%% published by the Free Software Foundation; either version 2 of the
|
|
|
|
%%% License, or (at your option) any later version.
|
|
|
|
%%%
|
|
|
|
%%% This program is distributed in the hope that it will be useful,
|
|
|
|
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
|
|
%%% General Public License for more details.
|
|
|
|
%%%
|
|
|
|
%%% You should have received a copy of the GNU General Public License along
|
|
|
|
%%% with this program; if not, write to the Free Software Foundation, Inc.,
|
|
|
|
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
|
|
|
%%%
|
|
|
|
%%%-------------------------------------------------------------------
|
|
|
|
-module(xmpp_stream_in).
|
|
|
|
-behaviour(gen_server).
|
|
|
|
|
|
|
|
-protocol({rfc, 6120}).
|
|
|
|
|
|
|
|
%% API
|
|
|
|
-export([start/3, call/3, cast/2, reply/2, send/2, send_error/3,
|
|
|
|
get_transport/1, change_shaper/2]).
|
|
|
|
|
|
|
|
%% gen_server callbacks
|
|
|
|
-export([init/1, handle_cast/2, handle_call/3, handle_info/2,
|
|
|
|
terminate/2, code_change/3]).
|
|
|
|
|
|
|
|
-include("xmpp.hrl").
|
|
|
|
-type state() :: map().
|
|
|
|
-type next_state() :: {noreply, state()} | {stop, term(), state()}.
|
|
|
|
|
|
|
|
-callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
|
2016-12-11 16:24:51 +01:00
|
|
|
-callback handle_stream_start(state()) -> next_state().
|
|
|
|
-callback handle_stream_end(state()) -> next_state().
|
|
|
|
-callback handle_stream_close(state()) -> next_state().
|
|
|
|
-callback handle_cdata(binary(), state()) -> next_state().
|
|
|
|
-callback handle_unauthenticated_packet(xmpp_element(), state()) -> next_state().
|
2016-12-11 13:03:37 +01:00
|
|
|
-callback handle_authenticated_packet(xmpp_element(), state()) -> next_state().
|
2016-12-11 16:24:51 +01:00
|
|
|
-callback handle_unbinded_packet(xmpp_element(), state()) -> next_state().
|
|
|
|
-callback handle_auth_success(binary(), binary(), module(), state()) -> next_state().
|
|
|
|
-callback handle_auth_failure(binary(), binary(), atom(), state()) -> next_state().
|
|
|
|
-callback handle_send(ok | {error, atom()},
|
|
|
|
xmpp_element(), fxml:xmlel(), binary(), state()) -> next_state().
|
|
|
|
-callback init_sasl(state()) -> cyrsasl:sasl_state().
|
|
|
|
-callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}.
|
|
|
|
-callback handshake(binary(), state()) -> {ok, state()} | {error, stream_error(), state()}.
|
|
|
|
-callback compress_methods(state()) -> [binary()].
|
|
|
|
-callback tls_options(state()) -> [proplists:property()].
|
|
|
|
-callback tls_required(state()) -> boolean().
|
|
|
|
-callback sasl_mechanisms(state()) -> [binary()].
|
|
|
|
-callback unauthenticated_stream_features(state()) -> [xmpp_element()].
|
|
|
|
-callback authenticated_stream_features(state()) -> [xmpp_element()].
|
|
|
|
|
|
|
|
%% All callbacks are optional
|
|
|
|
-optional_callbacks([init/1,
|
|
|
|
handle_stream_start/1,
|
|
|
|
handle_stream_end/1,
|
|
|
|
handle_stream_close/1,
|
|
|
|
handle_cdata/2,
|
|
|
|
handle_authenticated_packet/2,
|
|
|
|
handle_unauthenticated_packet/2,
|
|
|
|
handle_unbinded_packet/2,
|
|
|
|
handle_auth_success/4,
|
|
|
|
handle_auth_failure/4,
|
|
|
|
handle_send/5,
|
|
|
|
init_sasl/1,
|
|
|
|
bind/2,
|
|
|
|
handshake/2,
|
|
|
|
compress_methods/1,
|
|
|
|
tls_options/1,
|
|
|
|
tls_required/1,
|
|
|
|
sasl_mechanisms/1,
|
|
|
|
unauthenticated_stream_features/1,
|
|
|
|
authenticated_stream_features/1]).
|
2016-12-11 13:03:37 +01:00
|
|
|
|
|
|
|
%%%===================================================================
|
|
|
|
%%% API
|
|
|
|
%%%===================================================================
|
|
|
|
start(Mod, Args, Opts) ->
|
|
|
|
gen_server:start(?MODULE, [Mod|Args], Opts).
|
|
|
|
|
|
|
|
call(Ref, Msg, Timeout) ->
|
|
|
|
gen_server:call(Ref, Msg, Timeout).
|
|
|
|
|
|
|
|
cast(Ref, Msg) ->
|
|
|
|
gen_server:cast(Ref, Msg).
|
|
|
|
|
|
|
|
reply(Ref, Reply) ->
|
|
|
|
gen_server:reply(Ref, Reply).
|
|
|
|
|
|
|
|
-spec send(state(), xmpp_element()) -> next_state().
|
|
|
|
send(State, Pkt) ->
|
|
|
|
send_element(State, Pkt).
|
|
|
|
|
|
|
|
get_transport(#{sockmod := SockMod, socket := Socket}) ->
|
|
|
|
SockMod:get_transport(Socket).
|
|
|
|
|
|
|
|
-spec change_shaper(state(), shaper:shaper()) -> ok.
|
|
|
|
change_shaper(#{sockmod := SockMod, socket := Socket}, Shaper) ->
|
|
|
|
SockMod:change_shaper(Socket, Shaper).
|
|
|
|
|
|
|
|
%%%===================================================================
|
|
|
|
%%% gen_server callbacks
|
|
|
|
%%%===================================================================
|
|
|
|
init([Module, {SockMod, Socket}, Opts]) ->
|
|
|
|
XMLSocket = case lists:keyfind(xml_socket, 1, Opts) of
|
|
|
|
{_, XS} -> XS;
|
|
|
|
false -> false
|
|
|
|
end,
|
|
|
|
TLSEnabled = proplists:get_bool(tls, Opts),
|
|
|
|
SocketMonitor = SockMod:monitor(Socket),
|
|
|
|
case peername(SockMod, Socket) of
|
|
|
|
{ok, IP} ->
|
|
|
|
State = #{mod => Module,
|
|
|
|
socket => Socket,
|
|
|
|
sockmod => SockMod,
|
|
|
|
socket_monitor => SocketMonitor,
|
|
|
|
stream_id => new_id(),
|
|
|
|
stream_state => wait_for_stream,
|
|
|
|
stream_restarted => false,
|
|
|
|
stream_compressed => false,
|
|
|
|
stream_tlsed => TLSEnabled,
|
|
|
|
stream_version => {1,0},
|
|
|
|
stream_authenticated => false,
|
|
|
|
xml_socket => XMLSocket,
|
|
|
|
xmlns => ?NS_CLIENT,
|
|
|
|
lang => <<"">>,
|
|
|
|
user => <<"">>,
|
|
|
|
server => <<"">>,
|
|
|
|
resource => <<"">>,
|
2016-12-11 16:24:51 +01:00
|
|
|
lserver => <<"">>,
|
2016-12-11 13:03:37 +01:00
|
|
|
ip => IP},
|
2016-12-11 16:24:51 +01:00
|
|
|
try Module:init([State, Opts])
|
|
|
|
catch _:undef -> {ok, State}
|
|
|
|
end;
|
2016-12-11 13:03:37 +01:00
|
|
|
{error, Reason} ->
|
|
|
|
{stop, Reason}
|
|
|
|
end.
|
|
|
|
|
|
|
|
handle_cast(Cast, #{mod := Mod} = State) ->
|
2016-12-11 16:24:51 +01:00
|
|
|
try Mod:handle_cast(Cast, State)
|
|
|
|
catch _:undef -> {noreply, State}
|
|
|
|
end.
|
2016-12-11 13:03:37 +01:00
|
|
|
|
|
|
|
handle_call(Call, From, #{mod := Mod} = State) ->
|
2016-12-11 16:24:51 +01:00
|
|
|
try Mod:handle_call(Call, From, State)
|
|
|
|
catch _:undef -> {reply, ok, State}
|
|
|
|
end.
|
2016-12-11 13:03:37 +01:00
|
|
|
|
|
|
|
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
|
2016-12-11 16:24:51 +01:00
|
|
|
#{stream_state := wait_for_stream, xmlns := XMLNS} = State) ->
|
|
|
|
try xmpp:decode(#xmlel{name = Name, attrs = Attrs}, XMLNS, []) of
|
2016-12-11 13:03:37 +01:00
|
|
|
#stream_start{} = Pkt ->
|
|
|
|
case send_header(State, Pkt) of
|
|
|
|
{noreply, State1} ->
|
|
|
|
process_stream(Pkt, State1);
|
|
|
|
Err ->
|
|
|
|
Err
|
|
|
|
end;
|
|
|
|
_ ->
|
|
|
|
case send_header(State) of
|
|
|
|
{noreply, State1} ->
|
|
|
|
send_element(State1, xmpp:serr_invalid_xml());
|
|
|
|
Err ->
|
|
|
|
Err
|
|
|
|
end
|
|
|
|
catch _:{xmpp_codec, Why} ->
|
|
|
|
case send_header(State) of
|
|
|
|
{noreply, State1} -> process_invalid_xml(Why, State1);
|
|
|
|
Err -> Err
|
|
|
|
end
|
|
|
|
end;
|
|
|
|
handle_info({'$gen_event', {xmlstreamend, _}}, #{mod := Mod} = State) ->
|
|
|
|
try Mod:handle_stream_end(State)
|
|
|
|
catch _:undef -> {stop, normal, State}
|
|
|
|
end;
|
|
|
|
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
|
|
|
|
case send_header(State) of
|
|
|
|
{noreply, State1} ->
|
|
|
|
Err = case Reason of
|
|
|
|
<<"XML stanza is too big">> ->
|
|
|
|
xmpp:serr_policy_violation(Reason, Lang);
|
|
|
|
_ ->
|
|
|
|
xmpp:serr_not_well_formed()
|
|
|
|
end,
|
|
|
|
send_element(State1, Err);
|
|
|
|
Err ->
|
|
|
|
Err
|
|
|
|
end;
|
|
|
|
handle_info({'$gen_event', {xmlstreamelement, El}},
|
|
|
|
#{xmlns := NS} = State) ->
|
|
|
|
try xmpp:decode(El, NS, [ignore_els]) of
|
|
|
|
Pkt ->
|
|
|
|
process_element(Pkt, State)
|
|
|
|
catch _:{xmpp_codec, Why} ->
|
|
|
|
process_invalid_xml(Why, State)
|
|
|
|
end;
|
|
|
|
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
|
|
|
|
#{mod := Mod} = State) ->
|
|
|
|
try Mod:handle_cdata(Data, State)
|
|
|
|
catch _:undef -> {noreply, State}
|
|
|
|
end;
|
|
|
|
handle_info(closed, #{mod := Mod} = State) ->
|
|
|
|
try Mod:handle_stream_close(State)
|
|
|
|
catch _:undef -> {stop, normal, State}
|
|
|
|
end;
|
|
|
|
handle_info({'DOWN', MRef, _Type, _Object, _Info},
|
|
|
|
#{socket_monitor := MRef, mod := Mod} = State) ->
|
|
|
|
try Mod:handle_stream_close(State)
|
|
|
|
catch _:undef -> {stop, normal, State}
|
|
|
|
end;
|
|
|
|
handle_info(Info, #{mod := Mod} = State) ->
|
2016-12-11 16:24:51 +01:00
|
|
|
try Mod:handle_info(Info, State)
|
|
|
|
catch _:undef -> {noreply, State}
|
|
|
|
end.
|
2016-12-11 13:03:37 +01:00
|
|
|
|
|
|
|
terminate(Reason, #{mod := Mod, socket := Socket,
|
|
|
|
sockmod := SockMod} = State) ->
|
2016-12-11 16:24:51 +01:00
|
|
|
try Mod:terminate(Reason, State)
|
|
|
|
catch _:undef -> ok
|
|
|
|
end,
|
2016-12-11 13:03:37 +01:00
|
|
|
send_text(State, <<"</stream:stream>">>),
|
|
|
|
SockMod:close(Socket).
|
|
|
|
|
|
|
|
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
|
|
|
|
Mod:code_change(OldVsn, State, Extra).
|
|
|
|
|
|
|
|
%%%===================================================================
|
|
|
|
%%% Internal functions
|
|
|
|
%%%===================================================================
|
|
|
|
-spec new_id() -> binary().
|
|
|
|
new_id() ->
|
|
|
|
randoms:get_string().
|
|
|
|
|
|
|
|
peername(SockMod, Socket) ->
|
|
|
|
case SockMod of
|
|
|
|
gen_tcp -> inet:peername(Socket);
|
|
|
|
_ -> SockMod:peername(Socket)
|
|
|
|
end.
|
|
|
|
|
|
|
|
process_invalid_xml(Reason, #{lang := Lang} = State) ->
|
|
|
|
Txt = xmpp:io_format_error(Reason),
|
|
|
|
send_element(State, xmpp:serr_invalid_xml(Txt, Lang)).
|
|
|
|
|
|
|
|
process_stream(#stream_start{xmlns = XML_NS,
|
|
|
|
stream_xmlns = STREAM_NS},
|
|
|
|
#{xmlns := NS} = State)
|
|
|
|
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
|
|
|
|
send_element(State, xmpp:serr_invalid_namespace());
|
|
|
|
process_stream(#stream_start{lang = Lang},
|
|
|
|
#{xmlns := ?NS_CLIENT, lang := DefaultLang} = State)
|
|
|
|
when size(Lang) > 35 ->
|
|
|
|
%% As stated in BCP47, 4.4.1:
|
|
|
|
%% Protocols or specifications that specify limited buffer sizes for
|
|
|
|
%% language tags MUST allow for language tags of at least 35 characters.
|
|
|
|
%% Do not store long language tag to avoid possible DoS/flood attacks
|
|
|
|
Txt = <<"Too long value of 'xml:lang' attribute">>,
|
|
|
|
send_element(State, xmpp:serr_policy_violation(Txt, DefaultLang));
|
|
|
|
process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) ->
|
|
|
|
Txt = <<"Missing 'to' attribute">>,
|
|
|
|
send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
|
|
|
|
process_stream(#stream_start{from = undefined, version = {1,0}},
|
|
|
|
#{lang := Lang, xmlns := ?NS_SERVER,
|
|
|
|
stream_tlsed := true} = State) ->
|
|
|
|
Txt = <<"Missing 'from' attribute">>,
|
|
|
|
send_element(State, xmpp:serr_invalid_from(Txt, Lang));
|
|
|
|
process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
|
|
|
|
#{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
|
|
|
|
Txt = <<"Improper 'to' attribute">>,
|
|
|
|
send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
|
|
|
|
process_stream(#stream_start{to = #jid{lserver = RemoteServer}},
|
|
|
|
#{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
|
|
|
|
State1 = State#{remote_server => RemoteServer},
|
|
|
|
case try Mod:handle_stream_start(State1)
|
|
|
|
catch _:undef -> {noreply, State1}
|
|
|
|
end of
|
|
|
|
{noreply, State2} ->
|
|
|
|
{noreply, State2#{stream_state => wait_for_handshake}};
|
|
|
|
Err ->
|
|
|
|
Err
|
|
|
|
end;
|
2016-12-11 16:24:51 +01:00
|
|
|
process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
|
|
|
|
from = From},
|
2016-12-11 13:03:37 +01:00
|
|
|
#{stream_authenticated := Authenticated,
|
|
|
|
stream_restarted := StreamWasRestarted,
|
|
|
|
mod := Mod, xmlns := NS, resource := Resource,
|
|
|
|
stream_tlsed := TLSEnabled} = State) ->
|
|
|
|
case if not StreamWasRestarted ->
|
2016-12-11 16:24:51 +01:00
|
|
|
State1 = State#{server => Server, lserver => LServer},
|
2016-12-11 13:03:37 +01:00
|
|
|
try Mod:handle_stream_start(State1)
|
|
|
|
catch _:undef -> {noreply, State1}
|
|
|
|
end;
|
|
|
|
true ->
|
|
|
|
{noreply, State}
|
|
|
|
end of
|
|
|
|
{noreply, State2} ->
|
|
|
|
State3 = if NS == ?NS_SERVER andalso TLSEnabled ->
|
|
|
|
State2#{remote_server => From#jid.lserver};
|
|
|
|
true ->
|
|
|
|
State2
|
|
|
|
end,
|
|
|
|
case send_features(State3) of
|
|
|
|
{noreply, State4} ->
|
|
|
|
TLSRequired = is_starttls_required(State4),
|
|
|
|
NewStreamState =
|
|
|
|
if not Authenticated and
|
|
|
|
(not TLSEnabled and TLSRequired) ->
|
|
|
|
wait_for_starttls;
|
|
|
|
not Authenticated ->
|
|
|
|
wait_for_sasl_request;
|
|
|
|
(NS == ?NS_CLIENT) and (Resource == <<"">>) ->
|
|
|
|
wait_for_bind;
|
|
|
|
true ->
|
|
|
|
session_established
|
|
|
|
end,
|
|
|
|
{noreply, State4#{stream_state => NewStreamState}};
|
|
|
|
Err ->
|
|
|
|
Err
|
|
|
|
end;
|
|
|
|
Err ->
|
|
|
|
Err
|
|
|
|
end.
|
|
|
|
|
|
|
|
process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
|
|
|
|
case Pkt of
|
|
|
|
#starttls{} when StateName == wait_for_starttls;
|
|
|
|
StateName == wait_for_sasl_request ->
|
|
|
|
process_starttls(State);
|
|
|
|
#starttls{} ->
|
|
|
|
send_element(State, #starttls_failure{});
|
|
|
|
#sasl_auth{} when StateName == wait_for_starttls ->
|
|
|
|
send_element(State, #sasl_failure{reason = 'encryption-required'});
|
|
|
|
#sasl_auth{} when StateName == wait_for_sasl_request ->
|
|
|
|
process_sasl_request(Pkt, State);
|
|
|
|
#sasl_auth{} ->
|
|
|
|
Txt = <<"SASL negotiation is not allowed in this state">>,
|
|
|
|
send_element(State, #sasl_failure{reason = 'not-authorized',
|
|
|
|
text = xmpp:mk_text(Txt, Lang)});
|
|
|
|
#sasl_response{} when StateName == wait_for_starttls ->
|
|
|
|
send_element(State, #sasl_failure{reason = 'encryption-required'});
|
|
|
|
#sasl_response{} when StateName == wait_for_sasl_response ->
|
|
|
|
process_sasl_response(Pkt, State);
|
|
|
|
#sasl_response{} ->
|
|
|
|
Txt = <<"SASL negotiation is not allowed in this state">>,
|
|
|
|
send_element(State, #sasl_failure{reason = 'not-authorized',
|
|
|
|
text = xmpp:mk_text(Txt, Lang)});
|
|
|
|
#sasl_abort{} when StateName == wait_for_sasl_response ->
|
|
|
|
process_sasl_abort(State);
|
|
|
|
#sasl_abort{} ->
|
|
|
|
send_element(State, #sasl_failure{reason = 'aborted'});
|
|
|
|
#sasl_success{} ->
|
|
|
|
{noreply, State};
|
|
|
|
#compress{} when StateName == wait_for_sasl_response ->
|
|
|
|
send_element(State, #compress_failure{reason = 'setup-failed'});
|
|
|
|
#compress{} ->
|
|
|
|
process_compress(Pkt, State);
|
|
|
|
#handshake{} when StateName == wait_for_handshake ->
|
|
|
|
process_handshake(Pkt, State);
|
|
|
|
#handshake{} ->
|
|
|
|
{noreply, State};
|
|
|
|
_ when StateName == wait_for_sasl_request;
|
|
|
|
StateName == wait_for_handshake;
|
|
|
|
StateName == wait_for_sasl_response ->
|
|
|
|
process_unauthenticated_packet(Pkt, State);
|
|
|
|
_ when StateName == wait_for_starttls ->
|
|
|
|
Txt = <<"Use of STARTTLS required">>,
|
|
|
|
Err = xmpp:err_policy_violation(Txt, Lang),
|
|
|
|
send_error(State, Pkt, Err);
|
|
|
|
_ when StateName == wait_for_bind ->
|
|
|
|
process_bind(Pkt, State);
|
|
|
|
_ when StateName == session_established ->
|
|
|
|
process_authenticated_packet(Pkt, State)
|
|
|
|
end.
|
|
|
|
|
|
|
|
process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
|
|
|
|
NewPkt = set_lang(Pkt, State),
|
|
|
|
try Mod:handle_unauthenticated_packet(NewPkt, State)
|
|
|
|
catch _:undef ->
|
|
|
|
Err = xmpp:err_not_authorized(),
|
|
|
|
send_error(State, Pkt, Err)
|
|
|
|
end.
|
|
|
|
|
|
|
|
process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
|
|
|
|
Pkt1 = set_lang(Pkt, State),
|
|
|
|
case set_from_to(Pkt1, State) of
|
|
|
|
{ok, #iq{type = set, sub_els = [_]} = Pkt2} when NS == ?NS_CLIENT ->
|
|
|
|
case xmpp:get_subtag(Pkt2, #xmpp_session{}) of
|
|
|
|
#xmpp_session{} ->
|
|
|
|
send_element(State, xmpp:make_iq_result(Pkt2));
|
|
|
|
_ ->
|
2016-12-11 16:24:51 +01:00
|
|
|
try Mod:handle_authenticated_packet(Pkt2, State)
|
|
|
|
catch _:undef ->
|
|
|
|
Err = xmpp:err_service_unavailable(),
|
|
|
|
send_error(State, Pkt, Err)
|
|
|
|
end
|
2016-12-11 13:03:37 +01:00
|
|
|
end;
|
|
|
|
{ok, Pkt2} ->
|
2016-12-11 16:24:51 +01:00
|
|
|
try Mod:handle_authenticated_packet(Pkt2, State)
|
|
|
|
catch _:undef ->
|
|
|
|
Err = xmpp:err_service_unavailable(),
|
|
|
|
send_error(State, Pkt, Err)
|
|
|
|
end;
|
2016-12-11 13:03:37 +01:00
|
|
|
{error, Err} ->
|
|
|
|
send_element(State, Err)
|
|
|
|
end.
|
|
|
|
|
|
|
|
process_bind(#iq{type = set, sub_els = [_]} = Pkt,
|
|
|
|
#{xmlns := ?NS_CLIENT, mod := Mod, lang := Lang} = State) ->
|
|
|
|
case xmpp:get_subtag(Pkt, #bind{}) of
|
|
|
|
#bind{resource = R} ->
|
|
|
|
case jid:resourceprep(R) of
|
|
|
|
error ->
|
|
|
|
Txt = <<"Malformed resource">>,
|
|
|
|
Err = xmpp:err_bad_request(Txt, Lang),
|
|
|
|
send_error(State, Pkt, Err);
|
|
|
|
_ ->
|
|
|
|
case Mod:bind(R, State) of
|
|
|
|
{ok, #{user := U,
|
|
|
|
server := S,
|
|
|
|
resource := NewR} = State1} when NewR /= <<"">> ->
|
|
|
|
Reply = #bind{jid = jid:make(U, S, NewR)},
|
|
|
|
State2 = State1#{stream_state => session_established},
|
|
|
|
send_element(State2, xmpp:make_iq_result(Pkt, Reply));
|
|
|
|
{error, #stanza_error{}, State1} = Err ->
|
|
|
|
send_error(State1, Pkt, Err)
|
|
|
|
end
|
|
|
|
end;
|
|
|
|
_ ->
|
|
|
|
try Mod:handle_unbinded_packet(Pkt, State)
|
|
|
|
catch _:undef ->
|
|
|
|
Err = xmpp:err_not_authorized(),
|
|
|
|
send_error(State, Pkt, Err)
|
|
|
|
end
|
|
|
|
end;
|
|
|
|
process_bind(Pkt, #{mod := Mod} = State) ->
|
|
|
|
try Mod:handle_unbinded_packet(Pkt, State)
|
|
|
|
catch _:undef ->
|
|
|
|
Err = xmpp:err_not_authorized(),
|
|
|
|
send_error(State, Pkt, Err)
|
|
|
|
end.
|
|
|
|
|
2016-12-11 16:24:51 +01:00
|
|
|
process_handshake(#handshake{data = Data}, #{mod := Mod} = State) ->
|
|
|
|
case Mod:handshake(Data, State) of
|
|
|
|
{ok, State1} ->
|
|
|
|
State2 = State1#{stream_state => session_established,
|
|
|
|
stream_authenticated => true},
|
|
|
|
send_element(State2, #handshake{});
|
|
|
|
{error, #stream_error{} = Err, State1} ->
|
|
|
|
send_element(State1, Err)
|
|
|
|
end.
|
2016-12-11 13:03:37 +01:00
|
|
|
|
|
|
|
process_compress(#compress{}, #{stream_compressed := true} = State) ->
|
|
|
|
send_element(State, #compress_failure{reason = 'setup-failed'});
|
|
|
|
process_compress(#compress{methods = HisMethods},
|
|
|
|
#{socket := Socket, sockmod := SockMod, mod := Mod} = State) ->
|
|
|
|
MyMethods = try Mod:compress_methods(State)
|
|
|
|
catch _:undef -> []
|
|
|
|
end,
|
|
|
|
CommonMethods = lists_intersection(MyMethods, HisMethods),
|
|
|
|
case lists:member(<<"zlib">>, CommonMethods) of
|
|
|
|
true ->
|
|
|
|
BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})),
|
|
|
|
ZlibSocket = SockMod:compress(Socket, BCompressed),
|
|
|
|
State1 = State#{socket => ZlibSocket,
|
|
|
|
stream_id => new_id(),
|
|
|
|
stream_restarted => true,
|
|
|
|
stream_state => wait_for_stream,
|
|
|
|
stream_compressed => true},
|
|
|
|
{noreply, State1};
|
|
|
|
false ->
|
|
|
|
send_element(State, #compress_failure{reason = 'unsupported-method'})
|
|
|
|
end.
|
|
|
|
|
|
|
|
process_starttls(#{socket := Socket,
|
|
|
|
sockmod := SockMod, mod := Mod} = State) ->
|
|
|
|
TLSOpts = try Mod:tls_options(State)
|
|
|
|
catch _:undef -> []
|
|
|
|
end,
|
|
|
|
case SockMod:starttls(Socket, TLSOpts) of
|
|
|
|
{ok, TLSSocket} ->
|
|
|
|
case send_element(State, #starttls_proceed{}) of
|
|
|
|
{noreply, State1} ->
|
|
|
|
{noreply, State1#{socket => TLSSocket,
|
|
|
|
stream_id => new_id(),
|
|
|
|
stream_restarted => true,
|
|
|
|
stream_state => wait_for_stream,
|
|
|
|
stream_tlsed => true}};
|
|
|
|
Err ->
|
|
|
|
Err
|
|
|
|
end;
|
|
|
|
{error, _Reason} ->
|
|
|
|
send_element(State, #starttls_failure{})
|
|
|
|
end.
|
|
|
|
|
|
|
|
process_sasl_request(#sasl_auth{mechanism = <<"EXTERNAL">>},
|
|
|
|
#{stream_tlsed := false} = State) ->
|
|
|
|
process_sasl_failure('encryption-required', <<"">>, State);
|
|
|
|
process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
|
|
|
|
#{mod := Mod} = State) ->
|
2016-12-11 16:24:51 +01:00
|
|
|
try Mod:init_sasl(State) of
|
|
|
|
SASLState ->
|
|
|
|
SASLResult = cyrsasl:server_start(SASLState, Mech, ClientIn),
|
|
|
|
process_sasl_result(SASLResult, State)
|
|
|
|
catch _:undef ->
|
|
|
|
process_sasl_failure('temporary-auth-failure', <<"">>, State)
|
|
|
|
end.
|
2016-12-11 13:03:37 +01:00
|
|
|
|
|
|
|
process_sasl_response(#sasl_response{text = ClientIn},
|
|
|
|
#{sasl_state := SASLState} = State) ->
|
|
|
|
SASLResult = cyrsasl:server_step(SASLState, ClientIn),
|
|
|
|
process_sasl_result(SASLResult, State).
|
|
|
|
|
|
|
|
process_sasl_result({ok, Props}, State) ->
|
|
|
|
process_sasl_success(Props, <<"">>, State);
|
|
|
|
process_sasl_result({ok, Props, ServerOut}, State) ->
|
|
|
|
process_sasl_success(Props, ServerOut, State);
|
|
|
|
process_sasl_result({continue, ServerOut, NewSASLState}, State) ->
|
|
|
|
process_sasl_continue(ServerOut, NewSASLState, State);
|
|
|
|
process_sasl_result({error, Reason, User}, State) ->
|
|
|
|
process_sasl_failure(Reason, User, State);
|
|
|
|
process_sasl_result({error, Reason}, State) ->
|
|
|
|
process_sasl_failure(Reason, <<"">>, State).
|
|
|
|
|
|
|
|
process_sasl_success(Props, ServerOut,
|
|
|
|
#{socket := Socket, sockmod := SockMod,
|
|
|
|
mod := Mod, sasl_state := SASLState} = State) ->
|
|
|
|
Mech = cyrsasl:get_mech(SASLState),
|
|
|
|
User = identity(Props),
|
|
|
|
AuthModule = proplists:get_value(auth_module, Props),
|
|
|
|
case try Mod:handle_auth_success(User, Mech, AuthModule, State)
|
|
|
|
catch _:undef -> {noreply, State}
|
|
|
|
end of
|
|
|
|
{noreply, State1} ->
|
|
|
|
SockMod:reset_stream(Socket),
|
|
|
|
case send_element(State1, #sasl_success{text = ServerOut}) of
|
|
|
|
{noreply, State2} ->
|
|
|
|
State3 = maps:remove(sasl_state, State2),
|
|
|
|
{noreply, State3#{stream_id => new_id(),
|
|
|
|
stream_authenticated => true,
|
|
|
|
stream_restarted => true,
|
|
|
|
stream_state => wait_for_stream,
|
|
|
|
user => User}};
|
|
|
|
Err ->
|
|
|
|
Err
|
|
|
|
end;
|
|
|
|
Err ->
|
|
|
|
Err
|
|
|
|
end.
|
|
|
|
|
|
|
|
process_sasl_continue(ServerOut, NewSASLState, State) ->
|
|
|
|
send_element(State, #sasl_challenge{text = ServerOut}),
|
|
|
|
{noreply, State#{sasl_state => NewSASLState,
|
|
|
|
stream_state => wait_for_sasl_response}}.
|
|
|
|
|
|
|
|
process_sasl_failure(Reason, User,
|
|
|
|
#{mod := Mod, sasl_state := SASLState} = State) ->
|
|
|
|
Mech = cyrsasl:get_mech(SASLState),
|
|
|
|
case try Mod:handle_auth_failure(User, Mech, Reason, State)
|
|
|
|
catch _:undef -> {noreply, State}
|
|
|
|
end of
|
|
|
|
{noreply, State1} ->
|
|
|
|
State2 = maps:remove(sasl_state, State1),
|
|
|
|
State3 = State2#{stream_state => wait_for_sasl_request},
|
|
|
|
send_element(State3, #sasl_failure{reason = Reason});
|
|
|
|
Err ->
|
|
|
|
Err
|
|
|
|
end.
|
|
|
|
|
|
|
|
process_sasl_abort(State) ->
|
|
|
|
process_sasl_failure('aborted', <<"">>, State).
|
|
|
|
|
|
|
|
send_features(#{stream_version := {1,0},
|
|
|
|
stream_tlsed := TLSEnabled} = State) ->
|
|
|
|
TLSRequired = is_starttls_required(State),
|
|
|
|
Features = if TLSRequired and not TLSEnabled ->
|
|
|
|
get_tls_feature(State);
|
|
|
|
true ->
|
|
|
|
get_sasl_feature(State) ++ get_compress_feature(State)
|
|
|
|
++ get_tls_feature(State) ++ get_bind_feature(State)
|
|
|
|
++ get_session_feature(State) ++ get_other_features(State)
|
|
|
|
end,
|
|
|
|
send_element(State, #stream_features{sub_els = Features});
|
|
|
|
send_features(State) ->
|
|
|
|
%% clients from stone age
|
|
|
|
{noreply, State}.
|
|
|
|
|
|
|
|
get_sasl_feature(#{stream_authenticated := false,
|
|
|
|
mod := Mod,
|
|
|
|
stream_tlsed := TLSEnabled} = State) ->
|
|
|
|
TLSRequired = is_starttls_required(State),
|
|
|
|
if TLSEnabled or not TLSRequired ->
|
|
|
|
try Mod:sasl_mechanisms(State) of
|
|
|
|
[] -> [];
|
|
|
|
List -> [#sasl_mechanisms{list = List}]
|
|
|
|
catch _:undef ->
|
|
|
|
[]
|
|
|
|
end;
|
|
|
|
true ->
|
|
|
|
[]
|
|
|
|
end;
|
|
|
|
get_sasl_feature(_) ->
|
|
|
|
[].
|
|
|
|
|
|
|
|
get_compress_feature(#{stream_compressed := false, mod := Mod} = State) ->
|
|
|
|
try Mod:compress_methods(State) of
|
|
|
|
[] -> [];
|
|
|
|
Ms -> [#compression{methods = Ms}]
|
|
|
|
catch _:undef ->
|
|
|
|
[]
|
|
|
|
end;
|
|
|
|
get_compress_feature(_) ->
|
|
|
|
[].
|
|
|
|
|
|
|
|
get_tls_feature(#{stream_authenticated := false,
|
|
|
|
stream_tlsed := false} = State) ->
|
|
|
|
TLSRequired = is_starttls_required(State),
|
|
|
|
[#starttls{required = TLSRequired}];
|
|
|
|
get_tls_feature(_) ->
|
|
|
|
[].
|
|
|
|
|
|
|
|
get_bind_feature(#{stream_authenticated := true, resource := <<"">>}) ->
|
|
|
|
[#bind{}];
|
|
|
|
get_bind_feature(_) ->
|
|
|
|
[].
|
|
|
|
|
|
|
|
get_session_feature(#{stream_authenticated := true, resource := <<"">>}) ->
|
|
|
|
[#xmpp_session{optional = true}];
|
|
|
|
get_session_feature(_) ->
|
|
|
|
[].
|
|
|
|
|
|
|
|
get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
|
|
|
|
try
|
|
|
|
if Auth -> Mod:authenticated_stream_features(State);
|
|
|
|
true -> Mod:unauthenticated_stream_features(State)
|
|
|
|
end
|
|
|
|
catch _:undef ->
|
|
|
|
[]
|
|
|
|
end.
|
|
|
|
|
|
|
|
is_starttls_required(#{mod := Mod} = State) ->
|
|
|
|
try Mod:tls_required(State)
|
|
|
|
catch _:undef -> false
|
|
|
|
end.
|
|
|
|
|
|
|
|
set_from_to(Pkt, _State) when not ?is_stanza(Pkt) ->
|
|
|
|
{ok, Pkt};
|
|
|
|
set_from_to(Pkt, #{user := U, server := S, resource := R,
|
|
|
|
xmlns := ?NS_CLIENT}) ->
|
|
|
|
JID = jid:make(U, S, R),
|
|
|
|
From = case xmpp:get_from(Pkt) of
|
|
|
|
undefined -> JID;
|
|
|
|
F -> F
|
|
|
|
end,
|
|
|
|
if JID#jid.luser == From#jid.luser andalso
|
|
|
|
JID#jid.lserver == From#jid.lserver andalso
|
|
|
|
(JID#jid.lresource == From#jid.lresource
|
|
|
|
orelse From#jid.lresource == <<"">>) ->
|
|
|
|
To = case xmpp:get_to(Pkt) of
|
|
|
|
undefined -> jid:make(U, S);
|
|
|
|
T -> T
|
|
|
|
end,
|
|
|
|
{ok, xmpp:set_from_to(Pkt, JID, To)};
|
|
|
|
true ->
|
|
|
|
{error, xmpp:serr_invalid_from()}
|
|
|
|
end;
|
|
|
|
set_from_to(Pkt, #{lang := Lang}) ->
|
|
|
|
From = xmpp:get_from(Pkt),
|
|
|
|
To = xmpp:get_to(Pkt),
|
|
|
|
if From == undefined ->
|
|
|
|
Txt = <<"Missing 'from' attribute">>,
|
|
|
|
{error, xmpp:serr_invalid_from(Txt, Lang)};
|
|
|
|
To == undefined ->
|
|
|
|
Txt = <<"Missing 'to' attribute">>,
|
|
|
|
{error, xmpp:serr_improper_addressing(Txt, Lang)};
|
|
|
|
true ->
|
|
|
|
{ok, Pkt}
|
|
|
|
end.
|
|
|
|
|
|
|
|
send_header(State) ->
|
|
|
|
send_header(State, #stream_start{}).
|
|
|
|
|
|
|
|
send_header(#{stream_state := wait_for_stream,
|
|
|
|
stream_id := StreamID,
|
|
|
|
stream_version := MyVersion,
|
|
|
|
lang := MyLang,
|
|
|
|
xmlns := NS,
|
|
|
|
server := DefaultServer} = State,
|
|
|
|
#stream_start{to = To, lang = HisLang, version = HisVersion}) ->
|
|
|
|
Lang = choose_lang(MyLang, HisLang),
|
|
|
|
From = case To of
|
|
|
|
#jid{} -> To;
|
|
|
|
undefined -> jid:make(DefaultServer)
|
|
|
|
end,
|
|
|
|
Version = case HisVersion of
|
|
|
|
undefined -> MyVersion;
|
|
|
|
_ -> HisVersion
|
|
|
|
end,
|
|
|
|
Header = xmpp:encode(#stream_start{version = Version,
|
|
|
|
lang = Lang,
|
|
|
|
xmlns = NS,
|
|
|
|
stream_xmlns = ?NS_STREAM,
|
|
|
|
id = StreamID,
|
|
|
|
from = From}),
|
|
|
|
State1 = State#{lang => Lang},
|
|
|
|
case send_text(State1, fxml:element_to_header(Header)) of
|
|
|
|
ok -> {noreply, State1};
|
|
|
|
{error, _} -> {stop, normal, State1}
|
|
|
|
end;
|
|
|
|
send_header(State, _) ->
|
|
|
|
{noreply, State}.
|
|
|
|
|
|
|
|
send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
|
|
|
|
El = xmpp:encode(Pkt, NS),
|
|
|
|
Data = fxml:element_to_binary(El),
|
|
|
|
case send_text(State, Data) of
|
|
|
|
ok when is_record(Pkt, stream_error) ->
|
|
|
|
{stop, normal, State};
|
|
|
|
ok when is_record(Pkt, starttls_failure) ->
|
|
|
|
{stop, normal, State};
|
|
|
|
Res ->
|
|
|
|
try Mod:handle_send(Res, Pkt, El, Data, State)
|
|
|
|
catch _:undef when Res == ok ->
|
|
|
|
{noreply, State};
|
|
|
|
_:undef ->
|
|
|
|
{stop, normal, State}
|
|
|
|
end
|
|
|
|
end.
|
|
|
|
|
|
|
|
send_error(State, Pkt, Err) when ?is_stanza(Pkt) ->
|
|
|
|
case xmpp:get_type(Pkt) of
|
|
|
|
result -> {noreply, State};
|
|
|
|
error -> {noreply, State};
|
|
|
|
_ ->
|
|
|
|
ErrPkt = xmpp:make_error(Pkt, Err),
|
|
|
|
send_element(State, ErrPkt)
|
|
|
|
end;
|
|
|
|
send_error(State, _, _) ->
|
|
|
|
{noreply, State}.
|
|
|
|
|
|
|
|
send_text(#{socket := Sock, sockmod := SockMod}, Data) ->
|
|
|
|
SockMod:send(Sock, Data).
|
|
|
|
|
|
|
|
choose_lang(Lang, <<"">>) -> Lang;
|
|
|
|
choose_lang(_, Lang) -> Lang.
|
|
|
|
|
|
|
|
set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) ->
|
|
|
|
HisLang = xmpp:get_lang(Pkt),
|
|
|
|
Lang = choose_lang(MyLang, HisLang),
|
|
|
|
xmpp:set_lang(Pkt, Lang);
|
|
|
|
set_lang(Pkt, _) ->
|
|
|
|
Pkt.
|
|
|
|
|
|
|
|
lists_intersection(L1, L2) ->
|
|
|
|
lists:filter(
|
|
|
|
fun(E) ->
|
|
|
|
lists:member(E, L2)
|
|
|
|
end, L1).
|
|
|
|
|
|
|
|
identity(Props) ->
|
|
|
|
case proplists:get_value(authzid, Props, <<>>) of
|
|
|
|
<<>> -> proplists:get_value(username, Props, <<>>);
|
|
|
|
AuthzId -> AuthzId
|
|
|
|
end.
|