From cf87c5664f3abb9be035b92d762e6380985684cf Mon Sep 17 00:00:00 2001 From: Evgeniy Khramtsov Date: Sat, 31 Dec 2016 13:48:55 +0300 Subject: [PATCH] Reflect cyrsasl API changes in remaining code --- src/ejabberd_c2s.erl | 6 +- src/mod_s2s_dialback.erl | 27 ++++++- src/mod_sm.erl | 53 +++++++------- src/xmpp_stream_in.erl | 154 ++++++++++++++++++++++++--------------- src/xmpp_stream_out.erl | 58 +++++++++++---- src/xmpp_stream_pkix.erl | 39 +++++++--- 6 files changed, 220 insertions(+), 117 deletions(-) diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index f22960c50..a10ee59a5 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -221,7 +221,7 @@ process_closed(State, Reason) -> process_terminated(#{socket := Socket, jid := JID} = State, Reason) -> Status = format_reason(State, Reason), - ?INFO_MSG("(~s) Closing c2s connection for ~s: ~s", + ?INFO_MSG("(~s) Closing c2s session for ~s: ~s", [ejabberd_socket:pp(Socket), jid:to_string(JID), Status]), Pres = #presence{type = unavailable, status = xmpp:mk_text(Status), @@ -292,12 +292,12 @@ bind(R, #{user := U, server := S, access := Access, lang := Lang, State1 = open_session(State#{resource => Resource}), State2 = ejabberd_hooks:run_fold( c2s_session_opened, LServer, State1, []), - ?INFO_MSG("(~s) Opened session for ~s", + ?INFO_MSG("(~s) Opened c2s session for ~s", [ejabberd_socket:pp(Socket), jid:to_string(JID)]), {ok, State2}; deny -> ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]), - ?INFO_MSG("(~s) Forbidden session for ~s", + ?INFO_MSG("(~s) Forbidden c2s session for ~s", [ejabberd_socket:pp(Socket), jid:to_string(JID)]), Txt = <<"Denied by ACL">>, {error, xmpp:err_not_allowed(Txt, Lang), State} diff --git a/src/mod_s2s_dialback.erl b/src/mod_s2s_dialback.erl index 4bdda2ca7..d0d78a30c 100644 --- a/src/mod_s2s_dialback.erl +++ b/src/mod_s2s_dialback.erl @@ -29,7 +29,7 @@ -export([start/2, stop/1, depends/2, mod_opt_type/1]). %% Hooks -export([s2s_out_auth_result/2, s2s_out_downgraded/2, - s2s_in_packet/2, s2s_out_packet/2, + s2s_in_packet/2, s2s_out_packet/2, s2s_in_recv/3, s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]). -include("ejabberd.hrl"). @@ -52,6 +52,8 @@ start(Host, _Opts) -> s2s_in_features, 50), ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE, s2s_in_features, 50), + ejabberd_hooks:add(s2s_in_handle_recv, Host, ?MODULE, + s2s_in_recv, 50), ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE, s2s_in_packet, 50), ejabberd_hooks:add(s2s_in_authenticated_packet, Host, ?MODULE, @@ -71,6 +73,8 @@ stop(Host) -> s2s_in_features, 50), ejabberd_hooks:delete(s2s_in_post_auth_features, Host, ?MODULE, s2s_in_features, 50), + ejabberd_hooks:delete(s2s_in_handle_recv, Host, ?MODULE, + s2s_in_recv, 50), ejabberd_hooks:delete(s2s_in_unauthenticated_packet, Host, ?MODULE, s2s_in_packet, 50), ejabberd_hooks:delete(s2s_in_authenticated_packet, Host, ?MODULE, @@ -191,6 +195,25 @@ s2s_in_packet(State, Pkt) when is_record(Pkt, db_result); s2s_in_packet(State, _) -> State. +s2s_in_recv(State, El, {error, Why}) -> + case xmpp:get_name(El) of + Tag when Tag == <<"db:result">>; + Tag == <<"db:verify">> -> + case xmpp:get_type(El) of + T when T /= <<"valid">>, + T /= <<"invalid">>, + T /= <<"error">> -> + Err = xmpp:make_error(El, mk_error({codec_error, Why})), + {stop, ejabberd_s2s_in:send(State, Err)}; + _ -> + State + end; + _ -> + State + end; +s2s_in_recv(State, _El, _Pkt) -> + State. + s2s_out_packet(#{server := LServer, remote_server := RServer, db_verify := {StreamID, _Key, Pid}} = State, @@ -286,6 +309,8 @@ mk_error(forbidden) -> xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG); mk_error(host_unknown) -> xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG); +mk_error({codec_error, Why}) -> + xmpp:err_bad_request(xmpp:io_format_error(Why), ?MYLANG); mk_error({_Class, _Reason} = Why) -> Txt = xmpp_stream_out:format_error(Why), xmpp:err_remote_server_not_found(Txt, ?MYLANG); diff --git a/src/mod_sm.erl b/src/mod_sm.erl index 703234419..7e64e6a00 100644 --- a/src/mod_sm.erl +++ b/src/mod_sm.erl @@ -179,16 +179,14 @@ c2s_handle_recv(#{lang := Lang} = State, El, {error, Why}) -> c2s_handle_recv(State, _, _) -> State. -c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result) +c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, _Result) when MgmtState == pending; MgmtState == active -> State1 = mgmt_queue_add(State, Pkt), - case Result of - ok when ?is_stanza(Pkt) -> + case xmpp:is_stanza(Pkt) of + true -> send_rack(State1); - ok -> - State1; - {error, _} -> - transition_to_pending(State1) + false -> + State1 end; c2s_handle_send(State, _Pkt, _Result) -> State. @@ -210,8 +208,9 @@ c2s_handle_info(#{mgmt_ack_timer := TRef, jid := JID} = State, {timeout, TRef, ack_timeout}) -> ?DEBUG("Timed out waiting for stream management acknowledgement of ~s", [jid:to_string(JID)]), - State1 = ejabberd_c2s:close(State, _SendTrailer = false), - {stop, transition_to_pending(State1)}; + State1 = State#{stop_reason => {socket, timeout}}, + State2 = ejabberd_c2s:close(State1, _SendTrailer = false), + {stop, transition_to_pending(State2)}; c2s_handle_info(#{mgmt_state := pending, jid := JID} = State, {timeout, _, pending_timeout}) -> ?DEBUG("Timed out waiting for resumption of stream for ~s", @@ -222,8 +221,8 @@ c2s_handle_info(State, _) -> c2s_closed(State, {stream, _}) -> State; -c2s_closed(#{mgmt_state := active} = State, Reason) -> - {stop, transition_to_pending(State#{stop_reason => Reason})}; +c2s_closed(#{mgmt_state := active} = State, _Reason) -> + {stop, transition_to_pending(State)}; c2s_closed(State, _Reason) -> State. @@ -368,10 +367,9 @@ transition_to_pending(#{mgmt_state := active, jid := JID, lserver := LServer, mgmt_timeout := Timeout} = State) -> State1 = cancel_ack_timer(State), ?INFO_MSG("Waiting for resumption of stream for ~s", [jid:to_string(JID)]), - State2 = ejabberd_hooks:run_fold(c2s_session_pending, LServer, State1, []), - State3 = ejabberd_c2s:close(State2, _SendTrailer = false), erlang:start_timer(timer:seconds(Timeout), self(), pending_timeout), - State3#{mgmt_state => pending}; + State2 = State1#{mgmt_state => pending}, + ejabberd_hooks:run_fold(c2s_session_pending, LServer, State2, []); transition_to_pending(State) -> State. @@ -405,8 +403,8 @@ update_num_stanzas_in(State, _El) -> send_rack(#{mgmt_ack_timer := _} = State) -> State; send_rack(#{mgmt_xmlns := Xmlns, - mgmt_stanzas_out := NumStanzasOut, - mgmt_ack_timeout := AckTimeout} = State) -> + mgmt_stanzas_out := NumStanzasOut, + mgmt_ack_timeout := AckTimeout} = State) -> State1 = send(State, #sm_r{xmlns = Xmlns}), TRef = erlang:start_timer(AckTimeout, self(), ack_timeout), State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}. @@ -425,16 +423,19 @@ resend_rack(State) -> -spec mgmt_queue_add(state(), xmpp_element()) -> state(). mgmt_queue_add(#{mgmt_stanzas_out := NumStanzasOut, - mgmt_queue := Queue} = State, Stanza) when ?is_stanza(Stanza) -> - NewNum = case NumStanzasOut of - 4294967295 -> 0; - Num -> Num + 1 - end, - Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Stanza}, Queue), - State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum}, - check_queue_length(State1); -mgmt_queue_add(State, _Nonza) -> - State. + mgmt_queue := Queue} = State, Pkt) -> + case xmpp:is_stanza(Pkt) of + true -> + NewNum = case NumStanzasOut of + 4294967295 -> 0; + Num -> Num + 1 + end, + Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Pkt}, Queue), + State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum}, + check_queue_length(State1); + false -> + State + end. -spec mgmt_queue_drop(state(), non_neg_integer()) -> state(). mgmt_queue_drop(#{mgmt_queue := Queue} = State, NumHandled) -> diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl index a03870643..1ad78d45b 100644 --- a/src/xmpp_stream_in.erl +++ b/src/xmpp_stream_in.erl @@ -42,7 +42,7 @@ -include("xmpp.hrl"). -type state() :: map(). --type stop_reason() :: {stream, reset | stream_error()} | +-type stop_reason() :: {stream, reset | {in | out, stream_error()}} | {tls, term()} | {socket, inet:posix() | closed | timeout} | internal_failure. @@ -188,8 +188,10 @@ format_error({socket, Reason}) -> format("Connection failed: ~s", [format_inet_error(Reason)]); format_error({stream, reset}) -> <<"Stream reset by peer">>; -format_error({stream, #stream_error{reason = Reason, text = Txt}}) -> - format("Stream failed: ~s", [format_stream_error(Reason, Txt)]); +format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) -> + format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]); +format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) -> + format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]); format_error({tls, Reason}) -> format("TLS failed: ~w", [Reason]); format_error(internal_failure) -> @@ -304,7 +306,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> send_element(State1, Err) end); handle_info({'$gen_event', {xmlstreamelement, El}}, - #{xmlns := NS, lang := MyLang, mod := Mod} = State) -> + #{xmlns := NS, mod := Mod} = State) -> noreply( try xmpp:decode(El, NS, [ignore_els]) of Pkt -> @@ -321,10 +323,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}}, end, case is_disconnected(State1) of true -> State1; - false -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - send_error(State1, El, xmpp:err_bad_request(Txt, Lang)) + false -> process_invalid_xml(State1, El, Why) end end); handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, @@ -394,6 +393,33 @@ peername(SockMod, Socket) -> _ -> SockMod:peername(Socket) end. +-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state(). +process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> + case xmpp:is_stanza(El) of + true -> + Txt = xmpp:io_format_error(Reason), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + send_error(State, El, xmpp:err_bad_request(Txt, Lang)); + false -> + case {xmpp:get_name(El), xmpp:get_ns(El)} of + {Tag, ?NS_SASL} when Tag == <<"auth">>; + Tag == <<"response">>; + Tag == <<"abort">> -> + Txt = xmpp:io_format_error(Reason), + Err = #sasl_failure{reason = 'malformed-request', + text = xmpp:mk_text(Txt, MyLang)}, + send_element(State, Err); + {<<"starttls">>, ?NS_TLS} -> + send_element(State, #starttls_failure{}); + {<<"compress">>, ?NS_COMPRESS} -> + Err = #compress_failure{reason = 'setup-failed'}, + send_element(State, Err); + _ -> + %% Maybe add something more? + State + end + end. + -spec process_stream_end(stop_reason(), state()) -> state(). process_stream_end(_, #{stream_state := disconnected} = State) -> State; @@ -423,11 +449,6 @@ process_stream(#stream_start{lang = Lang}, process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) -> Txt = <<"Missing 'to' attribute">>, send_element(State, xmpp:serr_improper_addressing(Txt, Lang)); -process_stream(#stream_start{from = undefined, version = {1,0}}, - #{lang := Lang, xmlns := ?NS_SERVER, - stream_encrypted := true} = State) -> - Txt = <<"Missing 'from' attribute">>, - send_element(State, xmpp:serr_invalid_from(Txt, Lang)); process_stream(#stream_start{to = #jid{luser = U, lresource = R}}, #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> -> Txt = <<"Improper 'to' attribute">>, @@ -450,9 +471,10 @@ process_stream(#stream_start{to = #jid{server = Server, lserver = LServer}, true -> State end, - State2 = if NS == ?NS_SERVER andalso Encrypted -> - State1#{remote_server => From#jid.lserver}; - true -> + State2 = case From of + #jid{lserver = RemoteServer} when NS == ?NS_SERVER -> + State1#{remote_server => RemoteServer}; + _ -> State1 end, State3 = try Mod:handle_stream_start(StreamStart, State2) @@ -517,7 +539,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> #handshake{} -> State; #stream_error{} -> - process_stream_end({stream, Pkt}, State); + process_stream_end({stream, {in, Pkt}}, State); _ when StateName == wait_for_sasl_request; StateName == wait_for_handshake; StateName == wait_for_sasl_response -> @@ -707,35 +729,34 @@ process_starttls_failure(Why, State) -> -spec process_sasl_request(sasl_auth(), state()) -> state(). process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, #{mod := Mod, lserver := LServer} = State) -> - GetPW = try Mod:get_password_fun(State) - catch _:undef -> fun(_) -> false end - end, - CheckPW = try Mod:check_password_fun(State) - catch _:undef -> fun(_, _, _) -> false end - end, - CheckPWDigest = try Mod:check_password_digest_fun(State) - catch _:undef -> fun(_, _, _, _, _) -> false end - end, - SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [], - GetPW, CheckPW, CheckPWDigest), - State1 = State#{sasl_state => SASLState, sasl_mech => Mech}, + State1 = State#{sasl_mech => Mech}, Mechs = get_sasl_mechanisms(State1), - SASLResult = case lists:member(Mech, Mechs) of - true when Mech == <<"EXTERNAL">> -> - case xmpp_stream_pkix:authenticate(State1, ClientIn) of - {ok, Peer} -> - {ok, [{auth_module, pkix}, - {username, Peer}]}; - {error, _Reason, Peer} -> - %% TODO: return meaningful error - {error, 'not-authorized', Peer} - end; - true -> - cyrsasl:server_start(SASLState, Mech, ClientIn); - false -> - {error, 'invalid-mechanism'} - end, - process_sasl_result(SASLResult, State1). + case lists:member(Mech, Mechs) of + true when Mech == <<"EXTERNAL">> -> + Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of + {ok, Peer} -> + {ok, [{auth_module, pkix}, {username, Peer}]}; + {error, Reason, Peer} -> + {error, Reason, Peer} + end, + process_sasl_result(Res, State1); + true -> + GetPW = try Mod:get_password_fun(State1) + catch _:undef -> fun(_) -> false end + end, + CheckPW = try Mod:check_password_fun(State1) + catch _:undef -> fun(_, _, _) -> false end + end, + CheckPWDigest = try Mod:check_password_digest_fun(State1) + catch _:undef -> fun(_, _, _, _, _) -> false end + end, + SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [], + GetPW, CheckPW, CheckPWDigest), + Res = cyrsasl:server_start(SASLState, Mech, ClientIn), + process_sasl_result(Res, State1#{sasl_state => SASLState}); + false -> + process_sasl_result({error, unsupported_mechanism, <<"">>}, State1) + end. -spec process_sasl_response(sasl_response(), state()) -> state(). process_sasl_response(#sasl_response{text = ClientIn}, @@ -751,9 +772,7 @@ process_sasl_result({ok, Props, ServerOut}, State) -> process_sasl_result({continue, ServerOut, NewSASLState}, State) -> process_sasl_continue(ServerOut, NewSASLState, State); process_sasl_result({error, Reason, User}, State) -> - process_sasl_failure(Reason, User, State); -process_sasl_result({error, Reason}, State) -> - process_sasl_failure(Reason, <<"">>, State). + process_sasl_failure(Reason, User, State). -spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state(). process_sasl_success(Props, ServerOut, @@ -790,18 +809,20 @@ process_sasl_continue(ServerOut, NewSASLState, State) -> send_element(State1, #sasl_challenge{text = ServerOut}). -spec process_sasl_failure(atom(), binary(), state()) -> state(). -process_sasl_failure(Reason, User, - #{mod := Mod, sasl_mech := Mech} = State) -> - State1 = try Mod:handle_auth_failure(User, Mech, Reason, State) +process_sasl_failure(Err, User, + #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) -> + {Reason, Text} = format_sasl_error(Mech, Err), + State1 = try Mod:handle_auth_failure(User, Mech, Text, State) catch _:undef -> State end, State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)), State3 = State2#{stream_state => wait_for_sasl_request}, - send_element(State3, #sasl_failure{reason = Reason}). + send_element(State3, #sasl_failure{reason = Reason, + text = xmpp:mk_text(Text, Lang)}). -spec process_sasl_abort(state()) -> state(). process_sasl_abort(State) -> - process_sasl_failure('aborted', <<"">>, State). + process_sasl_failure(aborted, <<"">>, State). -spec send_features(state()) -> state(). send_features(#{stream_version := {1,0}, @@ -985,13 +1006,17 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> State1 = try Mod:handle_send(Pkt, Result, State) catch _:undef -> State end, - case Result of - _ when is_record(Pkt, stream_error) -> - process_stream_end({stream, Pkt}, State1); - ok -> - State1; - {error, Why} -> - process_stream_end({socket, Why}, State1) + case is_disconnected(State1) of + true -> State1; + false -> + case Result of + _ when is_record(Pkt, stream_error) -> + process_stream_end({stream, {out, Pkt}}, State1); + ok -> + State1; + {error, Why} -> + process_stream_end({socket, Why}, State1) + end end. -spec send_error(state(), xmpp_element(), stanza_error()) -> state(). @@ -1025,6 +1050,8 @@ send_text(_, _) -> {error, closed}. -spec close_socket(state()) -> state(). +close_socket(#{stream_state := disconnected} = State) -> + State; close_socket(#{sockmod := SockMod, socket := Socket} = State) -> SockMod:close(Socket), State#{stream_timeout => infinity, @@ -1052,6 +1079,7 @@ format_inet_error(Reason) -> -spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string(). format_stream_error(Reason, Txt) -> Slogan = case Reason of + undefined -> "no reason"; #'see-other-host'{} -> "see-other-host"; _ -> atom_to_list(Reason) end, @@ -1062,6 +1090,12 @@ format_stream_error(Reason, Txt) -> binary_to_list(Data) ++ " (" ++ Slogan ++ ")" end. +-spec format_sasl_error(cyrsasl:mechanism(), atom()) -> {atom(), binary()}. +format_sasl_error(<<"EXTERNAL">>, Err) -> + xmpp_stream_pkix:format_error(Err); +format_sasl_error(Mech, Err) -> + cyrsasl:format_error(Mech, Err). + -spec format(io:format(), list()) -> binary(). format(Fmt, Args) -> iolist_to_binary(io_lib:format(Fmt, Args)). diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl index 08804e432..290a92a49 100644 --- a/src/xmpp_stream_out.erl +++ b/src/xmpp_stream_out.erl @@ -1,10 +1,23 @@ %%%------------------------------------------------------------------- -%%% @author Evgeny Khramtsov -%%% @copyright (C) 2016, Evgeny Khramtsov -%%% @doc -%%% -%%% @end %%% Created : 14 Dec 2016 by Evgeny Khramtsov +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% %%%------------------------------------------------------------------- -module(xmpp_stream_out). -behaviour(gen_server). @@ -39,7 +52,7 @@ -type network_error() :: {error, inet:posix() | inet_res:res_error()}. -type stop_reason() :: {idna, bad_string} | {dns, inet:posix() | inet_res:res_error()} | - {stream, reset | stream_error()} | + {stream, reset | {in | out, stream_error()}} | {tls, term()} | {pkix, binary()} | {auth, atom() | binary() | string()} | @@ -135,7 +148,7 @@ change_shaper(_, _) -> -spec format_error(stop_reason()) -> binary(). format_error({idna, _}) -> - <<"Not an IDN hostname">>; + <<"Remote domain is not an IDN hostname">>; format_error({dns, Reason}) -> format("DNS lookup failed: ~s", [format_inet_error(Reason)]); format_error({socket, Reason}) -> @@ -144,8 +157,10 @@ format_error({pkix, Reason}) -> format("Peer certificate rejected: ~s", [Reason]); format_error({stream, reset}) -> <<"Stream reset by peer">>; -format_error({stream, #stream_error{reason = Reason, text = Txt}}) -> - format("Stream failed: ~s", [format_stream_error(Reason, Txt)]); +format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) -> + format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]); +format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) -> + format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]); format_error({tls, Reason}) -> format("TLS failed: ~w", [Reason]); format_error({auth, Reason}) -> @@ -264,7 +279,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> send_element(State1, Err) end); handle_info({'$gen_event', {xmlstreamelement, El}}, - #{xmlns := NS, lang := MyLang, mod := Mod} = State) -> + #{xmlns := NS, mod := Mod} = State) -> noreply( try xmpp:decode(El, NS, [ignore_els]) of Pkt -> @@ -281,10 +296,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}}, end, case is_disconnected(State1) of true -> State1; - false -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - send_error(State1, El, xmpp:err_bad_request(Txt, Lang)) + false -> process_invalid_xml(State1, El, Why) end end); handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, @@ -347,6 +359,17 @@ new_id() -> is_disconnected(#{stream_state := StreamState}) -> StreamState == disconnected. +-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state(). +process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> + case xmpp:is_stanza(El) of + true -> + Txt = xmpp:io_format_error(Reason), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + send_error(State, El, xmpp:err_bad_request(Txt, Lang)); + false -> + State + end. + -spec process_stream_end(stop_reason(), state()) -> state(). process_stream_end(_, #{stream_state := disconnected} = State) -> State; @@ -394,7 +417,7 @@ process_element(Pkt, #{stream_state := StateName} = State) -> #sasl_failure{} when StateName == wait_for_sasl_response -> process_sasl_failure(Pkt, State); #stream_error{} -> - process_stream_end({stream, Pkt}, State); + process_stream_end({stream, {in, Pkt}}, State); _ when is_record(Pkt, stream_features); is_record(Pkt, starttls_proceed); is_record(Pkt, starttls); @@ -612,7 +635,7 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> false -> case send_text(State1, Data) of _ when is_record(Pkt, stream_error) -> - process_stream_end({stream, Pkt}, State1); + process_stream_end({stream, {out, Pkt}}, State1); ok -> State1; {error, Why} -> @@ -650,6 +673,8 @@ send_trailer(State) -> close_socket(State). -spec close_socket(state()) -> state(). +close_socket(#{stream_state := disconnected} = State) -> + State; close_socket(State) -> case State of #{sockmod := SockMod, socket := Socket} -> @@ -674,6 +699,7 @@ format_inet_error(Reason) -> -spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string(). format_stream_error(Reason, Txt) -> Slogan = case Reason of + undefined -> "no reason"; #'see-other-host'{} -> "see-other-host"; _ -> atom_to_list(Reason) end, diff --git a/src/xmpp_stream_pkix.erl b/src/xmpp_stream_pkix.erl index 59f5d820e..5d64c5eb6 100644 --- a/src/xmpp_stream_pkix.erl +++ b/src/xmpp_stream_pkix.erl @@ -9,7 +9,7 @@ -module(xmpp_stream_pkix). %% API --export([authenticate/1, authenticate/2]). +-export([authenticate/1, authenticate/2, format_error/1]). -include("xmpp.hrl"). -include_lib("public_key/include/public_key.hrl"). @@ -19,21 +19,24 @@ %%% API %%%=================================================================== -spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state()) - -> {ok, binary()} | {error, binary(), binary()}. + -> {ok, binary()} | {error, atom(), binary()}. authenticate(State) -> authenticate(State, <<"">>). -spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state(), binary()) - -> {ok, binary()} | {error, binary(), binary()}. -authenticate(#{xmlns := ?NS_SERVER, remote_server := Peer, - sockmod := SockMod, socket := Socket}, _Authzid) -> + -> {ok, binary()} | {error, atom(), binary()}. +authenticate(#{xmlns := ?NS_SERVER, sockmod := SockMod, + socket := Socket} = State, Authzid) -> + Peer = try maps:get(remote_server, State) + catch _:{badkey, _} -> Authzid + end, case SockMod:get_peer_certificate(Socket) of {ok, Cert} -> case SockMod:get_verify_result(Socket) of 0 -> case ejabberd_idna:domain_utf8_to_ascii(Peer) of false -> - {error, <<"Cannot decode remote server name">>, Peer}; + {error, idna_failed, Peer}; AsciiPeer -> case lists:any( fun(D) -> match_domain(AsciiPeer, D) end, @@ -41,20 +44,34 @@ authenticate(#{xmlns := ?NS_SERVER, remote_server := Peer, true -> {ok, Peer}; false -> - {error, <<"Certificate host name mismatch">>, Peer} + {error, hostname_mismatch, Peer} end end; VerifyRes -> - {error, fast_tls:get_cert_verify_string(VerifyRes, Cert), Peer} + %% TODO: return atomic errors + %% This should be improved in fast_tls + Reason = fast_tls:get_cert_verify_string(VerifyRes, Cert), + {error, erlang:binary_to_atom(Reason, utf8), Peer} end; {error, _Reason} -> - {error, <<"Cannot get peer certificate">>, Peer}; + {error, get_cert_failed, Peer}; error -> - {error, <<"Cannot get peer certificate">>, Peer} + {error, get_cert_failed, Peer} end; authenticate(_State, _Authzid) -> %% TODO: client PKIX authentication - {error, <<"Client certificate verification not implemented">>, <<"">>}. + {error, client_not_supported, <<"">>}. + +format_error(idna_failed) -> + {'bad-protocol', <<"Remote domain is not an IDN hostname">>}; +format_error(hostname_mismatch) -> + {'not-authorized', <<"Certificate host name mismatch">>}; +format_error(get_cert_failed) -> + {'bad-protocol', <<"Failed to get peer certificate">>}; +format_error(client_not_supported) -> + {'invalid-mechanism', <<"Client certificate verification is not supported">>}; +format_error(Other) -> + {'not-authorized', erlang:atom_to_binary(Other, utf8)}. %%%=================================================================== %%% Internal functions