diff --git a/src/cyrsasl.erl b/src/cyrsasl.erl index 1edf44678..5c7eb7edb 100644 --- a/src/cyrsasl.erl +++ b/src/cyrsasl.erl @@ -25,13 +25,11 @@ -module(cyrsasl). --behaviour(ejabberd_config). - -author('alexey@process-one.net'). -export([start/0, register_mechanism/3, listmech/1, server_new/7, server_start/3, server_step/2, - get_mech/1, format_error/2, opt_type/1]). + get_mech/1, format_error/2]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -113,15 +111,9 @@ format_error(Mech, Reason) -> PasswordType :: password_type()) -> any(). register_mechanism(Mechanism, Module, PasswordType) -> - case is_disabled(Mechanism) of - false -> - ets:insert(sasl_mechanism, - #sasl_mechanism{mechanism = Mechanism, module = Module, - password_type = PasswordType}); - true -> - ?DEBUG("SASL mechanism ~p is disabled", [Mechanism]), - true - end. + ets:insert(sasl_mechanism, + #sasl_mechanism{mechanism = Mechanism, module = Module, + password_type = PasswordType}). check_credentials(_State, Props) -> User = proplists:get_value(authzid, Props, <<>>), @@ -134,20 +126,19 @@ check_credentials(_State, Props) -> -spec listmech(Host ::binary()) -> Mechanisms::mechanisms(). listmech(Host) -> - Mechs = ets:select(sasl_mechanism, - [{#sasl_mechanism{mechanism = '$1', - password_type = '$2', _ = '_'}, - case catch ejabberd_auth:store_type(Host) of - external -> [{'==', '$2', plain}]; - scram -> [{'/=', '$2', digest}]; - {'EXIT', {undef, [{Module, store_type, []} | _]}} -> - ?WARNING_MSG("~p doesn't implement the function store_type/0", - [Module]), - []; - _Else -> [] - end, - ['$1']}]), - filter_anonymous(Host, Mechs). + ets:select(sasl_mechanism, + [{#sasl_mechanism{mechanism = '$1', + password_type = '$2', _ = '_'}, + case catch ejabberd_auth:store_type(Host) of + external -> [{'==', '$2', plain}]; + scram -> [{'/=', '$2', digest}]; + {'EXIT', {undef, [{Module, store_type, []} | _]}} -> + ?WARNING_MSG("~p doesn't implement the function store_type/0", + [Module]), + []; + _Else -> [] + end, + ['$1']}]). -spec server_new(binary(), binary(), binary(), term(), fun(), fun(), fun()) -> sasl_state(). @@ -206,33 +197,3 @@ server_step(State, ClientIn) -> -spec get_mech(sasl_state()) -> binary(). get_mech(#sasl_state{mech_name = Mech}) -> Mech. - -%% Remove the anonymous mechanism from the list if not enabled for the given -%% host -%% --spec filter_anonymous(Host :: binary(), Mechs :: mechanisms()) -> mechanisms(). - -filter_anonymous(Host, Mechs) -> - case ejabberd_auth_anonymous:is_sasl_anonymous_enabled(Host) of - true -> Mechs; - false -> Mechs -- [<<"ANONYMOUS">>] - end. - --spec is_disabled(Mechanism :: mechanism()) -> boolean(). - -is_disabled(Mechanism) -> - Disabled = ejabberd_config:get_option( - disable_sasl_mechanisms, - fun(V) when is_list(V) -> - lists:map(fun(M) -> str:to_upper(M) end, V); - (V) -> - [str:to_upper(V)] - end, []), - lists:member(Mechanism, Disabled). - -opt_type(disable_sasl_mechanisms) -> - fun (V) when is_list(V) -> - lists:map(fun (M) -> str:to_upper(M) end, V); - (V) -> [str:to_upper(V)] - end; -opt_type(_) -> [disable_sasl_mechanisms]. diff --git a/src/cyrsasl_digest.erl b/src/cyrsasl_digest.erl index 9b4faca20..39055f2b1 100644 --- a/src/cyrsasl_digest.erl +++ b/src/cyrsasl_digest.erl @@ -59,7 +59,7 @@ start(_Opts) -> Fqdn = get_local_fqdn(), - ?INFO_MSG("FQDN used to check DIGEST-MD5 SASL authentication: ~p", + ?INFO_MSG("FQDN used to check DIGEST-MD5 SASL authentication: ~s", [Fqdn]), cyrsasl:register_mechanism(<<"DIGEST-MD5">>, ?MODULE, digest). diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index a10ee59a5..007a94dc9 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -33,9 +33,9 @@ %% xmpp_stream_in callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([tls_options/1, tls_required/1, tls_verify/1, - compress_methods/1, bind/2, get_password_fun/1, - check_password_fun/1, check_password_digest_fun/1, +-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, + compress_methods/1, bind/2, sasl_mechanisms/2, + get_password_fun/1, check_password_fun/1, check_password_digest_fun/1, unauthenticated_stream_features/1, authenticated_stream_features/1, handle_stream_start/2, handle_stream_end/2, handle_unauthenticated_packet/2, handle_authenticated_packet/2, @@ -47,7 +47,7 @@ process_terminated/2, process_info/2]). %% API -export([get_presence/1, get_subscription/2, get_subscribed/1, - open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1, + open_session/1, call/3, send/2, close/1, close/2, stop/1, reply/2, copy_state/2, set_timeout/2, add_hooks/1]). -include("ejabberd.hrl"). @@ -73,6 +73,9 @@ start_link(SockData, Opts) -> socket_type() -> xml_stream. +%%%=================================================================== +%%% Common API +%%%=================================================================== -spec call(pid(), term(), non_neg_integer() | infinity) -> term(). call(Ref, Msg, Timeout) -> xmpp_stream_in:call(Ref, Msg, Timeout). @@ -116,19 +119,16 @@ stop(Ref) -> send(Pid, Pkt) when is_pid(Pid) -> xmpp_stream_in:send(Pid, Pkt); send(#{lserver := LServer} = State, Pkt) -> - case ejabberd_hooks:run_fold(c2s_filter_send, LServer, {Pkt, State}, []) of + Pkt1 = fix_from_to(Pkt, State), + case ejabberd_hooks:run_fold(c2s_filter_send, LServer, {Pkt1, State}, []) of {drop, State1} -> State1; - {Pkt1, State1} -> xmpp_stream_in:send(State1, Pkt1) + {Pkt2, State1} -> xmpp_stream_in:send(State1, Pkt2) end. -spec set_timeout(state(), timeout()) -> state(). set_timeout(State, Timeout) -> xmpp_stream_in:set_timeout(State, Timeout). --spec establish(state()) -> state(). -establish(State) -> - xmpp_stream_in:establish(State). - -spec add_hooks(binary()) -> ok. add_hooks(Host) -> ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100), @@ -162,7 +162,7 @@ copy_state(#{owner := Owner} = NewState, auth_module => AuthModule, pres_t => PresT, pres_a => PresA, pres_f => PresF}, - ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]). + ejabberd_hooks:run_fold(c2s_copy_session, LServer, State2, [OldState]). -spec open_session(state()) -> {ok, state()} | state(). open_session(#{user := U, server := S, resource := R, @@ -195,14 +195,22 @@ process_info(#{lserver := LServer} = State, process_iq_in(State, Packet) end, if Pass -> - Packet1 = ejabberd_hooks:run_fold( - user_receive_packet, LServer, Packet, [State1]), - ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), - send(State1, Packet1); + {Packet1, State2} = ejabberd_hooks:run_fold( + user_receive_packet, LServer, + {Packet, State1}, []), + case Packet1 of + drop -> State2; + _ -> send(State2, Packet1) + end; true -> - ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), State1 end; +process_info(State, force_update_presence) -> + try maps:get(pres_last, State) of + Pres -> process_self_presence(State, Pres) + catch _:{badkey, _} -> + State + end; process_info(State, Info) -> ?WARNING_MSG("got unexpected info: ~p", [Info]), State. @@ -218,15 +226,21 @@ reject_unauthenticated_packet(State, Pkt) -> process_closed(State, Reason) -> stop(State#{stop_reason => Reason}). -process_terminated(#{socket := Socket, jid := JID} = State, +process_terminated(#{sockmod := SockMod, socket := Socket, jid := JID} = State, Reason) -> Status = format_reason(State, Reason), ?INFO_MSG("(~s) Closing c2s session for ~s: ~s", - [ejabberd_socket:pp(Socket), jid:to_string(JID), Status]), - Pres = #presence{type = unavailable, - status = xmpp:mk_text(Status), - from = JID, to = jid:remove_resource(JID)}, - State1 = broadcast_presence_unavailable(State, Pres), + [SockMod:pp(Socket), jid:to_string(JID), Status]), + State1 = case maps:is_key(pres_last, State) of + true -> + Pres = #presence{type = unavailable, + status = xmpp:mk_text(Status), + from = JID, + to = jid:remove_resource(JID)}, + broadcast_presence_unavailable(State, Pres); + false -> + State + end, bounce_message_queue(), State1; process_terminated(State, _Reason) -> @@ -235,13 +249,51 @@ process_terminated(State, _Reason) -> %%%=================================================================== %%% xmpp_stream_in callbacks %%%=================================================================== -tls_options(#{lserver := LServer, tls_options := TLSOpts}) -> - case ejabberd_config:get_option({domain_certfile, LServer}, - fun iolist_to_binary/1) of - undefined -> - TLSOpts; - CertFile -> - lists:keystore(certfile, 1, TLSOpts, {certfile, CertFile}) +tls_options(#{lserver := LServer, tls_options := DefaultOpts}) -> + TLSOpts1 = case ejabberd_config:get_option( + {c2s_certfile, LServer}, + fun iolist_to_binary/1, + ejabberd_config:get_option( + {domain_certfile, LServer}, + fun iolist_to_binary/1)) of + undefined -> []; + CertFile -> lists:keystore(certfile, 1, DefaultOpts, + {certfile, CertFile}) + end, + TLSOpts2 = case ejabberd_config:get_option( + {c2s_ciphers, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts1; + Ciphers -> lists:keystore(ciphers, 1, TLSOpts1, + {ciphers, Ciphers}) + end, + TLSOpts3 = case ejabberd_config:get_option( + {c2s_protocol_options, LServer}, + fun (Options) -> str:join(Options, <<$|>>) end) of + undefined -> TLSOpts2; + ProtoOpts -> lists:keystore(protocol_options, 1, TLSOpts2, + {protocol_options, ProtoOpts}) + end, + TLSOpts4 = case ejabberd_config:get_option( + {c2s_dhfile, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts3; + DHFile -> lists:keystore(dhfile, 1, TLSOpts3, + {dhfile, DHFile}) + end, + TLSOpts5 = case ejabberd_config:get_option( + {c2s_cafile, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts4; + CAFile -> lists:keystore(cafile, 1, TLSOpts4, + {cafile, CAFile}) + end, + case ejabberd_config:get_option( + {c2s_tls_compression, LServer}, + fun(B) when is_boolean(B) -> B end) of + undefined -> TLSOpts5; + false -> [compression_none | TLSOpts5]; + true -> lists:delete(compression_none, TLSOpts5) end. tls_required(#{tls_required := TLSRequired}) -> @@ -250,6 +302,11 @@ tls_required(#{tls_required := TLSRequired}) -> tls_verify(#{tls_verify := TLSVerify}) -> TLSVerify. +tls_enabled(#{tls_enabled := TLSEnabled, + tls_required := TLSRequired, + tls_verify := TLSVerify}) -> + TLSEnabled or TLSRequired or TLSVerify. + compress_methods(#{zlib := true}) -> [<<"zlib">>]; compress_methods(_) -> @@ -261,6 +318,20 @@ unauthenticated_stream_features(#{lserver := LServer}) -> authenticated_stream_features(#{lserver := LServer}) -> ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]). +sasl_mechanisms(Mechs, #{lserver := LServer}) -> + Mechs1 = ejabberd_config:get_option( + {disable_sasl_mechanisms, LServer}, + fun(V) when is_list(V) -> + lists:map(fun(M) -> str:to_upper(M) end, V); + (V) -> + [str:to_upper(V)] + end, []), + Mechs2 = case ejabberd_auth_anonymous:is_sasl_anonymous_enabled(LServer) of + true -> Mechs1; + false -> [<<"ANONYMOUS">>|Mechs1] + end, + Mechs -- Mechs2. + get_password_fun(#{lserver := LServer}) -> fun(U) -> ejabberd_auth:get_password_with_authmodule(U, LServer) @@ -279,7 +350,8 @@ check_password_digest_fun(#{lserver := LServer}) -> bind(<<"">>, State) -> bind(new_uniq_id(), State); bind(R, #{user := U, server := S, access := Access, lang := Lang, - lserver := LServer, socket := Socket, ip := IP} = State) -> + lserver := LServer, sockmod := SockMod, socket := Socket, + ip := IP} = State) -> case resource_conflict_action(U, S, R) of closenew -> {error, xmpp:err_conflict(), State}; @@ -289,38 +361,30 @@ bind(R, #{user := U, server := S, access := Access, lang := Lang, #{usr => jid:split(JID), ip => IP}, LServer) of allow -> - State1 = open_session(State#{resource => Resource}), + State1 = open_session(State#{resource => Resource, + sid => ejabberd_sm:make_sid()}), State2 = ejabberd_hooks:run_fold( c2s_session_opened, LServer, State1, []), ?INFO_MSG("(~s) Opened c2s session for ~s", - [ejabberd_socket:pp(Socket), jid:to_string(JID)]), + [SockMod:pp(Socket), jid:to_string(JID)]), {ok, State2}; deny -> ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]), ?INFO_MSG("(~s) Forbidden c2s session for ~s", - [ejabberd_socket:pp(Socket), jid:to_string(JID)]), + [SockMod:pp(Socket), jid:to_string(JID)]), Txt = <<"Denied by ACL">>, {error, xmpp:err_not_allowed(Txt, Lang), State} end end. -handle_stream_start(StreamStart, - #{lserver := LServer, ip := IP, lang := Lang} = State) -> +handle_stream_start(StreamStart, #{lserver := LServer} = State) -> case ejabberd_router:is_my_host(LServer) of false -> send(State, xmpp:serr_host_unknown()); true -> - case check_bl_c2s(IP, Lang) of - false -> - change_shaper(State), - ejabberd_hooks:run_fold( - c2s_stream_started, LServer, State, [StreamStart]); - {true, LogReason, ReasonT} -> - ?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s", - [jlib:ip_to_list(IP), LogReason]), - Err = xmpp:serr_policy_violation(ReasonT, Lang), - send(State, Err) - end + change_shaper(State), + ejabberd_hooks:run_fold( + c2s_stream_started, LServer, State, [StreamStart]) end. handle_stream_end(Reason, #{lserver := LServer} = State) -> @@ -328,18 +392,20 @@ handle_stream_end(Reason, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_closed, LServer, State1, [Reason]). handle_auth_success(User, Mech, AuthModule, - #{socket := Socket, ip := IP, lserver := LServer} = State) -> + #{socket := Socket, sockmod := SockMod, + ip := IP, lserver := LServer} = State) -> ?INFO_MSG("(~s) Accepted c2s ~s authentication for ~s@~s by ~s backend from ~s", - [ejabberd_socket:pp(Socket), Mech, User, LServer, + [SockMod:pp(Socket), Mech, User, LServer, ejabberd_auth:backend_type(AuthModule), ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), State1 = State#{auth_module => AuthModule}, ejabberd_hooks:run_fold(c2s_auth_result, LServer, State1, [true, User]). handle_auth_failure(User, Mech, Reason, - #{socket := Socket, ip := IP, lserver := LServer} = State) -> + #{socket := Socket, sockmod := SockMod, + ip := IP, lserver := LServer} = State) -> ?INFO_MSG("(~s) Failed c2s ~s authentication ~sfrom ~s: ~s", - [ejabberd_socket:pp(Socket), Mech, + [SockMod:pp(Socket), Mech, if User /= <<"">> -> ["for ", User, "@", LServer, " "]; true -> "" end, @@ -355,17 +421,22 @@ handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) -> handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) -> ejabberd_hooks:run_fold(c2s_authenticated_packet, LServer, State, [Pkt]); -handle_authenticated_packet(Pkt, #{lserver := LServer} = State) -> +handle_authenticated_packet(Pkt, #{lserver := LServer, jid := JID} = State) -> State1 = ejabberd_hooks:run_fold(c2s_authenticated_packet, LServer, State, [Pkt]), - Pkt1 = ejabberd_hooks:run_fold(user_send_packet, LServer, Pkt, [State1]), + #jid{luser = LUser} = JID, + {Pkt1, State2} = ejabberd_hooks:run_fold( + user_send_packet, LServer, {Pkt, State1}, []), case Pkt1 of - #presence{to = #jid{lresource = <<"">>}} -> - process_self_presence(State1, Pkt1); + drop -> + State2; + #presence{to = #jid{luser = LUser, lserver = LServer, + lresource = <<"">>}} -> + process_self_presence(State2, Pkt1); #presence{} -> - process_presence_out(State1, Pkt1); + process_presence_out(State2, Pkt1); _ -> - check_privacy_then_route(State1, Pkt1) + check_privacy_then_route(State2, Pkt1) end. handle_cdata(Data, #{lserver := LServer} = State) -> @@ -381,22 +452,34 @@ handle_send(Pkt, Result, #{lserver := LServer} = State) -> init([State, Opts]) -> Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all), Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none), - TLSOpts = lists:filter( - fun({certfile, _}) -> true; - ({ciphers, _}) -> true; - ({dhfile, _}) -> true; - (_) -> false - end, Opts), + TLSOpts1 = lists:filter( + fun({certfile, _}) -> true; + ({ciphers, _}) -> true; + ({dhfile, _}) -> true; + ({cafile, _}) -> true; + (_) -> false + end, Opts), + TLSOpts2 = case lists:keyfind(protocol_options, 1, Opts) of + false -> TLSOpts1; + {_, OptString} -> + ProtoOpts = str:join(OptString, <<$|>>), + [{protocol_options, ProtoOpts}|TLSOpts1] + end, + TLSOpts3 = case proplists:get_bool(tls_compression, Opts) of + false -> [compression_none | TLSOpts2]; + true -> TLSOpts2 + end, + TLSEnabled = proplists:get_bool(starttls, Opts), TLSRequired = proplists:get_bool(starttls_required, Opts), TLSVerify = proplists:get_bool(tls_verify, Opts), Zlib = proplists:get_bool(zlib, Opts), - State1 = State#{tls_options => TLSOpts, + State1 = State#{tls_options => TLSOpts3, tls_required => TLSRequired, + tls_enabled => TLSEnabled, tls_verify => TLSVerify, pres_a => ?SETS:new(), pres_f => ?SETS:new(), pres_t => ?SETS:new(), - sid => ejabberd_sm:make_sid(), zlib => Zlib, lang => ?MYLANG, server => ?MYNAME, @@ -426,12 +509,12 @@ handle_cast(Msg, #{lserver := LServer} = State) -> handle_info(Info, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]). -terminate(Reason, #{sid := SID, jid := _, +terminate(Reason, #{sid := SID, user := U, server := S, resource := R, lserver := LServer} = State) -> - Status = format_reason(State, Reason), case maps:is_key(pres_last, State) of true -> + Status = format_reason(State, Reason), ejabberd_sm:close_session_unset_presence(SID, U, S, R, Status); false -> ejabberd_sm:close_session(SID, U, S, R) @@ -446,11 +529,6 @@ code_change(_OldVsn, State, _Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== --spec check_bl_c2s({inet:ip_address(), non_neg_integer()}, binary()) - -> false | {true, binary(), binary()}. -check_bl_c2s({IP, _Port}, Lang) -> - ejabberd_hooks:run_fold(check_bl_c2s, false, [IP, Lang]). - -spec process_iq_in(state(), iq()) -> {boolean(), state()}. process_iq_in(State, #iq{} = IQ) -> case privacy_check_packet(State, IQ, in) of @@ -484,7 +562,7 @@ process_presence_in(#{lserver := LServer, pres_a := PresA} = State0, State = ejabberd_hooks:run_fold(c2s_presence_in, LServer, State0, [Pres]), case T of probe -> - NewState = do_some_magic(State, From), + NewState = add_to_pres_a(State, From), route_probe_reply(From, To, NewState), {false, NewState}; error -> @@ -495,7 +573,7 @@ process_presence_in(#{lserver := LServer, pres_a := PresA} = State0, allow when T == error -> {true, State}; allow -> - NewState = do_some_magic(State, From), + NewState = add_to_pres_a(State, From), {true, NewState}; deny -> {false, State} @@ -577,24 +655,27 @@ process_presence_out(#{user := User, server := Server, lserver := LServer, end. -spec process_self_presence(state(), presence()) -> state(). -process_self_presence(#{ip := IP, conn := Conn, +process_self_presence(#{ip := IP, conn := Conn, lserver := LServer, auth_module := AuthMod, sid := SID, user := U, server := S, resource := R} = State, #presence{type = unavailable} = Pres) -> Status = xmpp:get_text(Pres#presence.status), Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}], ejabberd_sm:unset_presence(SID, U, S, R, Status, Info), - State1 = broadcast_presence_unavailable(State, Pres), - maps:remove(pres_last, maps:remove(pres_timestamp, State1)); + {Pres1, State1} = ejabberd_hooks:run_fold( + c2s_self_presence, LServer, {Pres, State}, []), + State2 = broadcast_presence_unavailable(State1, Pres1), + maps:remove(pres_last, maps:remove(pres_timestamp, State2)); process_self_presence(#{lserver := LServer} = State, #presence{type = available} = Pres) -> PreviousPres = maps:get(pres_last, State, undefined), update_priority(State, Pres), - State1 = ejabberd_hooks:run_fold(user_available_hook, LServer, State, [Pres]), - State2 = State1#{pres_last => Pres, + {Pres1, State1} = ejabberd_hooks:run_fold( + c2s_self_presence, LServer, {Pres, State}, []), + State2 = State1#{pres_last => Pres1, pres_timestamp => p1_time_compat:timestamp()}, FromUnavailable = PreviousPres == undefined, - broadcast_presence_available(State2, Pres, FromUnavailable); + broadcast_presence_available(State2, Pres1, FromUnavailable); process_self_presence(State, _Pres) -> State. @@ -614,9 +695,9 @@ broadcast_presence_unavailable(#{pres_a := PresA} = State, Pres) -> -spec broadcast_presence_available(state(), presence(), boolean()) -> state(). broadcast_presence_available(#{pres_a := PresA, pres_f := PresF, - pres_t := PresT} = State, + pres_t := PresT, jid := JID} = State, Pres, _FromUnavailable = true) -> - Probe = #presence{type = probe}, + Probe = #presence{from = JID, type = probe}, TJIDs = filter_blocked(State, Probe, PresT), FJIDs = filter_blocked(State, Pres, PresF), route_multiple(State, TJIDs, Probe), @@ -739,6 +820,19 @@ get_conn_type(State) -> websocket -> websocket end. +-spec fix_from_to(xmpp_element(), state()) -> stanza(). +fix_from_to(Pkt, #{jid := JID}) when ?is_stanza(Pkt) -> + #jid{luser = U, lserver = S, lresource = R} = JID, + From = xmpp:get_from(Pkt), + From1 = case jid:tolower(From) of + {U, S, R} -> JID; + {U, S, _} -> jid:replace_resource(JID, From#jid.resource); + _ -> From + end, + xmpp:set_from_to(Pkt, From1, JID); +fix_from_to(Pkt, _State) -> + Pkt. + -spec change_shaper(state()) -> ok. change_shaper(#{shaper := ShaperName, ip := IP, lserver := LServer, user := U, server := S, resource := R} = State) -> @@ -748,8 +842,8 @@ change_shaper(#{shaper := ShaperName, ip := IP, lserver := LServer, LServer), xmpp_stream_in:change_shaper(State, Shaper). --spec do_some_magic(state(), jid()) -> state(). -do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) -> +-spec add_to_pres_a(state(), jid()) -> state(). +add_to_pres_a(#{pres_a := PresA, pres_f := PresF} = State, From) -> LFrom = jid:tolower(From), LBFrom = jid:remove_resource(LFrom), case (?SETS):is_element(LFrom, PresA) orelse @@ -775,20 +869,41 @@ do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) -> -spec format_reason(state(), term()) -> binary(). format_reason(#{stop_reason := Reason}, _) -> xmpp_stream_in:format_error(Reason); -format_reason(_, Reason) when Reason /= normal -> - <<"internal server error">>; +format_reason(_, normal) -> + <<"unknown reason">>; +format_reason(_, shutdown) -> + <<"stopped by supervisor">>; +format_reason(_, {shutdown, _}) -> + <<"stopped by supervisor">>; format_reason(_, _) -> - <<"">>. + <<"internal server error">>. transform_listen_option(Opt, Opts) -> [Opt|Opts]. opt_type(domain_certfile) -> fun iolist_to_binary/1; +opt_type(c2s_certfile) -> fun iolist_to_binary/1; +opt_type(c2s_ciphers) -> fun iolist_to_binary/1; +opt_type(c2s_dhfile) -> fun iolist_to_binary/1; +opt_type(c2s_cafile) -> fun iolist_to_binary/1; +opt_type(c2s_protocol_options) -> + fun (Options) -> str:join(Options, <<"|">>) end; +opt_type(c2s_tls_compression) -> + fun (true) -> true; + (false) -> false + end; opt_type(resource_conflict) -> fun (setresource) -> setresource; (closeold) -> closeold; (closenew) -> closenew; (acceptnew) -> acceptnew end; +opt_type(disable_sasl_mechanisms) -> + fun (V) when is_list(V) -> + lists:map(fun (M) -> str:to_upper(M) end, V); + (V) -> [str:to_upper(V)] + end; opt_type(_) -> - [domain_certfile, resource_conflict]. + [domain_certfile, c2s_certfile, c2s_ciphers, c2s_cafile, + c2s_protocol_options, c2s_tls_compression, resource_conflict, + disable_sasl_mechanisms]. diff --git a/src/ejabberd_hooks.erl b/src/ejabberd_hooks.erl index 612d5afe5..f63d1d75c 100644 --- a/src/ejabberd_hooks.erl +++ b/src/ejabberd_hooks.erl @@ -326,10 +326,9 @@ run1([{_Seq, Node, Module, Function} | Ls], Hook, Args) -> run1(Ls, Hook, Args) end; run1([{_Seq, Module, Function} | Ls], Hook, Args) -> - Res = safe_apply(Module, Function, Args), + Res = safe_apply(Hook, Module, Function, Args), case Res of - {'EXIT', Reason} -> - ?ERROR_MSG("~p~nrunning hook: ~p", [Reason, {Hook, Args}]), + 'EXIT' -> run1(Ls, Hook, Args); stop -> ok; @@ -362,10 +361,9 @@ run_fold1([{_Seq, Node, Module, Function} | Ls], Hook, Val, Args) -> run_fold1(Ls, Hook, NewVal, Args) end; run_fold1([{_Seq, Module, Function} | Ls], Hook, Val, Args) -> - Res = safe_apply(Module, Function, [Val | Args]), + Res = safe_apply(Hook, Module, Function, [Val | Args]), case Res of - {'EXIT', Reason} -> - ?ERROR_MSG("~p~nrunning hook: ~p", [Reason, {Hook, Args}]), + 'EXIT' -> run_fold1(Ls, Hook, Val, Args); stop -> stopped; @@ -375,12 +373,20 @@ run_fold1([{_Seq, Module, Function} | Ls], Hook, Val, Args) -> run_fold1(Ls, Hook, NewVal, Args) end. -safe_apply(Module, Function, Args) -> +safe_apply(Hook, Module, Function, Args) -> try if is_function(Function) -> apply(Function, Args); true -> apply(Module, Function, Args) end catch E:R when E /= exit, R /= normal -> - {'EXIT', {E, {R, erlang:get_stacktrace()}}} + ?ERROR_MSG("Hook ~p crashed when running ~p:~p/~p:~n" + "** Reason = ~p~n" + "** Arguments = ~p", + [Hook, Module, Function, length(Args), + {E, R, get_stacktrace()}, Args]), + 'EXIT' end. + +get_stacktrace() -> + [{Mod, Fun, Loc, Args} || {Mod, Fun, Args, Loc} <- erlang:get_stacktrace()]. diff --git a/src/ejabberd_http_ws.erl b/src/ejabberd_http_ws.erl index b92345dd4..6d90dba4b 100644 --- a/src/ejabberd_http_ws.erl +++ b/src/ejabberd_http_ws.erl @@ -120,7 +120,7 @@ init([{#ws{ip = IP, http_opts = HOpts}, _} = WS]) -> ({resend_on_timeout, _}) -> true; (_) -> false end, HOpts), - Opts = [{xml_socket, true} | ejabberd_c2s_config:get_c2s_limits() ++ SOpts], + Opts = ejabberd_c2s_config:get_c2s_limits() ++ SOpts, PingInterval = ejabberd_config:get_option( {websocket_ping_interval, ?MYNAME}, fun(I) when is_integer(I), I>=0 -> I end, diff --git a/src/ejabberd_listener.erl b/src/ejabberd_listener.erl index f720fc585..4191b1958 100644 --- a/src/ejabberd_listener.erl +++ b/src/ejabberd_listener.erl @@ -186,7 +186,9 @@ init_tcp(PortIP, Module, Opts, SockOpts, Port, IPS) -> listen_tcp(PortIP, Module, SockOpts, Port, IPS) -> case ets:lookup(listen_sockets, PortIP) of [{PortIP, ListenSocket}] -> - ?INFO_MSG("Reusing listening port for ~p", [PortIP]), + {_, _, Transport} = PortIP, + ?INFO_MSG("Reusing listening ~s port ~p at ~s", + [Transport, Port, IPS]), ets:delete(listen_sockets, PortIP), ListenSocket; _ -> @@ -330,21 +332,26 @@ accept(ListenSocket, Module, Opts, Interval) -> {ok, Socket} -> case {inet:sockname(Socket), inet:peername(Socket)} of {{ok, {Addr, Port}}, {ok, {PAddr, PPort}}} -> - ?INFO_MSG("Accepted connection ~s:~p -> ~s:~p", - [ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)), + CallMod = case is_frontend(Module) of + true -> ejabberd_frontend_socket; + false -> ejabberd_socket + end, + Receiver = case CallMod:start(strip_frontend(Module), + gen_tcp, Socket, Opts) of + {ok, RecvPid} -> RecvPid; + _ -> none + end, + ?INFO_MSG("(~p) Accepted connection ~s:~p -> ~s:~p", + [Receiver, + ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)), PPort, inet_parse:ntoa(Addr), Port]); _ -> ok end, - CallMod = case is_frontend(Module) of - true -> ejabberd_frontend_socket; - false -> ejabberd_socket - end, - CallMod:start(strip_frontend(Module), gen_tcp, Socket, Opts), accept(ListenSocket, Module, Opts, NewInterval); {error, Reason} -> - ?ERROR_MSG("(~w) Failed TCP accept: ~w", - [ListenSocket, Reason]), + ?ERROR_MSG("(~w) Failed TCP accept: ~s", + [ListenSocket, inet:format_error(Reason)]), accept(ListenSocket, Module, Opts, NewInterval) end. diff --git a/src/ejabberd_local.erl b/src/ejabberd_local.erl index a5ee6a242..48c4e863c 100644 --- a/src/ejabberd_local.erl +++ b/src/ejabberd_local.erl @@ -36,8 +36,7 @@ process_iq_reply/3, register_iq_handler/4, register_iq_handler/5, register_iq_response_handler/4, register_iq_response_handler/5, unregister_iq_handler/2, - unregister_iq_response_handler/2, refresh_iq_handlers/0, - bounce_resource_packet/3]). + unregister_iq_response_handler/2, bounce_resource_packet/3]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, @@ -90,8 +89,13 @@ process_iq(From, To, #iq{type = T, lang = Lang, sub_els = [El]} = Packet) Err = xmpp:err_service_unavailable(Txt, Lang), ejabberd_router:route_error(To, From, Packet, Err) end; -process_iq(From, To, #iq{type = T} = Packet) when T == get; T == set -> - Err = xmpp:err_bad_request(), +process_iq(From, To, #iq{type = T, lang = Lang, sub_els = SubEls} = Packet) + when T == get; T == set -> + Txt = case SubEls of + [] -> <<"No child elements found">>; + _ -> <<"Too many child elements">> + end, + Err = xmpp:err_bad_request(Txt, Lang), ejabberd_router:route_error(To, From, Packet, Err); process_iq(From, To, #iq{type = T} = Packet) when T == result; T == error -> process_iq_reply(From, To, Packet). @@ -171,10 +175,6 @@ unregister_iq_response_handler(_Host, ID) -> unregister_iq_handler(Host, XMLNS) -> ejabberd_local ! {unregister_iq_handler, Host, XMLNS}. --spec refresh_iq_handlers() -> any(). -refresh_iq_handlers() -> - ejabberd_local ! refresh_iq_handlers. - -spec bounce_resource_packet(jid(), jid(), stanza()) -> stop. bounce_resource_packet(_From, #jid{lresource = <<"">>}, #presence{}) -> ok; @@ -228,14 +228,12 @@ handle_info({register_iq_handler, Host, XMLNS, Module, Function}, State) -> ets:insert(?IQTABLE, {{XMLNS, Host}, Module, Function}), - catch mod_disco:register_feature(Host, XMLNS), {noreply, State}; handle_info({register_iq_handler, Host, XMLNS, Module, Function, Opts}, State) -> ets:insert(?IQTABLE, {{XMLNS, Host}, Module, Function, Opts}), - catch mod_disco:register_feature(Host, XMLNS), {noreply, State}; handle_info({unregister_iq_handler, Host, XMLNS}, State) -> @@ -245,19 +243,6 @@ handle_info({unregister_iq_handler, Host, XMLNS}, _ -> ok end, ets:delete(?IQTABLE, {XMLNS, Host}), - catch mod_disco:unregister_feature(Host, XMLNS), - {noreply, State}; -handle_info(refresh_iq_handlers, State) -> - lists:foreach(fun (T) -> - case T of - {{XMLNS, Host}, _Module, _Function, _Opts} -> - catch mod_disco:register_feature(Host, XMLNS); - {{XMLNS, Host}, _Module, _Function} -> - catch mod_disco:register_feature(Host, XMLNS); - _ -> ok - end - end, - ets:tab2list(?IQTABLE)), {noreply, State}; handle_info({timeout, _TRef, ID}, State) -> process_iq_timeout(ID), diff --git a/src/ejabberd_piefxis.erl b/src/ejabberd_piefxis.erl index 36d734004..9e6cbd715 100644 --- a/src/ejabberd_piefxis.erl +++ b/src/ejabberd_piefxis.erl @@ -484,18 +484,17 @@ process_privacy(#privacy_query{lists = Lists, JID = jid:make(U, S), IQ = #iq{type = set, id = randoms:get_string(), from = JID, to = JID, sub_els = [PrivacyQuery]}, - Txt = <<"No module is handling this query">>, - Error = {error, xmpp:err_feature_not_implemented(Txt, ?MYLANG)}, - case mod_privacy:process_iq_set(Error, IQ, #userlist{}) of - {error, #stanza_error{reason = Reason}} = Err -> + case mod_privacy:process_iq(IQ) of + #iq{type = error} = ResIQ -> + #stanza_error{reason = Reason} = xmpp:get_error(ResIQ), if Reason == 'item-not-found', Lists == [], Active == undefined, Default /= undefined -> %% Failed to set default list because there is no %% list with such name. We shouldn't stop here. {ok, State}; true -> - stop("Failed to write privacy: ~p", [Err]) - end; + stop("Failed to write privacy: ~p", [Reason]) + end; _ -> {ok, State} end. diff --git a/src/ejabberd_receiver.erl b/src/ejabberd_receiver.erl index 0a33e30ec..ffa55806f 100644 --- a/src/ejabberd_receiver.erl +++ b/src/ejabberd_receiver.erl @@ -135,8 +135,8 @@ handle_call({starttls, TLSSocket}, _From, State) -> {ok, TLSData} -> {reply, ok, process_data(TLSData, NewState), ?HIBERNATE_TIMEOUT}; - {error, _Reason} -> - {stop, normal, ok, NewState} + {error, _} = Err -> + {stop, normal, Err, NewState} end; handle_call({compress, Data}, _From, #state{socket = Socket, sock_mod = SockMod} = diff --git a/src/ejabberd_router.erl b/src/ejabberd_router.erl index 5ce8a8afb..b1c9c9e48 100644 --- a/src/ejabberd_router.erl +++ b/src/ejabberd_router.erl @@ -76,8 +76,17 @@ start_link() -> -spec route(jid(), jid(), xmlel() | stanza()) -> ok. -route(From, To, Packet) -> - case catch do_route(From, To, Packet) of +route(#jid{} = From, #jid{} = To, #xmlel{} = El) -> + try xmpp:decode(El, ?NS_CLIENT, [ignore_els]) of + Pkt -> route(From, To, xmpp:set_from_to(Pkt, From, To)) + catch _:{xmpp_codec, Why} -> + ?ERROR_MSG("failed to decode xml element ~p when " + "routing from ~s to ~s: ~s", + [El, jid:to_string(From), jid:to_string(To), + xmpp:format_error(Why)]) + end; +route(#jid{} = From, #jid{} = To, Packet) -> + case catch do_route(From, To, xmpp:set_from_to(Packet, From, To)) of {'EXIT', Reason} -> ?ERROR_MSG("~p~nwhen processing: ~p", [Reason, {From, To, Packet}]); @@ -169,7 +178,7 @@ register_route(Domain, ServerHost, LocalHint) -> mnesia:transaction(F) end, if LocalHint == undefined -> - ?INFO_MSG("Route registered: ~s", [LDomain]); + ?DEBUG("Route registered: ~s", [LDomain]); true -> ok end @@ -218,7 +227,7 @@ unregister_route(Domain) -> end, mnesia:transaction(F) end, - ?INFO_MSG("Route unregistered: ~s", [LDomain]) + ?DEBUG("Route unregistered: ~s", [LDomain]) end. -spec unregister_routes([binary()]) -> ok. @@ -283,9 +292,9 @@ process_iq(From, To, #iq{} = IQ) -> true -> ejabberd_sm:process_iq(From, To, IQ) end; -process_iq(From, To, El) -> +process_iq(From, To, #xmlel{} = El) -> try xmpp:decode(El, ?NS_CLIENT, [ignore_els]) of - IQ -> process_iq(From, To, IQ) + IQ -> process_iq(From, To, xmpp:set_from_to(IQ, From, To)) catch _:{xmpp_codec, Why} -> Type = xmpp:get_type(El), if Type == <<"get">>; Type == <<"set">> -> @@ -409,70 +418,56 @@ code_change(_OldVsn, State, _Extra) -> %%-------------------------------------------------------------------- %%% Internal functions %%-------------------------------------------------------------------- --spec do_route(jid(), jid(), xmlel() | xmpp_element()) -> any(). +-spec do_route(jid(), jid(), stanza()) -> any(). do_route(OrigFrom, OrigTo, OrigPacket) -> - ?DEBUG("route~n\tfrom ~p~n\tto ~p~n\tpacket " - "~p~n", - [OrigFrom, OrigTo, OrigPacket]), + ?DEBUG("route:~n~s", [xmpp:pp(OrigPacket)]), case ejabberd_hooks:run_fold(filter_packet, - {OrigFrom, OrigTo, OrigPacket}, []) - of - {From, To, Packet} -> - LDstDomain = To#jid.lserver, - case mnesia:dirty_read(route, LDstDomain) of - [] -> - try xmpp:decode(Packet, ?NS_CLIENT, [ignore_els]) of - Pkt -> - ejabberd_s2s:route(From, To, Pkt) - catch _:{xmpp_codec, Why} -> - log_decoding_error(From, To, Packet, Why) - end; - [R] -> - do_route(From, To, Packet, R); - Rs -> - Value = get_domain_balancing(From, To, LDstDomain), - case get_component_number(LDstDomain) of - undefined -> - case [R || R <- Rs, node(R#route.pid) == node()] of - [] -> - R = lists:nth(erlang:phash(Value, length(Rs)), Rs), - do_route(From, To, Packet, R); - LRs -> - R = lists:nth(erlang:phash(Value, length(LRs)), LRs), - do_route(From, To, Packet, R) - end; - _ -> - SRs = lists:ukeysort(#route.local_hint, Rs), - R = lists:nth(erlang:phash(Value, length(SRs)), SRs), - do_route(From, To, Packet, R) - end - end; - drop -> ok + {OrigFrom, OrigTo, OrigPacket}, []) of + {From, To, Packet} -> + LDstDomain = To#jid.lserver, + case mnesia:dirty_read(route, LDstDomain) of + [] -> + ejabberd_s2s:route(From, To, Packet); + [Route] -> + do_route(From, To, Packet, Route); + Routes -> + balancing_route(From, To, Packet, Routes) + end; + drop -> + ok end. --spec do_route(jid(), jid(), xmlel() | xmpp_element(), #route{}) -> any(). -do_route(From, To, Packet, #route{local_hint = LocalHint, - pid = Pid}) when is_pid(Pid) -> - try xmpp:decode(Packet, ?NS_CLIENT, [ignore_els]) of - Pkt -> - case LocalHint of - {apply, Module, Function} when node(Pid) == node() -> - Module:Function(From, To, Pkt); - _ -> - Pid ! {route, From, To, Pkt} - end - catch error:{xmpp_codec, Why} -> - log_decoding_error(From, To, Packet, Why) +-spec do_route(jid(), jid(), stanza(), #route{}) -> any(). +do_route(From, To, Pkt, #route{local_hint = LocalHint, + pid = Pid}) when is_pid(Pid) -> + case LocalHint of + {apply, Module, Function} when node(Pid) == node() -> + Module:Function(From, To, Pkt); + _ -> + Pid ! {route, From, To, Pkt} end; -do_route(_From, _To, _Packet, _Route) -> +do_route(_From, _To, _Pkt, _Route) -> drop. --spec log_decoding_error(jid(), jid(), xmlel() | xmpp_element(), term()) -> ok. -log_decoding_error(From, To, Packet, Reason) -> - ?ERROR_MSG("failed to decode xml element ~p when " - "routing from ~s to ~s: ~s", - [Packet, jid:to_string(From), jid:to_string(To), - xmpp:format_error(Reason)]). +-spec balancing_route(jid(), jid(), stanza(), [#route{}]) -> any(). +balancing_route(From, To, Packet, Rs) -> + LDstDomain = To#jid.lserver, + Value = get_domain_balancing(From, To, LDstDomain), + case get_component_number(LDstDomain) of + undefined -> + case [R || R <- Rs, node(R#route.pid) == node()] of + [] -> + R = lists:nth(erlang:phash(Value, length(Rs)), Rs), + do_route(From, To, Packet, R); + LRs -> + R = lists:nth(erlang:phash(Value, length(LRs)), LRs), + do_route(From, To, Packet, R) + end; + _ -> + SRs = lists:ukeysort(#route.local_hint, Rs), + R = lists:nth(erlang:phash(Value, length(SRs)), SRs), + do_route(From, To, Packet, R) + end. -spec get_component_number(binary()) -> pos_integer() | undefined. get_component_number(LDomain) -> diff --git a/src/ejabberd_s2s.erl b/src/ejabberd_s2s.erl index af4d6a662..d57c91ed2 100644 --- a/src/ejabberd_s2s.erl +++ b/src/ejabberd_s2s.erl @@ -257,7 +257,7 @@ tls_verify(LServer) -> -spec tls_enabled(binary()) -> boolean(). tls_enabled(LServer) -> TLS = use_starttls(LServer), - TLS == true orelse TLS == optional. + TLS /= false. -spec zlib_enabled(binary()) -> boolean(). zlib_enabled(LServer) -> diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index a31af337e..cca8438c6 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -120,12 +120,8 @@ process_closed(State, _Reason) -> %%%=================================================================== %%% xmpp_stream_in callbacks %%%=================================================================== -tls_options(#{tls_compression := Compression, server_host := LServer}) -> - Opts = case Compression of - false -> [compression_none]; - true -> [] - end, - ejabberd_s2s:tls_options(LServer, Opts). +tls_options(#{tls_options := TLSOpts, server_host := LServer}) -> + ejabberd_s2s:tls_options(LServer, TLSOpts). tls_required(#{server_host := LServer}) -> ejabberd_s2s:tls_required(LServer). @@ -164,16 +160,18 @@ handle_stream_established(State) -> set_idle_timeout(State#{established => true}). handle_auth_success(RServer, Mech, _AuthModule, - #{socket := Socket, ip := IP, + #{sockmod := SockMod, + socket := Socket, ip := IP, auth_domains := AuthDomains, server_host := ServerHost, lserver := LServer} = State) -> ?INFO_MSG("(~s) Accepted inbound s2s ~s authentication ~s -> ~s (~s)", - [ejabberd_socket:pp(Socket), Mech, RServer, LServer, + [SockMod:pp(Socket), Mech, RServer, LServer, ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), State1 = case ejabberd_s2s:allow_host(ServerHost, RServer) of true -> AuthDomains1 = sets:add_element(RServer, AuthDomains), + change_shaper(State, RServer), State#{auth_domains => AuthDomains1}; false -> State @@ -181,11 +179,12 @@ handle_auth_success(RServer, Mech, _AuthModule, ejabberd_hooks:run_fold(s2s_in_auth_result, ServerHost, State1, [true, RServer]). handle_auth_failure(RServer, Mech, Reason, - #{socket := Socket, ip := IP, + #{sockmod := SockMod, + socket := Socket, ip := IP, server_host := ServerHost, lserver := LServer} = State) -> ?INFO_MSG("(~s) Failed inbound s2s ~s authentication ~s -> ~s (~s): ~s", - [ejabberd_socket:pp(Socket), Mech, RServer, LServer, + [SockMod:pp(Socket), Mech, RServer, LServer, ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), ejabberd_hooks:run_fold(s2s_in_auth_result, ServerHost, State, [false, RServer]). @@ -204,10 +203,13 @@ handle_authenticated_packet(Pkt, State) -> LServer = ejabberd_router:host_of_route(To#jid.lserver), State1 = ejabberd_hooks:run_fold(s2s_in_authenticated_packet, LServer, State, [Pkt]), - Pkt1 = ejabberd_hooks:run_fold(s2s_receive_packet, LServer, - Pkt, [State1]), - ejabberd_router:route(From, To, Pkt1), - State1; + {Pkt1, State2} = ejabberd_hooks:run_fold(s2s_receive_packet, LServer, + {Pkt, State1}, []), + case Pkt1 of + drop -> ok; + _ -> ejabberd_router:route(From, To, Pkt1) + end, + State2; {error, Err} -> send(State, Err) end. @@ -225,8 +227,24 @@ handle_send(Pkt, Result, #{server_host := LServer} = State) -> init([State, Opts]) -> Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none), - TLSCompression = proplists:get_bool(tls_compression, Opts), - State1 = State#{tls_compression => TLSCompression, + TLSOpts1 = lists:filter( + fun({certfile, _}) -> true; + ({ciphers, _}) -> true; + ({dhfile, _}) -> true; + ({cafile, _}) -> true; + (_) -> false + end, Opts), + TLSOpts2 = case lists:keyfind(protocol_options, 1, Opts) of + false -> TLSOpts1; + {_, OptString} -> + ProtoOpts = str:join(OptString, <<$|>>), + [{protocol_options, ProtoOpts}|TLSOpts1] + end, + TLSOpts3 = case proplists:get_bool(tls_compression, Opts) of + false -> [compression_none | TLSOpts2]; + true -> TLSOpts2 + end, + State1 = State#{tls_options => TLSOpts3, auth_domains => sets:new(), xmlns => ?NS_SERVER, lang => ?MYLANG, @@ -251,8 +269,16 @@ handle_cast(Msg, #{server_host := LServer} = State) -> handle_info(Info, #{server_host := LServer} = State) -> ejabberd_hooks:run_fold(s2s_in_handle_info, LServer, State, [Info]). -terminate(_Reason, _State) -> - ok. +terminate(Reason, #{auth_domains := AuthDomains}) -> + case Reason of + {process_limit, _} -> + sets:fold( + fun(Host, _) -> + ejabberd_s2s:external_host_overloaded(Host) + end, ok, AuthDomains); + _ -> + ok + end. code_change(_OldVsn, State, _Extra) -> {ok, State}. @@ -290,5 +316,11 @@ set_idle_timeout(#{server_host := LServer, set_idle_timeout(State) -> State. +-spec change_shaper(state(), binary()) -> ok. +change_shaper(#{shaper := ShaperName, server_host := ServerHost} = State, + RServer) -> + Shaper = acl:match_rule(ServerHost, ShaperName, jid:make(RServer)), + xmpp_stream_in:change_shaper(State, Shaper). + opt_type(_) -> []. diff --git a/src/ejabberd_s2s_out.erl b/src/ejabberd_s2s_out.erl index 6069c786c..5188d269b 100644 --- a/src/ejabberd_s2s_out.erl +++ b/src/ejabberd_s2s_out.erl @@ -1,10 +1,23 @@ %%%------------------------------------------------------------------- -%%% @author Evgeny Khramtsov -%%% @copyright (C) 2016, Evgeny Khramtsov -%%% @doc -%%% -%%% @end %%% Created : 16 Dec 2016 by Evgeny Khramtsov +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% %%%------------------------------------------------------------------- -module(ejabberd_s2s_out). -behaviour(xmpp_stream_out). @@ -14,9 +27,11 @@ -export([opt_type/1, transform_options/1]). %% xmpp_stream_out callbacks -export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, + connect_timeout/1, address_families/1, default_port/1, + dns_retries/1, dns_timeout/1, handle_auth_success/2, handle_auth_failure/3, handle_packet/2, handle_stream_end/2, handle_stream_downgraded/2, - handle_recv/3, handle_send/4, handle_cdata/2, + handle_recv/3, handle_send/3, handle_cdata/2, handle_stream_established/1, handle_timeout/1]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). @@ -92,12 +107,12 @@ add_hooks() -> %%% Hooks %%%=================================================================== process_auth_result(#{server := LServer, remote_server := RServer} = State, - false) -> + {false, Reason}) -> Delay = get_delay(), - ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: authentication failed;" - " bouncing for ~p seconds", + ?INFO_MSG("Failed to establish outbound s2s connection ~s -> ~s: " + "authentication failed; bouncing for ~p seconds", [LServer, RServer, Delay]), - State1 = State#{on_route => bounce}, + State1 = State#{on_route => bounce, stop_reason => Reason}, State2 = close(State1), State3 = bounce_queue(State2), xmpp_stream_out:set_timeout(State3, timer:seconds(Delay)); @@ -113,7 +128,7 @@ process_closed(#{server := LServer, remote_server := RServer, process_closed(#{server := LServer, remote_server := RServer} = State, Reason) -> Delay = get_delay(), - ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s; " + ?INFO_MSG("Failed to establish outbound s2s connection ~s -> ~s: ~s; " "bouncing for ~p seconds", [LServer, RServer, xmpp_stream_out:format_error(Reason), Delay]), State1 = State#{on_route => bounce}, @@ -146,23 +161,65 @@ tls_verify(#{server := LServer}) -> tls_enabled(#{server := LServer}) -> ejabberd_s2s:tls_enabled(LServer). -handle_auth_success(Mech, #{socket := Socket, ip := IP, +connect_timeout(#{server := LServer}) -> + ejabberd_config:get_option( + {outgoing_s2s_timeout, LServer}, + fun(TimeOut) when is_integer(TimeOut), TimeOut > 0 -> + timer:seconds(TimeOut); + (infinity) -> + infinity + end, timer:seconds(10)). + +default_port(#{server := LServer}) -> + ejabberd_config:get_option( + {outgoing_s2s_port, LServer}, + fun(I) when is_integer(I), I > 0, I =< 65536 -> I end, + 5269). + +address_families(#{server := LServer}) -> + ejabberd_config:get_option( + {outgoing_s2s_families, LServer}, + fun(Families) -> + lists:map( + fun(ipv4) -> inet; + (ipv6) -> inet6 + end, Families) + end, [inet, inet6]). + +dns_retries(#{server := LServer}) -> + ejabberd_config:get_option( + {s2s_dns_retries, LServer}, + fun(I) when is_integer(I), I>=0 -> I end, + 2). + +dns_timeout(#{server := LServer}) -> + ejabberd_config:get_option( + {s2s_dns_timeout, LServer}, + fun(I) when is_integer(I), I>=0 -> + timer:seconds(I); + (infinity) -> + infinity + end, timer:seconds(10)). + +handle_auth_success(Mech, #{sockmod := SockMod, + socket := Socket, ip := IP, remote_server := RServer, server := LServer} = State) -> ?INFO_MSG("(~s) Accepted outbound s2s ~s authentication ~s -> ~s (~s)", - [ejabberd_socket:pp(Socket), Mech, LServer, RServer, + [SockMod:pp(Socket), Mech, LServer, RServer, ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State, [true]). handle_auth_failure(Mech, Reason, - #{socket := Socket, ip := IP, + #{sockmod := SockMod, + socket := Socket, ip := IP, remote_server := RServer, server := LServer} = State) -> ?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s", - [ejabberd_socket:pp(Socket), Mech, LServer, RServer, - ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), - State1 = State#{stop_reason => {auth, Reason}}, - ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State1, [false]). + [SockMod:pp(Socket), Mech, LServer, RServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), + xmpp_stream_out:format_error(Reason)]), + ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State, [{false, Reason}]). handle_packet(Pkt, #{server := LServer} = State) -> ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]). @@ -185,9 +242,8 @@ handle_cdata(Data, #{server := LServer} = State) -> handle_recv(El, Pkt, #{server := LServer} = State) -> ejabberd_hooks:run_fold(s2s_out_handle_recv, LServer, State, [El, Pkt]). -handle_send(Pkt, El, Data, #{server := LServer} = State) -> - ejabberd_hooks:run_fold(s2s_out_handle_send, LServer, - State, [Pkt, El, Data]). +handle_send(El, Pkt, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_send, LServer, State, [El, Pkt]). handle_timeout(#{on_route := Action} = State) -> case Action of @@ -298,7 +354,7 @@ get_delay() -> s2s_max_retry_delay, fun(I) when is_integer(I), I > 0 -> I end, 300), - crypto:rand_uniform(0, MaxDelay). + crypto:rand_uniform(1, MaxDelay). -spec set_idle_timeout(state()) -> state(). set_idle_timeout(#{on_route := send, server := LServer} = State) -> @@ -316,6 +372,7 @@ transform_options({outgoing_s2s_options, Families, Timeout}, Opts) -> "but it is better to fix your config: " "use 'outgoing_s2s_timeout' and " "'outgoing_s2s_families' instead.", []), + maybe_report_huge_timeout(outgoing_s2s_timeout, Timeout), [{outgoing_s2s_families, Families}, {outgoing_s2s_timeout, Timeout} | Opts]; @@ -327,15 +384,27 @@ transform_options({s2s_dns_options, S2SDNSOpts}, AllOpts) -> "'s2s_dns_retries' instead", []), lists:foldr( fun({timeout, T}, AccOpts) -> + maybe_report_huge_timeout(s2s_dns_timeout, T), [{s2s_dns_timeout, T}|AccOpts]; ({retries, R}, AccOpts) -> [{s2s_dns_retries, R}|AccOpts]; (_, AccOpts) -> AccOpts end, AllOpts, S2SDNSOpts); +transform_options({Opt, T}, Opts) + when Opt == outgoing_s2s_timeout; Opt == s2s_dns_timeout -> + maybe_report_huge_timeout(Opt, T), + [{outgoing_s2s_timeout, T}|Opts]; transform_options(Opt, Opts) -> [Opt|Opts]. +maybe_report_huge_timeout(Opt, T) when is_integer(T), T >= 1000 -> + ?WARNING_MSG("value '~p' of option '~p' is too big, " + "are you sure you have set seconds?", + [T, Opt]); +maybe_report_huge_timeout(_, _) -> + ok. + opt_type(outgoing_s2s_families) -> fun (Families) -> true = lists:all(fun (ipv4) -> true; @@ -354,7 +423,10 @@ opt_type(outgoing_s2s_timeout) -> opt_type(s2s_dns_retries) -> fun (I) when is_integer(I), I >= 0 -> I end; opt_type(s2s_dns_timeout) -> - fun (I) when is_integer(I), I >= 0 -> I end; + fun (TimeOut) when is_integer(TimeOut), TimeOut > 0 -> + TimeOut; + (infinity) -> infinity + end; opt_type(s2s_max_retry_delay) -> fun (I) when is_integer(I), I > 0 -> I end; opt_type(_) -> diff --git a/src/ejabberd_service.erl b/src/ejabberd_service.erl index 6ecd03a4c..d84de3db4 100644 --- a/src/ejabberd_service.erl +++ b/src/ejabberd_service.erl @@ -85,7 +85,8 @@ init([State, Opts]) -> dict:from_list([{global, Pass}]) end, CheckFrom = gen_mod:get_opt(check_from, Opts, - fun(Flag) when is_boolean(Flag) -> Flag end), + fun(Flag) when is_boolean(Flag) -> Flag end, + true), xmpp_stream_in:change_shaper(State, Shaper), State1 = State#{access => Access, xmlns => ?NS_COMPONENT, @@ -119,7 +120,7 @@ handle_stream_start(_StreamStart, end. get_password_fun(#{remote_server := RemoteServer, - socket := Socket, + socket := Socket, sockmod := SockMod, ip := IP, host_opts := HostOpts}) -> fun(_) -> @@ -129,7 +130,7 @@ get_password_fun(#{remote_server := RemoteServer, error -> ?ERROR_MSG("(~s) Domain ~s is unconfigured for " "external component from ~s", - [ejabberd_socket:pp(Socket), RemoteServer, + [SockMod:pp(Socket), RemoteServer, ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), {false, undefined} end @@ -137,10 +138,11 @@ get_password_fun(#{remote_server := RemoteServer, handle_auth_success(_, Mech, _, #{remote_server := RemoteServer, host_opts := HostOpts, - socket := Socket, ip := IP} = State) -> + socket := Socket, sockmod := SockMod, + ip := IP} = State) -> ?INFO_MSG("(~s) Accepted external component ~s authentication " "for ~s from ~s", - [ejabberd_socket:pp(Socket), Mech, RemoteServer, + [SockMod:pp(Socket), Mech, RemoteServer, ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), lists:foreach( fun (H) -> @@ -151,10 +153,11 @@ handle_auth_success(_, Mech, _, handle_auth_failure(_, Mech, Reason, #{remote_server := RemoteServer, + sockmod := SockMod, socket := Socket, ip := IP} = State) -> ?ERROR_MSG("(~s) Failed external component ~s authentication " "for ~s from ~s: ~s", - [ejabberd_socket:pp(Socket), Mech, RemoteServer, + [SockMod:pp(Socket), Mech, RemoteServer, ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), State. diff --git a/src/ejabberd_sm.erl b/src/ejabberd_sm.erl index a15d788d0..38b7ed15f 100644 --- a/src/ejabberd_sm.erl +++ b/src/ejabberd_sm.erl @@ -83,7 +83,6 @@ -include("xmpp.hrl"). -include("ejabberd_commands.hrl"). --include("mod_privacy.hrl"). -include("ejabberd_sm.hrl"). -callback init() -> ok | {error, any()}. @@ -576,24 +575,10 @@ do_route(From, To, Packet) -> %% or if there are no current sessions for the user. -spec is_privacy_allow(jid(), jid(), stanza()) -> boolean(). is_privacy_allow(From, To, Packet) -> - User = To#jid.user, - Server = To#jid.server, - PrivacyList = - ejabberd_hooks:run_fold(privacy_get_user_list, Server, - #userlist{}, [User, Server]), - is_privacy_allow(From, To, Packet, PrivacyList). - -%% Check if privacy rules allow this delivery -%% Function copied from ejabberd_c2s.erl --spec is_privacy_allow(jid(), jid(), stanza(), #userlist{}) -> boolean(). -is_privacy_allow(From, To, Packet, PrivacyList) -> - User = To#jid.user, - Server = To#jid.server, - allow == - ejabberd_hooks:run_fold(privacy_check_packet, Server, - allow, - [User, Server, PrivacyList, {From, To, Packet}, - in]). + LServer = To#jid.server, + allow == ejabberd_hooks:run_fold( + privacy_check_packet, LServer, allow, + [To, xmpp:set_from_to(Packet, From, To), in]). -spec route_message(jid(), jid(), message(), message_type()) -> any(). route_message(From, To, Packet, Type) -> @@ -757,10 +742,14 @@ process_iq(From, To, #iq{type = T, lang = Lang, sub_els = [El]} = Packet) Err = xmpp:err_service_unavailable(Txt, Lang), ejabberd_router:route_error(To, From, Packet, Err) end; -process_iq(From, To, #iq{type = T} = Packet) when T == get; T == set -> - Err = xmpp:err_bad_request(), - ejabberd_router:route_error(To, From, Packet, Err), - ok; +process_iq(From, To, #iq{type = T, lang = Lang, sub_els = SubEls} = Packet) + when T == get; T == set -> + Txt = case SubEls of + [] -> <<"No child elements found">>; + _ -> <<"Too many child elements">> + end, + Err = xmpp:err_bad_request(Txt, Lang), + ejabberd_router:route_error(To, From, Packet, Err); process_iq(_From, _To, #iq{}) -> ok. @@ -770,7 +759,7 @@ force_update_presence({LUser, LServer}) -> Mod = get_sm_backend(LServer), Ss = online(Mod:get_sessions(LUser, LServer)), lists:foreach(fun (#session{sid = {_, Pid}}) -> - Pid ! {force_update_presence, LUser, LServer} + Pid ! force_update_presence end, Ss). diff --git a/src/ejabberd_socket.erl b/src/ejabberd_socket.erl index 4e523a7e5..83b7ae9b9 100644 --- a/src/ejabberd_socket.erl +++ b/src/ejabberd_socket.erl @@ -33,10 +33,12 @@ connect/4, connect/5, starttls/2, - starttls/3, compress/1, compress/2, reset_stream/1, + send_element/2, + send_header/2, + send_trailer/1, send/2, send_xml/2, change_shaper/2, @@ -78,60 +80,63 @@ [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore. -callback socket_type() -> xml_stream | independent | raw. +-define(is_http_socket(S), + (S#socket_state.sockmod == ejabberd_bosh orelse + S#socket_state.sockmod == ejabberd_http_ws)). + %%==================================================================== %% API %%==================================================================== --spec start(atom(), sockmod(), socket(), [{atom(), any()}]) -> any(). - +-spec start(atom(), sockmod(), socket(), [proplists:propery()]) + -> {ok, pid() | independent} | {error, inet:posix() | any()}. start(Module, SockMod, Socket, Opts) -> case Module:socket_type() of - xml_stream -> - MaxStanzaSize = case lists:keysearch(max_stanza_size, 1, - Opts) - of - {value, {_, Size}} -> Size; - _ -> infinity - end, - {ReceiverMod, Receiver, RecRef} = case catch - SockMod:custom_receiver(Socket) - of - {receiver, RecMod, RecPid} -> - {RecMod, RecPid, RecMod}; - _ -> - RecPid = - ejabberd_receiver:start(Socket, - SockMod, - none, - MaxStanzaSize), - {ejabberd_receiver, RecPid, - RecPid} - end, - SocketData = #socket_state{sockmod = SockMod, - socket = Socket, receiver = RecRef}, - case Module:start({?MODULE, SocketData}, Opts) of - {ok, Pid} -> - case SockMod:controlling_process(Socket, Receiver) of - ok -> ok; - {error, _Reason} -> SockMod:close(Socket) + independent -> {ok, independent}; + xml_stream -> + MaxStanzaSize = proplists:get_value(max_stanza_size, Opts, infinity), + {ReceiverMod, Receiver, RecRef} = + try SockMod:custom_receiver(Socket) of + {receiver, RecMod, RecPid} -> + {RecMod, RecPid, RecMod} + catch _:_ -> + RecPid = ejabberd_receiver:start( + Socket, SockMod, none, MaxStanzaSize), + {ejabberd_receiver, RecPid, RecPid} end, - ReceiverMod:become_controller(Receiver, Pid); - _ -> - SockMod:close(Socket), - case ReceiverMod of - ejabberd_receiver -> ReceiverMod:close(Receiver); - _ -> ok - end - end; - independent -> ok; - raw -> - case Module:start({SockMod, Socket}, Opts) of - {ok, Pid} -> - case SockMod:controlling_process(Socket, Pid) of - ok -> ok; - {error, _Reason} -> SockMod:close(Socket) - end; - {error, _Reason} -> SockMod:close(Socket) - end + SocketData = #socket_state{sockmod = SockMod, + socket = Socket, receiver = RecRef}, + case Module:start({?MODULE, SocketData}, Opts) of + {ok, Pid} -> + case SockMod:controlling_process(Socket, Receiver) of + ok -> + ReceiverMod:become_controller(Receiver, Pid), + {ok, Receiver}; + Err -> + SockMod:close(Socket), + Err + end; + Err -> + SockMod:close(Socket), + case ReceiverMod of + ejabberd_receiver -> ReceiverMod:close(Receiver); + _ -> ok + end, + Err + end; + raw -> + case Module:start({SockMod, Socket}, Opts) of + {ok, Pid} -> + case SockMod:controlling_process(Socket, Pid) of + ok -> + {ok, Pid}; + {error, _} = Err -> + SockMod:close(Socket), + Err + end; + Err -> + SockMod:close(Socket), + Err + end end. connect(Addr, Port, Opts) -> @@ -156,35 +161,31 @@ connect(Addr, Port, Opts, Timeout, Owner) -> {error, _Reason} = Error -> Error end. -starttls(SocketData, TLSOpts) -> - case fast_tls:tcp_to_tls(SocketData#socket_state.socket, TLSOpts) of +starttls(#socket_state{socket = Socket, + receiver = Receiver} = SocketData, TLSOpts) -> + case fast_tls:tcp_to_tls(Socket, TLSOpts) of {ok, TLSSocket} -> - ejabberd_receiver:starttls(SocketData#socket_state.receiver, TLSSocket), - {ok, SocketData#socket_state{socket = TLSSocket, sockmod = fast_tls}}; - Err -> - ?ERROR_MSG("starttls failed: ~p", [Err]), - Err - end. - -starttls(SocketData, TLSOpts, Data) -> - case fast_tls:tcp_to_tls(SocketData#socket_state.socket, TLSOpts) of - {ok, TLSSocket} -> - ejabberd_receiver:starttls(SocketData#socket_state.receiver, TLSSocket), - send(SocketData, Data), - {ok, SocketData#socket_state{socket = TLSSocket, sockmod = fast_tls}}; - Err -> - ?ERROR_MSG("starttls failed: ~p", [Err]), + case ejabberd_receiver:starttls(Receiver, TLSSocket) of + ok -> + {ok, SocketData#socket_state{socket = TLSSocket, + sockmod = fast_tls}}; + {error, _} = Err -> + Err + end; + {error, _} = Err -> Err end. compress(SocketData) -> compress(SocketData, undefined). compress(SocketData, Data) -> - {ok, ZlibSocket} = - ejabberd_receiver:compress(SocketData#socket_state.receiver, - Data), - SocketData#socket_state{socket = ZlibSocket, - sockmod = ezlib}. + case ejabberd_receiver:compress(SocketData#socket_state.receiver, Data) of + {ok, ZlibSocket} -> + {ok, SocketData#socket_state{socket = ZlibSocket, sockmod = ezlib}}; + Err -> + ?ERROR_MSG("compress failed: ~p", [Err]), + Err + end. reset_stream(SocketData) when is_pid(SocketData#socket_state.receiver) -> @@ -193,30 +194,41 @@ reset_stream(SocketData) when is_atom(SocketData#socket_state.receiver) -> (SocketData#socket_state.receiver):reset_stream(SocketData#socket_state.socket). --spec send(socket_state(), iodata()) -> ok. +-spec send_element(socket_state(), fxml:xmlel()) -> ok | {error, inet:posix()}. +send_element(SocketData, El) when ?is_http_socket(SocketData) -> + send_xml(SocketData, {xmlstreamelement, El}); +send_element(SocketData, El) -> + send(SocketData, fxml:element_to_binary(El)). -send(SocketData, Data) -> - ?DEBUG("Send XML on stream = ~p", [Data]), - case catch (SocketData#socket_state.sockmod):send( - SocketData#socket_state.socket, Data) of - ok -> ok; - {error, timeout} -> - ?INFO_MSG("Timeout on ~p:send",[SocketData#socket_state.sockmod]), - {error, timeout}; - Error -> - ?DEBUG("Error in ~p:send: ~p",[SocketData#socket_state.sockmod, Error]), - Error +-spec send_header(socket_state(), fxml:xmlel()) -> ok | {error, inet:posix()}. +send_header(SocketData, El) when ?is_http_socket(SocketData) -> + send_xml(SocketData, {xmlstreamstart, El#xmlel.name, El#xmlel.attrs}); +send_header(SocketData, El) -> + send(SocketData, fxml:element_to_header(El)). + +-spec send_trailer(socket_state()) -> ok | {error, inet:posix()}. +send_trailer(SocketData) when ?is_http_socket(SocketData) -> + send_xml(SocketData, {xmlstreamend, <<"stream:stream">>}); +send_trailer(SocketData) -> + send(SocketData, <<"">>). + +-spec send(socket_state(), iodata()) -> ok | {error, inet:posix()}. +send(#socket_state{sockmod = SockMod, socket = Socket} = SocketData, Data) -> + ?DEBUG("(~s) Send XML on stream = ~p", [pp(SocketData), Data]), + try SockMod:send(Socket, Data) + catch _:badarg -> + %% Some modules throw badarg exceptions on closed sockets + %% TODO: their code should be improved + {error, einval} end. -%% Can only be called when in c2s StateData#state.xml_socket is true -%% This function is used for HTTP bind -%% sockmod=ejabberd_http_ws|ejabberd_http_bind or any custom module --spec send_xml(socket_state(), fxml:xmlel()) -> any(). - -send_xml(SocketData, Data) -> - catch - (SocketData#socket_state.sockmod):send_xml(SocketData#socket_state.socket, - Data). +-spec send_xml(socket_state(), + {xmlstreamelement, fxml:xmlel()} | + {xmlstreamstart, binary(), [{binary(), binary()}]} | + {xmlstreamend, binary()} | + {xmlstreamraw, iodata()}) -> term(). +send_xml(SocketData, El) -> + (SocketData#socket_state.sockmod):send_xml(SocketData#socket_state.socket, El). change_shaper(SocketData, Shaper) when is_pid(SocketData#socket_state.receiver) -> diff --git a/src/mod_announce.erl b/src/mod_announce.erl index 2e182ed1e..15524ce1e 100644 --- a/src/mod_announce.erl +++ b/src/mod_announce.erl @@ -68,7 +68,7 @@ start(Host, Opts) -> ejabberd_hooks:add(disco_local_items, Host, ?MODULE, disco_items, 50), ejabberd_hooks:add(adhoc_local_items, Host, ?MODULE, announce_items, 50), ejabberd_hooks:add(adhoc_local_commands, Host, ?MODULE, announce_commands, 50), - ejabberd_hooks:add(user_available_hook, Host, + ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, send_motd, 50), register(gen_mod:get_module_proc(Host, ?PROCNAME), proc_lib:spawn(?MODULE, init, [])). @@ -123,7 +123,7 @@ stop(Host) -> ejabberd_hooks:delete(disco_local_items, Host, ?MODULE, disco_items, 50), ejabberd_hooks:delete(local_send_to_resource_hook, Host, ?MODULE, announce, 50), - ejabberd_hooks:delete(user_available_hook, Host, + ejabberd_hooks:delete(c2s_self_presence, Host, ?MODULE, send_motd, 50), Proc = gen_mod:get_module_proc(Host, ?PROCNAME), exit(whereis(Proc), stop), @@ -733,8 +733,13 @@ announce_motd_delete(LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:delete_motd(LServer). --spec send_motd(jid()) -> ok | {atomic, any()}. -send_motd(#jid{luser = LUser, lserver = LServer} = JID) when LUser /= <<>> -> +-spec send_motd({presence(), ejabberd_c2s:state()}) -> {presence(), ejabberd_c2s:state()}. +send_motd({_, #{pres_last := _}} = Acc) -> + %% This is just a presence update, nothing to do + Acc; +send_motd({#presence{type = available}, + #{jid := #jid{luser = LUser, lserver = LServer} = JID}} = Acc) + when LUser /= <<>> -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:get_motd(LServer) of {ok, Packet} -> @@ -754,9 +759,10 @@ send_motd(#jid{luser = LUser, lserver = LServer} = JID) when LUser /= <<>> -> end; error -> ok - end; -send_motd(_) -> - ok. + end, + Acc; +send_motd(Acc) -> + Acc. get_stored_motd(LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), diff --git a/src/mod_blocking.erl b/src/mod_blocking.erl index 45564daf4..5195bfb73 100644 --- a/src/mod_blocking.erl +++ b/src/mod_blocking.erl @@ -29,8 +29,8 @@ -protocol({xep, 191, '1.2'}). --export([start/2, stop/1, process_iq/1, c2s_handle_info/2, - process_iq_set/3, process_iq_get/3, mod_opt_type/1, depends/2]). +-export([start/2, stop/1, process_iq/1, mod_opt_type/1, depends/2, + disco_features/5]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -48,77 +48,73 @@ start(Host, Opts) -> IQDisc = gen_mod:get_opt(iqdisc, Opts, fun gen_iq_handler:check_type/1, one_queue), - ejabberd_hooks:add(privacy_iq_get, Host, ?MODULE, - process_iq_get, 40), - ejabberd_hooks:add(privacy_iq_set, Host, ?MODULE, - process_iq_set, 40), - ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, - c2s_handle_info, 40), - mod_disco:register_feature(Host, ?NS_BLOCKING), + ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 50), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_BLOCKING, ?MODULE, process_iq, IQDisc). stop(Host) -> - ejabberd_hooks:delete(privacy_iq_get, Host, ?MODULE, - process_iq_get, 40), - ejabberd_hooks:delete(privacy_iq_set, Host, ?MODULE, - process_iq_set, 40), - ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, - c2s_handle_info, 40), - mod_disco:unregister_feature(Host, ?NS_BLOCKING), - gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, - ?NS_BLOCKING). + ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, disco_features, 50), + gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_BLOCKING). depends(_Host, _Opts) -> [{mod_privacy, hard}]. +-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty, + jid(), jid(), binary(), binary()) -> + {error, stanza_error()} | {result, [binary()]}. +disco_features({error, Err}, _From, _To, _Node, _Lang) -> + {error, Err}; +disco_features(empty, _From, _To, <<"">>, _Lang) -> + {result, [?NS_BLOCKING]}; +disco_features({result, Feats}, _From, _To, <<"">>, _Lang) -> + {result, [?NS_BLOCKING|Feats]}; +disco_features(Acc, _From, _To, _Node, _Lang) -> + Acc. + -spec process_iq(iq()) -> iq(). -process_iq(IQ) -> - xmpp:make_error(IQ, xmpp:err_not_allowed()). +process_iq(#iq{type = Type, + from = #jid{luser = U, lserver = S}, + to = #jid{luser = U, lserver = S}} = IQ) -> + case Type of + get -> process_iq_get(IQ); + set -> process_iq_set(IQ) + end; +process_iq(#iq{lang = Lang} = IQ) -> + Txt = <<"Query to another users is forbidden">>, + xmpp:make_error(IQ, xmpp:err_forbidden(Txt, Lang)). --spec process_iq_get({error, stanza_error()} | {result, xmpp_element() | undefined}, - iq(), userlist()) -> - {error, stanza_error()} | - {result, xmpp_element() | undefined}. -process_iq_get(_, #iq{lang = Lang, from = From, - sub_els = [#block_list{}]}, _) -> - #jid{luser = LUser, lserver = LServer} = From, - process_blocklist_get(LUser, LServer, Lang); -process_iq_get(Acc, _, _) -> Acc. +-spec process_iq_get(iq()) -> iq(). +process_iq_get(#iq{sub_els = [#block_list{}]} = IQ) -> + process_get(IQ); +process_iq_get(#iq{lang = Lang} = IQ) -> + Txt = <<"No module is handling this query">>, + xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)). --spec process_iq_set({error, stanza_error()} | - {result, xmpp_element() | undefined} | - {result, xmpp_element() | undefined, userlist()}, - iq(), userlist()) -> - {error, stanza_error()} | - {result, xmpp_element() | undefined} | - {result, xmpp_element() | undefined, userlist()}. -process_iq_set(Acc, #iq{from = From, lang = Lang, sub_els = [SubEl]}, _) -> - #jid{luser = LUser, lserver = LServer} = From, +-spec process_iq_set(iq()) -> iq(). +process_iq_set(#iq{lang = Lang, sub_els = [SubEl]} = IQ) -> case SubEl of #block{items = []} -> Txt = <<"No items found in this query">>, - {error, xmpp:err_bad_request(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang)); #block{items = Items} -> JIDs = [jid:tolower(Item) || Item <- Items], - process_blocklist_block(LUser, LServer, JIDs, Lang); + process_block(IQ, JIDs); #unblock{items = []} -> - process_blocklist_unblock_all(LUser, LServer, Lang); + process_unblock_all(IQ); #unblock{items = Items} -> JIDs = [jid:tolower(Item) || Item <- Items], - process_blocklist_unblock(LUser, LServer, JIDs, Lang); + process_unblock(IQ, JIDs); _ -> - Acc - end; -process_iq_set(Acc, _, _) -> Acc. + Txt = <<"No module is handling this query">>, + xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)) + end. --spec list_to_blocklist_jids([listitem()], [ljid()]) -> [ljid()]. -list_to_blocklist_jids([], JIDs) -> JIDs; -list_to_blocklist_jids([#listitem{type = jid, - action = deny, value = JID} = - Item - | Items], - JIDs) -> +-spec listitems_to_jids([listitem()], [ljid()]) -> [ljid()]. +listitems_to_jids([], JIDs) -> + JIDs; +listitems_to_jids([#listitem{type = jid, + action = deny, value = JID} = Item | Items], + JIDs) -> Match = case Item of #listitem{match_all = true} -> true; @@ -130,20 +126,18 @@ list_to_blocklist_jids([#listitem{type = jid, _ -> false end, - if Match -> list_to_blocklist_jids(Items, [JID | JIDs]); - true -> list_to_blocklist_jids(Items, JIDs) + if Match -> listitems_to_jids(Items, [JID | JIDs]); + true -> listitems_to_jids(Items, JIDs) end; % Skip Privacy List items than cannot be mapped to Blocking items -list_to_blocklist_jids([_ | Items], JIDs) -> - list_to_blocklist_jids(Items, JIDs). +listitems_to_jids([_ | Items], JIDs) -> + listitems_to_jids(Items, JIDs). --spec process_blocklist_block(binary(), binary(), [ljid()], - binary()) -> - {error, stanza_error()} | - {result, undefined, userlist()}. -process_blocklist_block(LUser, LServer, JIDs, Lang) -> +-spec process_block(iq(), [ljid()]) -> iq(). +process_block(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ, JIDs) -> Filter = fun (List) -> - AlreadyBlocked = list_to_blocklist_jids(List, []), + AlreadyBlocked = listitems_to_jids(List, []), lists:foldr(fun (JID, List1) -> case lists:member(JID, AlreadyBlocked) of @@ -161,23 +155,21 @@ process_blocklist_block(LUser, LServer, JIDs, Lang) -> end, Mod = db_mod(LServer), case Mod:process_blocklist_block(LUser, LServer, Filter) of - {atomic, {ok, Default, List}} -> - UserList = make_userlist(Default, List), - broadcast_list_update(LUser, LServer, Default, - UserList), - broadcast_blocklist_event(LUser, LServer, - {block, [jid:make(J) || J <- JIDs]}), - {result, undefined, UserList}; - _Err -> + {atomic, {ok, Default, List}} -> + UserList = make_userlist(Default, List), + broadcast_list_update(LUser, LServer, UserList, Default), + broadcast_event(LUser, LServer, + #block{items = [jid:make(J) || J <- JIDs]}), + xmpp:make_iq_result(IQ); + _Err -> ?ERROR_MSG("Error processing ~p: ~p", [{LUser, LServer, JIDs}, _Err]), - {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)} + Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang), + xmpp:make_error(IQ, Err) end. --spec process_blocklist_unblock_all(binary(), binary(), binary()) -> - {error, stanza_error()} | - {result, undefined} | - {result, undefined, userlist()}. -process_blocklist_unblock_all(LUser, LServer, Lang) -> +-spec process_unblock_all(iq()) -> iq(). +process_unblock_all(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ) -> Filter = fun (List) -> lists:filter(fun (#listitem{action = A}) -> A =/= deny end, @@ -185,23 +177,22 @@ process_blocklist_unblock_all(LUser, LServer, Lang) -> end, Mod = db_mod(LServer), case Mod:unblock_by_filter(LUser, LServer, Filter) of - {atomic, ok} -> {result, undefined}; - {atomic, {ok, Default, List}} -> - UserList = make_userlist(Default, List), - broadcast_list_update(LUser, LServer, Default, - UserList), - broadcast_blocklist_event(LUser, LServer, unblock_all), - {result, undefined, UserList}; - _Err -> + {atomic, ok} -> + xmpp:make_iq_result(IQ); + {atomic, {ok, Default, List}} -> + UserList = make_userlist(Default, List), + broadcast_list_update(LUser, LServer, UserList, Default), + broadcast_event(LUser, LServer, #unblock{}), + xmpp:make_iq_result(IQ); + _Err -> ?ERROR_MSG("Error processing ~p: ~p", [{LUser, LServer}, _Err]), - {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)} + Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang), + xmpp:make_error(IQ, Err) end. --spec process_blocklist_unblock(binary(), binary(), [ljid()], binary()) -> - {error, stanza_error()} | - {result, undefined} | - {result, undefined, userlist()}. -process_blocklist_unblock(LUser, LServer, JIDs, Lang) -> +-spec process_unblock(iq(), [ljid()]) -> iq(). +process_unblock(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ, JIDs) -> Filter = fun (List) -> lists:filter(fun (#listitem{action = deny, type = jid, value = JID}) -> @@ -212,17 +203,18 @@ process_blocklist_unblock(LUser, LServer, JIDs, Lang) -> end, Mod = db_mod(LServer), case Mod:unblock_by_filter(LUser, LServer, Filter) of - {atomic, ok} -> {result, undefined}; - {atomic, {ok, Default, List}} -> - UserList = make_userlist(Default, List), - broadcast_list_update(LUser, LServer, Default, - UserList), - broadcast_blocklist_event(LUser, LServer, - {unblock, [jid:make(J) || J <- JIDs]}), - {result, undefined, UserList}; - _Err -> + {atomic, ok} -> + xmpp:make_iq_result(IQ); + {atomic, {ok, Default, List}} -> + UserList = make_userlist(Default, List), + broadcast_list_update(LUser, LServer, UserList, Default), + broadcast_event(LUser, LServer, + #unblock{items = [jid:make(J) || J <- JIDs]}), + xmpp:make_iq_result(IQ); + _Err -> ?ERROR_MSG("Error processing ~p: ~p", [{LUser, LServer, JIDs}, _Err]), - {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)} + Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang), + xmpp:make_error(IQ, Err) end. -spec make_userlist(binary(), [listitem()]) -> userlist(). @@ -230,52 +222,36 @@ make_userlist(Name, List) -> NeedDb = mod_privacy:is_list_needdb(List), #userlist{name = Name, list = List, needdb = NeedDb}. --spec broadcast_list_update(binary(), binary(), binary(), userlist()) -> ok. -broadcast_list_update(LUser, LServer, Name, UserList) -> - ejabberd_sm:route(jid:make(LUser, LServer, <<"">>), - {privacy_list, UserList, Name}). +-spec broadcast_list_update(binary(), binary(), userlist(), binary()) -> ok. +broadcast_list_update(LUser, LServer, UserList, Name) -> + mod_privacy:push_list_update(jid:make(LUser, LServer), UserList, Name). --spec broadcast_blocklist_event(binary(), binary(), block_event()) -> ok. -broadcast_blocklist_event(LUser, LServer, Event) -> - JID = jid:make(LUser, LServer, <<"">>), - ejabberd_sm:route(JID, {blocking, Event}). +-spec broadcast_event(binary(), binary(), block_event()) -> ok. +broadcast_event(LUser, LServer, Event) -> + From = jid:make(LUser, LServer), + lists:foreach( + fun(R) -> + To = jid:replace_resource(From, R), + IQ = #iq{type = set, from = From, to = To, + id = <<"push", (randoms:get_string())/binary>>, + sub_els = [Event]}, + ejabberd_router:route(From, To, IQ) + end, ejabberd_sm:get_user_resources(LUser, LServer)). --spec process_blocklist_get(binary(), binary(), binary()) -> - {error, stanza_error()} | {result, block_list()}. -process_blocklist_get(LUser, LServer, Lang) -> +-spec process_get(iq()) -> iq(). +process_get(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ) -> Mod = db_mod(LServer), case Mod:process_blocklist_get(LUser, LServer) of - error -> - {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)}; - List -> - LJIDs = list_to_blocklist_jids(List, []), - Items = [jid:make(J) || J <- LJIDs], - {result, #block_list{items = Items}} + error -> + Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang), + xmpp:make_error(IQ, Err); + List -> + LJIDs = listitems_to_jids(List, []), + Items = [jid:make(J) || J <- LJIDs], + xmpp:make_iq_result(IQ, #block_list{items = Items}) end. --spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state(). -c2s_handle_info(#{user := U, server := S, resource := R} = State, - {blocking, Action}) -> - SubEl = case Action of - {block, JIDs} -> - #block{items = JIDs}; - {unblock, JIDs} -> - #unblock{items = JIDs}; - unblock_all -> - #unblock{} - end, - PushIQ = #iq{type = set, - from = jid:make(U, S), - to = jid:make(U, S, R), - id = <<"push", (randoms:get_string())/binary>>, - sub_els = [SubEl]}, - %% No need to replace active privacy list here, - %% blocking pushes are always accompanied by - %% Privacy List pushes - {stop, ejabberd_c2s:send(State, PushIQ)}; -c2s_handle_info(State, _) -> - State. - -spec db_mod(binary()) -> module(). db_mod(LServer) -> DBType = gen_mod:db_type(LServer, mod_privacy), diff --git a/src/mod_caps.erl b/src/mod_caps.erl index e2ec30305..d5a623669 100644 --- a/src/mod_caps.erl +++ b/src/mod_caps.erl @@ -47,7 +47,7 @@ -export([init/1, handle_info/2, handle_call/3, handle_cast/2, terminate/2, code_change/3]). --export([user_send_packet/4, user_receive_packet/5, +-export([user_send_packet/1, user_receive_packet/1, c2s_presence_in/2, mod_opt_type/1]). -include("ejabberd.hrl"). @@ -126,47 +126,51 @@ read_caps(Presence) -> Caps -> Caps end. --spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -user_send_packet(#presence{type = available} = Pkt, - _C2SState, - #jid{luser = User, lserver = Server} = From, - #jid{luser = User, lserver = Server, - lresource = <<"">>}) -> +-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_send_packet({#presence{type = available, + from = #jid{luser = U, lserver = LServer} = From, + to = #jid{luser = U, lserver = LServer, + lresource = <<"">>}} = Pkt, + State}) -> case read_caps(Pkt) of nothing -> ok; #caps{version = Version, exts = Exts} = Caps -> - feature_request(Server, From, Caps, [Version | Exts]) + feature_request(LServer, From, Caps, [Version | Exts]) end, - Pkt; -user_send_packet(Pkt, _C2SState, _From, _To) -> - Pkt. + {Pkt, State}; +user_send_packet(Acc) -> + Acc. --spec user_receive_packet(stanza(), ejabberd_c2s:state(), - jid(), jid(), jid()) -> stanza(). -user_receive_packet(#presence{type = available} = Pkt, - _C2SState, - #jid{lserver = Server}, - From, _To) -> - IsRemote = not lists:member(From#jid.lserver, ?MYHOSTS), +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_receive_packet({#presence{from = From, type = available} = Pkt, + #{lserver := LServer} = State}) -> + IsRemote = not ejabberd_router:is_my_host(From#jid.lserver), if IsRemote -> - case read_caps(Pkt) of - nothing -> ok; - #caps{version = Version, exts = Exts} = Caps -> - feature_request(Server, From, Caps, [Version | Exts]) - end; + case read_caps(Pkt) of + nothing -> ok; + #caps{version = Version, exts = Exts} = Caps -> + feature_request(LServer, From, Caps, [Version | Exts]) + end; true -> ok end, - Pkt; -user_receive_packet(Pkt, _C2SState, _JID, _From, _To) -> - Pkt. + {Pkt, State}; +user_receive_packet(Acc) -> + Acc. -spec caps_stream_features([xmpp_element()], binary()) -> [xmpp_element()]. caps_stream_features(Acc, MyHost) -> - case make_my_disco_hash(MyHost) of - <<"">> -> Acc; - Hash -> - [#caps{hash = <<"sha-1">>, node = ?EJABBERD_URI, version = Hash}|Acc] + case gen_mod:is_loaded(MyHost, ?MODULE) of + true -> + case make_my_disco_hash(MyHost) of + <<"">> -> + Acc; + Hash -> + [#caps{hash = <<"sha-1">>, node = ?EJABBERD_URI, + version = Hash}|Acc] + end; + false -> + Acc end. -spec disco_features({error, stanza_error()} | {result, [binary()]} | empty, @@ -238,7 +242,7 @@ c2s_presence_in(C2SState, end; _ -> gb_trees:delete_any(LFrom, Rs) end, - C2SState#{caps_resources := NewRs}; + C2SState#{caps_resources => NewRs}; true -> C2SState end. @@ -266,7 +270,7 @@ init([Host, Opts]) -> user_receive_packet, 75), ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE, caps_stream_features, 75), - ejabberd_hooks:add(s2s_stream_features, Host, ?MODULE, + ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE, caps_stream_features, 75), ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 75), @@ -295,7 +299,7 @@ terminate(_Reason, State) -> ?MODULE, user_receive_packet, 75), ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE, caps_stream_features, 75), - ejabberd_hooks:delete(s2s_stream_features, Host, + ejabberd_hooks:delete(s2s_in_post_auth_features, Host, ?MODULE, caps_stream_features, 75), ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, disco_features, 75), diff --git a/src/mod_carboncopy.erl b/src/mod_carboncopy.erl index 5839a65b2..ea44aed95 100644 --- a/src/mod_carboncopy.erl +++ b/src/mod_carboncopy.erl @@ -35,8 +35,8 @@ -export([start/2, stop/1]). --export([user_send_packet/4, user_receive_packet/5, - iq_handler/1, remove_connection/4, +-export([user_send_packet/1, user_receive_packet/1, + iq_handler/1, remove_connection/4, disco_features/5, is_carbon_copy/1, mod_opt_type/1, depends/2]). -include("ejabberd.hrl"). @@ -59,7 +59,7 @@ is_carbon_copy(_) -> start(Host, Opts) -> IQDisc = gen_mod:get_opt(iqdisc, Opts,fun gen_iq_handler:check_type/1, one_queue), - mod_disco:register_feature(Host, ?NS_CARBONS_2), + ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 50), Mod = gen_mod:db_mod(Host, ?MODULE), Mod:init(Host, Opts), ejabberd_hooks:add(unset_presence_hook,Host, ?MODULE, remove_connection, 10), @@ -70,12 +70,24 @@ start(Host, Opts) -> stop(Host) -> gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_CARBONS_2), - mod_disco:unregister_feature(Host, ?NS_CARBONS_2), + ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, disco_features, 50), %% why priority 89: to define clearly that we must run BEFORE mod_logdb hook (90) ejabberd_hooks:delete(user_send_packet,Host, ?MODULE, user_send_packet, 89), ejabberd_hooks:delete(user_receive_packet,Host, ?MODULE, user_receive_packet, 89), ejabberd_hooks:delete(unset_presence_hook,Host, ?MODULE, remove_connection, 10). +-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty, + jid(), jid(), binary(), binary()) -> + {error, stanza_error()} | {result, [binary()]}. +disco_features({error, Err}, _From, _To, _Node, _Lang) -> + {error, Err}; +disco_features(empty, _From, _To, <<"">>, _Lang) -> + {result, [?NS_CARBONS_2]}; +disco_features({result, Feats}, _From, _To, <<"">>, _Lang) -> + {result, [?NS_CARBONS_2|Feats]}; +disco_features(Acc, _From, _To, _Node, _Lang) -> + Acc. + -spec iq_handler(iq()) -> iq(). iq_handler(#iq{type = set, lang = Lang, from = From, sub_els = [El]} = IQ) when is_record(El, carbons_enable); @@ -105,16 +117,24 @@ iq_handler(#iq{type = get, lang = Lang} = IQ)-> Txt = <<"Value 'get' of 'type' attribute is not allowed">>, xmpp:make_error(IQ, xmpp:err_not_allowed(Txt, Lang)). --spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> - stanza() | {stop, stanza()}. -user_send_packet(Packet, _C2SState, From, To) -> - check_and_forward(From, To, Packet, sent). +-spec user_send_packet({stanza(), ejabberd_c2s:state()}) + -> {stanza(), ejabberd_c2s:state()} | {stop, {stanza(), ejabberd_c2s:state()}}. +user_send_packet({Packet, C2SState}) -> + From = xmpp:get_from(Packet), + To = xmpp:get_to(Packet), + case check_and_forward(From, To, Packet, sent) of + {stop, Pkt} -> {stop, {Pkt, C2SState}}; + Pkt -> {Pkt, C2SState} + end. --spec user_receive_packet(stanza(), ejabberd_c2s:state(), - jid(), jid(), jid()) -> - stanza() | {stop, stanza()}. -user_receive_packet(Packet, _C2SState, JID, _From, To) -> - check_and_forward(JID, To, Packet, received). +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) + -> {stanza(), ejabberd_c2s:state()} | {stop, {stanza(), ejabberd_c2s:state()}}. +user_receive_packet({Packet, #{jid := JID} = C2SState}) -> + To = xmpp:get_to(Packet), + case check_and_forward(JID, To, Packet, received) of + {stop, Pkt} -> {stop, {Pkt, C2SState}}; + Pkt -> {Pkt, C2SState} + end. % Modified from original version: % - registered to the user_send_packet hook, to be called only once even for multicast diff --git a/src/mod_client_state.erl b/src/mod_client_state.erl index a838088fc..175929a57 100644 --- a/src/mod_client_state.erl +++ b/src/mod_client_state.erl @@ -34,8 +34,11 @@ -export([start/2, stop/1, mod_opt_type/1, depends/2]). %% ejabberd_hooks callbacks. --export([filter_presence/4, filter_chat_states/4, filter_pep/4, filter_other/4, - flush_queue/3, add_stream_feature/2]). +-export([filter_presence/1, filter_chat_states/1, + filter_pep/1, filter_other/1, + c2s_stream_started/2, add_stream_feature/2, + c2s_copy_session/2, c2s_authenticated_packet/2, + c2s_session_resumed/1]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -44,9 +47,10 @@ -define(CSI_QUEUE_MAX, 100). -type csi_type() :: presence | chatstate | {pep, binary()}. --type csi_key() :: {ljid(), csi_type()}. --type csi_stanza() :: {csi_key(), erlang:timestamp(), xmlel()}. --type csi_queue() :: [csi_stanza()]. +-type csi_queue() :: {non_neg_integer(), non_neg_integer(), map()}. +-type csi_timestamp() :: {non_neg_integer(), erlang:timestamp()}. +-type c2s_state() :: ejabberd_c2s:state(). +-type filter_acc() :: {stanza() | drop, c2s_state()}. %%-------------------------------------------------------------------- %% gen_mod callbacks. @@ -68,27 +72,33 @@ start(Host, Opts) -> fun(B) when is_boolean(B) -> B end, true), if QueuePresence; QueueChatStates; QueuePEP -> + ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE, + c2s_stream_started, 50), ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE, add_stream_feature, 50), + ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE, + c2s_authenticated_packet, 50), + ejabberd_hooks:add(c2s_copy_session, Host, ?MODULE, + c2s_copy_session, 50), + ejabberd_hooks:add(c2s_session_resumed, Host, ?MODULE, + c2s_session_resumed, 50), if QueuePresence -> - ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, filter_presence, 50); true -> ok end, if QueueChatStates -> - ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, filter_chat_states, 50); true -> ok end, if QueuePEP -> - ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, filter_pep, 50); true -> ok end, - ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE, - filter_other, 100), - ejabberd_hooks:add(csi_flush_queue, Host, ?MODULE, - flush_queue, 50); + ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, + filter_other, 75); true -> ok end. @@ -108,27 +118,33 @@ stop(Host) -> fun(B) when is_boolean(B) -> B end, true), if QueuePresence; QueueChatStates; QueuePEP -> + ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE, + c2s_stream_started, 50), ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE, add_stream_feature, 50), + ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE, + c2s_authenticated_packet, 50), + ejabberd_hooks:delete(c2s_copy_session, Host, ?MODULE, + c2s_copy_session, 50), + ejabberd_hooks:delete(c2s_session_resumed, Host, ?MODULE, + c2s_session_resumed, 50), if QueuePresence -> - ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, filter_presence, 50); true -> ok end, if QueueChatStates -> - ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, filter_chat_states, 50); true -> ok end, if QueuePEP -> - ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, filter_pep, 50); true -> ok end, - ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE, - filter_other, 100), - ejabberd_hooks:delete(csi_flush_queue, Host, ?MODULE, - flush_queue, 50); + ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, + filter_other, 75); true -> ok end. @@ -150,29 +166,46 @@ depends(_Host, _Opts) -> %%-------------------------------------------------------------------- %% ejabberd_hooks callbacks. %%-------------------------------------------------------------------- +-spec c2s_stream_started(c2s_state(), stream_start()) -> c2s_state(). +c2s_stream_started(State, _) -> + State#{csi_state => active, csi_queue => queue_new()}. --spec filter_presence({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza()) - -> {ejabberd_c2s:state(), [stanza()]} | - {stop, {ejabberd_c2s:state(), [stanza()]}}. +-spec c2s_authenticated_packet(c2s_state(), xmpp_element()) -> c2s_state(). +c2s_authenticated_packet(C2SState, #csi{type = active}) -> + C2SState1 = C2SState#{csi_state => active}, + flush_queue(C2SState1); +c2s_authenticated_packet(C2SState, #csi{type = inactive}) -> + C2SState#{csi_state => inactive}; +c2s_authenticated_packet(C2SState, _) -> + C2SState. -filter_presence({C2SState, _OutStanzas} = Acc, Host, To, - #presence{type = Type} = Stanza) -> - if Type == available; Type == unavailable -> - ?DEBUG("Got availability presence stanza for ~s", - [jid:to_string(To)]), - queue_add(presence, Stanza, Host, C2SState); - true -> - Acc - end; -filter_presence(Acc, _Host, _To, _Stanza) -> Acc. +-spec c2s_copy_session(c2s_state(), c2s_state()) -> c2s_state(). +c2s_copy_session(C2SState, #{csi_state := State, csi_queue := Q}) -> + C2SState#{csi_state => State, csi_queue => Q}; +c2s_copy_session(C2SState, _) -> + C2SState. --spec filter_chat_states({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza()) - -> {ejabberd_c2s:state(), [stanza()]} | - {stop, {ejabberd_c2s:state(), [stanza()]}}. +-spec c2s_session_resumed(c2s_state()) -> c2s_state(). +c2s_session_resumed(C2SState) -> + flush_queue(C2SState). -filter_chat_states({C2SState, _OutStanzas} = Acc, Host, To, - #message{from = From} = Stanza) -> - case xmpp_util:is_standalone_chat_state(Stanza) of +-spec filter_presence(filter_acc()) -> filter_acc(). +filter_presence({#presence{meta = #{csi_resend := true}}, _} = Acc) -> + Acc; +filter_presence({#presence{to = To, type = Type} = Pres, + #{csi_state := inactive} = C2SState}) + when Type == available; Type == unavailable -> + ?DEBUG("Got availability presence stanza for ~s", [jid:to_string(To)]), + enqueue_stanza(presence, Pres, C2SState); +filter_presence(Acc) -> + Acc. + +-spec filter_chat_states(filter_acc()) -> filter_acc(). +filter_chat_states({#message{meta = #{csi_resend := true}}, _} = Acc) -> + Acc; +filter_chat_states({#message{from = From, to = To} = Msg, + #{csi_state := inactive} = C2SState} = Acc) -> + case xmpp_util:is_standalone_chat_state(Msg) of true -> case {From, To} of {#jid{luser = U, lserver = S}, #jid{luser = U, lserver = S}} -> @@ -181,105 +214,109 @@ filter_chat_states({C2SState, _OutStanzas} = Acc, Host, To, %% conversations across clients. Acc; _ -> - ?DEBUG("Got standalone chat state notification for ~s", - [jid:to_string(To)]), - queue_add(chatstate, Stanza, Host, C2SState) + ?DEBUG("Got standalone chat state notification for ~s", + [jid:to_string(To)]), + enqueue_stanza(chatstate, Msg, C2SState) end; false -> Acc end; -filter_chat_states(Acc, _Host, _To, _Stanza) -> Acc. +filter_chat_states(Acc) -> + Acc. --spec filter_pep({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza()) - -> {ejabberd_c2s:state(), [stanza()]} | - {stop, {ejabberd_c2s:state(), [stanza()]}}. - -filter_pep({C2SState, _OutStanzas} = Acc, Host, To, #message{} = Stanza) -> - case get_pep_node(Stanza) of +-spec filter_pep(filter_acc()) -> filter_acc(). +filter_pep({#message{meta = #{csi_resend := true}}, _} = Acc) -> + Acc; +filter_pep({#message{to = To} = Msg, + #{csi_state := inactive} = C2SState} = Acc) -> + case get_pep_node(Msg) of undefined -> Acc; Node -> ?DEBUG("Got PEP notification for ~s", [jid:to_string(To)]), - queue_add({pep, Node}, Stanza, Host, C2SState) + enqueue_stanza({pep, Node}, Msg, C2SState) end; -filter_pep(Acc, _Host, _To, _Stanza) -> Acc. +filter_pep(Acc) -> + Acc. --spec filter_other({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza()) - -> {ejabberd_c2s:state(), [stanza()]}. +-spec filter_other(filter_acc()) -> filter_acc(). +filter_other({Stanza, #{jid := JID} = C2SState} = Acc) when ?is_stanza(Stanza) -> + case xmpp:get_meta(Stanza) of + #{csi_resend := true} -> + Acc; + _ -> + ?DEBUG("Won't add stanza for ~s to CSI queue", [jid:to_string(JID)]), + From = xmpp:get_from(Stanza), + C2SState1 = dequeue_sender(From, C2SState), + {Stanza, C2SState1} + end; +filter_other(Acc) -> + Acc. -filter_other({C2SState, _OutStanzas}, Host, To, Stanza) -> - ?DEBUG("Won't add stanza for ~s to CSI queue", [jid:to_string(To)]), - queue_take(Stanza, Host, C2SState). - --spec flush_queue({ejabberd_c2s:state(), [stanza()]}, binary(), jid()) - -> {ejabberd_c2s:state(), [stanza()]}. - -flush_queue({C2SState, _OutStanzas}, Host, JID) -> - ?DEBUG("Going to flush CSI queue of ~s", [jid:to_string(JID)]), - Queue = get_queue(C2SState), - NewState = set_queue([], C2SState), - {NewState, get_stanzas(Queue, Host)}. - --spec add_stream_feature([stanza()], binary) -> [stanza()]. - -add_stream_feature(Features, _Host) -> - [#feature_csi{xmlns = <<"urn:xmpp:csi:0">>} | Features]. +-spec add_stream_feature([xmpp_element()], binary) -> [xmpp_element()]. +add_stream_feature(Features, Host) -> + case gen_mod:is_loaded(Host, ?MODULE) of + true -> + [#feature_csi{xmlns = <<"urn:xmpp:csi:0">>} | Features]; + false -> + Features + end. %%-------------------------------------------------------------------- %% Internal functions. %%-------------------------------------------------------------------- +-spec enqueue_stanza(csi_type(), stanza(), c2s_state()) -> filter_acc(). +enqueue_stanza(Type, Stanza, #{csi_state := inactive, + csi_queue := Q} = C2SState) -> + case queue_len(Q) >= ?CSI_QUEUE_MAX of + true -> + ?DEBUG("CSI queue too large, going to flush it", []), + C2SState1 = flush_queue(C2SState), + enqueue_stanza(Type, Stanza, C2SState1); + false -> + #jid{luser = U, lserver = S} = xmpp:get_from(Stanza), + Q1 = queue_in({U, S}, Type, Stanza, Q), + {stop, {drop, C2SState#{csi_queue => Q1}}} + end; +enqueue_stanza(_Type, Stanza, State) -> + {Stanza, State}. --spec queue_add(csi_type(), stanza(), binary(), term()) - -> {stop, {term(), [stanza()]}}. - -queue_add(Type, Stanza, Host, C2SState) -> - case get_queue(C2SState) of - Queue when length(Queue) >= ?CSI_QUEUE_MAX -> - ?DEBUG("CSI queue too large, going to flush it", []), - NewState = set_queue([], C2SState), - {stop, {NewState, get_stanzas(Queue, Host) ++ [Stanza]}}; - Queue -> - ?DEBUG("Adding stanza to CSI queue", []), - From = xmpp:get_from(Stanza), - Key = {jid:tolower(From), Type}, - Entry = {Key, p1_time_compat:timestamp(), Stanza}, - NewQueue = lists:keystore(Key, 1, Queue, Entry), - NewState = set_queue(NewQueue, C2SState), - {stop, {NewState, []}} +-spec dequeue_sender(jid(), c2s_state()) -> c2s_state(). +dequeue_sender(#jid{luser = U, lserver = S}, + #{csi_queue := Q, jid := JID} = C2SState) -> + ?DEBUG("Flushing packets of ~s@~s from CSI queue of ~s", + [U, S, jid:to_string(JID)]), + case queue_take({U, S}, Q) of + {Stanzas, Q1} -> + C2SState1 = flush_stanzas(C2SState, Stanzas), + C2SState1#{csi_queue => Q1}; + error -> + C2SState end. --spec queue_take(stanza(), binary(), term()) -> {term(), [stanza()]}. +-spec flush_queue(c2s_state()) -> c2s_state(). +flush_queue(#{csi_queue := Q, jid := JID} = C2SState) -> + ?DEBUG("Flushing CSI queue of ~s", [jid:to_string(JID)]), + C2SState1 = flush_stanzas(C2SState, queue_to_list(Q)), + C2SState1#{csi_queue => queue_new()}. -queue_take(Stanza, Host, C2SState) -> - From = xmpp:get_from(Stanza), - {LUser, LServer, _LResource} = jid:tolower(From), - {Selected, Rest} = lists:partition( - fun({{{U, S, _R}, _Type}, _Time, _Stanza}) -> - U == LUser andalso S == LServer - end, get_queue(C2SState)), - NewState = set_queue(Rest, C2SState), - {NewState, get_stanzas(Selected, Host) ++ [Stanza]}. +-spec flush_stanzas(c2s_state(), + [{csi_type(), csi_timestamp(), stanza()}]) -> c2s_state(). +flush_stanzas(#{lserver := LServer} = C2SState, Elems) -> + lists:foldl( + fun({_Type, Time, Stanza}, AccState) -> + Stanza1 = add_delay_info(Stanza, LServer, Time), + ejabberd_c2s:send(AccState, Stanza1) + end, C2SState, Elems). --spec set_queue(csi_queue(), ejabberd_c2s:state()) -> ejabberd_c2s:state(). - -set_queue(Queue, C2SState) -> - C2SState#{csi_queue => Queue}. - --spec get_queue(ejabberd_c2s:state()) -> csi_queue(). - -get_queue(C2SState) -> - maps:get(csi_queue, C2SState, []). - --spec get_stanzas(csi_queue(), binary()) -> [stanza()]. - -get_stanzas(Queue, Host) -> - lists:map(fun({_Key, Time, Stanza}) -> - xmpp_util:add_delay_info(Stanza, jid:make(Host), Time, - <<"Client Inactive">>) - end, Queue). +-spec add_delay_info(stanza(), binary(), csi_timestamp()) -> stanza(). +add_delay_info(Stanza, LServer, {_Seq, TimeStamp}) -> + Stanza1 = xmpp_util:add_delay_info( + Stanza, jid:make(LServer), TimeStamp, + <<"Client Inactive">>), + xmpp:put_meta(Stanza1, csi_resend, true). -spec get_pep_node(message()) -> binary() | undefined. - get_pep_node(#message{from = #jid{luser = <<>>}}) -> %% It's not PEP. undefined; @@ -290,3 +327,53 @@ get_pep_node(#message{} = Msg) -> _ -> undefined end. + +%%-------------------------------------------------------------------- +%% Queue interface +%%-------------------------------------------------------------------- +-spec queue_new() -> csi_queue(). +queue_new() -> + {0, 0, #{}}. + +-spec queue_in(term(), term(), term(), csi_queue()) -> csi_queue(). +queue_in(Key, Type, Val, {N, Seq, Q}) -> + Seq1 = Seq + 1, + Time = {Seq1, p1_time_compat:timestamp()}, + try maps:get(Key, Q) of + TypeVals -> + case lists:keymember(Type, 1, TypeVals) of + true -> + TypeVals1 = lists:keyreplace( + Type, 1, TypeVals, {Type, Time, Val}), + Q1 = maps:put(Key, TypeVals1, Q), + {N, Seq1, Q1}; + false -> + TypeVals1 = [{Type, Time, Val}|TypeVals], + Q1 = maps:put(Key, TypeVals1, Q), + {N + 1, Seq1, Q1} + end + catch _:{badkey, _} -> + Q1 = maps:put(Key, [{Type, Time, Val}], Q), + {N + 1, Seq1, Q1} + end. + +-spec queue_take(term(), csi_queue()) -> {list(), csi_queue()} | error. +queue_take(Key, {N, Seq, Q}) -> + case maps:take(Key, Q) of + {TypeVals, Q1} -> + {lists:keysort(2, TypeVals), {N-length(TypeVals), Seq, Q1}}; + error -> + error + end. + +-spec queue_len(csi_queue()) -> non_neg_integer(). +queue_len({N, _, _}) -> + N. + +-spec queue_to_list(csi_queue()) -> [term()]. +queue_to_list({_, _, Q}) -> + TypeVals = maps:fold( + fun(_, Vals, Acc) -> + Vals ++ Acc + end, [], Q), + lists:keysort(2, TypeVals). diff --git a/src/mod_disco.erl b/src/mod_disco.erl index 953d1da10..54720f716 100644 --- a/src/mod_disco.erl +++ b/src/mod_disco.erl @@ -37,9 +37,7 @@ get_local_features/5, get_local_services/5, process_sm_iq_items/1, process_sm_iq_info/1, get_sm_identity/5, get_sm_features/5, get_sm_items/5, - get_info/5, register_feature/2, unregister_feature/2, - register_extra_domain/2, unregister_extra_domain/2, - transform_module_options/1, mod_opt_type/1, depends/2]). + get_info/5, transform_module_options/1, mod_opt_type/1, depends/2]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -48,8 +46,10 @@ -include_lib("stdlib/include/ms_transform.hrl"). -include("mod_roster.hrl"). +-type features_acc() :: {error, stanza_error()} | {result, [binary()]} | empty. +-type items_acc() :: {error, stanza_error()} | {result, [disco_item()]} | empty. + start(Host, Opts) -> - ejabberd_local:refresh_iq_handlers(), IQDisc = gen_mod:get_opt(iqdisc, Opts, fun gen_iq_handler:check_type/1, one_queue), gen_iq_handler:add_iq_handler(ejabberd_local, Host, @@ -64,12 +64,9 @@ start(Host, Opts) -> gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_DISCO_INFO, ?MODULE, process_sm_iq_info, IQDisc), - catch ets:new(disco_features, - [named_table, ordered_set, public]), - register_feature(Host, <<"iq">>), - register_feature(Host, <<"presence">>), catch ets:new(disco_extra_domains, - [named_table, ordered_set, public]), + [named_table, ordered_set, public, + {heir, erlang:group_leader(), none}]), ExtraDomains = gen_mod:get_opt(extra_domains, Opts, fun(Hs) -> [iolist_to_binary(H) || H <- Hs] @@ -78,10 +75,6 @@ start(Host, Opts) -> register_extra_domain(Host, Domain) end, ExtraDomains), - catch ets:new(disco_sm_features, - [named_table, ordered_set, public]), - catch ets:new(disco_sm_nodes, - [named_table, ordered_set, public]), ejabberd_hooks:add(disco_local_items, Host, ?MODULE, get_local_services, 100), ejabberd_hooks:add(disco_local_features, Host, ?MODULE, @@ -121,35 +114,14 @@ stop(Host) -> ?NS_DISCO_ITEMS), gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_DISCO_INFO), - catch ets:match_delete(disco_features, {{'_', Host}}), catch ets:match_delete(disco_extra_domains, {{'_', Host}}), ok. --spec register_feature(binary(), binary()) -> true. -register_feature(Host, Feature) -> - catch ets:new(disco_features, - [named_table, ordered_set, public]), - ets:insert(disco_features, {{Feature, Host}}). - --spec unregister_feature(binary(), binary()) -> true. -unregister_feature(Host, Feature) -> - catch ets:new(disco_features, - [named_table, ordered_set, public]), - ets:delete(disco_features, {Feature, Host}). - -spec register_extra_domain(binary(), binary()) -> true. register_extra_domain(Host, Domain) -> - catch ets:new(disco_extra_domains, - [named_table, ordered_set, public]), ets:insert(disco_extra_domains, {{Domain, Host}}). --spec unregister_extra_domain(binary(), binary()) -> true. -unregister_extra_domain(Host, Domain) -> - catch ets:new(disco_extra_domains, - [named_table, ordered_set, public]), - ets:delete(disco_extra_domains, {Domain, Host}). - -spec process_local_iq_items(iq()) -> iq(). process_local_iq_items(#iq{type = set, lang = Lang} = IQ) -> Txt = <<"Value 'set' of 'type' attribute is not allowed">>, @@ -198,22 +170,17 @@ get_local_identity(Acc, _From, _To, <<"">>, _Lang) -> get_local_identity(Acc, _From, _To, _Node, _Lang) -> Acc. --spec get_local_features({error, stanza_error()} | {result, [binary()]} | empty, - jid(), jid(), binary(), binary()) -> +-spec get_local_features(features_acc(), jid(), jid(), binary(), binary()) -> {error, stanza_error()} | {result, [binary()]}. get_local_features({error, _Error} = Acc, _From, _To, _Node, _Lang) -> Acc; -get_local_features(Acc, _From, To, <<"">>, _Lang) -> +get_local_features(Acc, _From, _To, <<"">>, _Lang) -> Feats = case Acc of {result, Features} -> Features; empty -> [] end, - Host = To#jid.lserver, - {result, - ets:select(disco_features, - ets:fun2ms(fun({{F, H}}) when H == Host -> F end)) - ++ Feats}; + {result, [<<"iq">>, <<"presence">>|Feats]}; get_local_features(Acc, _From, _To, _Node, Lang) -> case Acc of {result, _Features} -> Acc; @@ -222,9 +189,7 @@ get_local_features(Acc, _From, _To, _Node, Lang) -> {error, xmpp:err_item_not_found(Txt, Lang)} end. --spec get_local_services({error, stanza_error()} | {result, [disco_item()]} | empty, - jid(), jid(), - binary(), binary()) -> +-spec get_local_services(items_acc(), jid(), jid(), binary(), binary()) -> {error, stanza_error()} | {result, [disco_item()]}. get_local_services({error, _Error} = Acc, _From, _To, _Node, _Lang) -> @@ -296,9 +261,7 @@ process_sm_iq_items(#iq{type = get, lang = Lang, xmpp:make_error(IQ, xmpp:err_subscription_required(Txt, Lang)) end. --spec get_sm_items({error, stanza_error()} | {result, [disco_item()]} | empty, - jid(), jid(), - binary(), binary()) -> +-spec get_sm_items(items_acc(), jid(), jid(), binary(), binary()) -> {error, stanza_error()} | {result, [disco_item()]}. get_sm_items({error, _Error} = Acc, _From, _To, _Node, _Lang) -> @@ -383,8 +346,7 @@ get_sm_identity(Acc, _From, _ -> [] end. --spec get_sm_features({error, stanza_error()} | {result, [binary()]} | empty, - jid(), jid(), binary(), binary()) -> +-spec get_sm_features(features_acc(), jid(), jid(), binary(), binary()) -> {error, stanza_error()} | {result, [binary()]}. get_sm_features(empty, From, To, _Node, Lang) -> #jid{luser = LFrom, lserver = LSFrom} = From, diff --git a/src/mod_fail2ban.erl b/src/mod_fail2ban.erl index cc3b4bf7f..e8cc29816 100644 --- a/src/mod_fail2ban.erl +++ b/src/mod_fail2ban.erl @@ -29,7 +29,8 @@ -behaviour(gen_server). %% API --export([start_link/2, start/2, stop/1, c2s_auth_result/4, check_bl_c2s/3]). +-export([start_link/2, start/2, stop/1, c2s_auth_result/3, + c2s_stream_started/2]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3, @@ -38,6 +39,7 @@ -include_lib("stdlib/include/ms_transform.hrl"). -include("ejabberd.hrl"). -include("logger.hrl"). +-include("xmpp.hrl"). -define(C2S_AUTH_BAN_LIFETIME, 3600). %% 1 hour -define(C2S_MAX_AUTH_FAILURES, 20). @@ -52,12 +54,12 @@ start_link(Host, Opts) -> Proc = gen_mod:get_module_proc(Host, ?MODULE), gen_server:start_link({local, Proc}, ?MODULE, [Host, Opts], []). --spec c2s_auth_result(boolean(), binary(), binary(), - {inet:ip_address(), non_neg_integer()}) -> ok. -c2s_auth_result(false, _User, LServer, {Addr, _Port}) -> +-spec c2s_auth_result(ejabberd_c2s:state(), boolean(), binary()) + -> ejabberd_c2s:state() | {stop, ejabberd_c2s:state()}. +c2s_auth_result(#{ip := {Addr, _}, lserver := LServer} = State, false, _User) -> case is_whitelisted(LServer, Addr) of true -> - ok; + State; false -> BanLifetime = gen_mod:get_module_opt( LServer, ?MODULE, c2s_auth_ban_lifetime, @@ -68,47 +70,41 @@ c2s_auth_result(false, _User, LServer, {Addr, _Port}) -> fun(I) when is_integer(I), I > 0 -> I end, ?C2S_MAX_AUTH_FAILURES), UnbanTS = p1_time_compat:system_time(seconds) + BanLifetime, - case ets:lookup(failed_auth, Addr) of - [{Addr, N, _, _}] -> - ets:insert(failed_auth, {Addr, N+1, UnbanTS, MaxFailures}); - [] -> - ets:insert(failed_auth, {Addr, 1, UnbanTS, MaxFailures}) - end, - ok + Attempts = case ets:lookup(failed_auth, Addr) of + [{Addr, N, _, _}] -> + ets:insert(failed_auth, + {Addr, N+1, UnbanTS, MaxFailures}), + N+1; + [] -> + ets:insert(failed_auth, + {Addr, 1, UnbanTS, MaxFailures}), + 1 + end, + if Attempts >= MaxFailures -> + log_and_disconnect(State, Attempts, UnbanTS); + true -> + State + end end; -c2s_auth_result(true, _User, _Server, _AddrPort) -> - ok. +c2s_auth_result(#{ip := {Addr, _}} = State, true, _User) -> + ets:delete(failed_auth, Addr), + State. --spec check_bl_c2s({true, binary(), binary()} | false, - {inet:ip_address(), non_neg_integer()}, - binary()) -> {stop, {true, binary(), binary()}} | false. -check_bl_c2s(_Acc, Addr, Lang) -> +-spec c2s_stream_started(ejabberd_c2s:state(), stream_start()) + -> ejabberd_c2s:state() | {stop, ejabberd_c2s:state()}. +c2s_stream_started(#{ip := {Addr, _}} = State, _) -> + ets:tab2list(failed_auth), case ets:lookup(failed_auth, Addr) of [{Addr, N, TS, MaxFailures}] when N >= MaxFailures -> case TS > p1_time_compat:system_time(seconds) of true -> - IP = jlib:ip_to_list(Addr), - UnbanDate = format_date( - calendar:now_to_universal_time(seconds_to_now(TS))), - LogReason = io_lib:fwrite( - "Too many (~p) failed authentications " - "from this IP address (~s). The address " - "will be unblocked at ~s UTC", - [N, IP, UnbanDate]), - ReasonT = io_lib:fwrite( - translate:translate( - Lang, - <<"Too many (~p) failed authentications " - "from this IP address (~s). The address " - "will be unblocked at ~s UTC">>), - [N, IP, UnbanDate]), - {stop, {true, LogReason, ReasonT}}; + log_and_disconnect(State, N, TS); false -> ets:delete(failed_auth, Addr), - false + State end; _ -> - false + State end. %%==================================================================== @@ -134,7 +130,7 @@ depends(_Host, _Opts) -> %%%=================================================================== init([Host, _Opts]) -> ejabberd_hooks:add(c2s_auth_result, Host, ?MODULE, c2s_auth_result, 100), - ejabberd_hooks:add(check_bl_c2s, ?MODULE, check_bl_c2s, 100), + ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE, c2s_stream_started, 100), erlang:send_after(?CLEAN_INTERVAL, self(), clean), {ok, #state{host = Host}}. @@ -160,11 +156,11 @@ handle_info(_Info, State) -> terminate(_Reason, #state{host = Host}) -> ejabberd_hooks:delete(c2s_auth_result, Host, ?MODULE, c2s_auth_result, 100), + ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE, c2s_stream_started, 100), case is_loaded_at_other_hosts(Host) of true -> ok; false -> - ejabberd_hooks:delete(check_bl_c2s, ?MODULE, check_bl_c2s, 100), ets:delete(failed_auth) end. @@ -174,6 +170,21 @@ code_change(_OldVsn, State, _Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== +-spec log_and_disconnect(ejabberd_c2s:state(), pos_integer(), erlang:timestamp()) + -> {stop, ejabberd_c2s:state()}. +log_and_disconnect(#{ip := {Addr, _}, lang := Lang} = State, Attempts, UnbanTS) -> + IP = jlib:ip_to_list(Addr), + UnbanDate = format_date( + calendar:now_to_universal_time(seconds_to_now(UnbanTS))), + Format = <<"Too many (~p) failed authentications " + "from this IP address (~s). The address " + "will be unblocked at ~s UTC">>, + Args = [Attempts, IP, UnbanDate], + ?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s", + [IP, io_lib:fwrite(Format, Args)]), + Err = xmpp:serr_policy_violation({Format, Args}, Lang), + {stop, ejabberd_c2s:send(State, Err)}. + is_whitelisted(Host, Addr) -> Access = gen_mod:get_module_opt(Host, ?MODULE, access, fun(A) -> A end, diff --git a/src/mod_http_fileserver.erl b/src/mod_http_fileserver.erl index a896cb8b4..734dbb126 100644 --- a/src/mod_http_fileserver.erl +++ b/src/mod_http_fileserver.erl @@ -46,7 +46,7 @@ %% utility for other http modules -export([content_type/3]). --export([reopen_log/1, mod_opt_type/1, depends/2]). +-export([reopen_log/0, mod_opt_type/1, depends/2]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -236,7 +236,7 @@ check_docroot_is_readable(DRInfo, DocRoot) -> try_open_log(undefined, _Host) -> undefined; -try_open_log(FN, Host) -> +try_open_log(FN, _Host) -> FD = try open_log(FN) of FD1 -> FD1 catch @@ -244,7 +244,7 @@ try_open_log(FN, Host) -> ?ERROR_MSG("Cannot open access log file: ~p~nReason: ~p", [FN, Reason]), undefined end, - ejabberd_hooks:add(reopen_log_hook, Host, ?MODULE, reopen_log, 50), + ejabberd_hooks:add(reopen_log_hook, ?MODULE, reopen_log, 50), FD. %%-------------------------------------------------------------------- @@ -298,7 +298,8 @@ handle_info(_Info, State) -> %%-------------------------------------------------------------------- terminate(_Reason, State) -> close_log(State#state.accesslogfd), - ejabberd_hooks:delete(reopen_log_hook, State#state.host, ?MODULE, reopen_log, 50), + %% TODO: unregister the hook gracefully + %% ejabberd_hooks:delete(reopen_log_hook, State#state.host, ?MODULE, reopen_log, 50), ok. %%-------------------------------------------------------------------- @@ -410,8 +411,11 @@ reopen_log(FN, FD) -> close_log(FD), open_log(FN). -reopen_log(Host) -> - gen_server:cast(get_proc_name(Host), reopen_log). +reopen_log() -> + lists:foreach( + fun(Host) -> + gen_server:cast(get_proc_name(Host), reopen_log) + end, ?MYHOSTS). add_to_log(FileSize, Code, Request) -> gen_server:cast(get_proc_name(Request#request.host), diff --git a/src/mod_last.erl b/src/mod_last.erl index 2c17dcda3..7a08d362b 100644 --- a/src/mod_last.erl +++ b/src/mod_last.erl @@ -130,13 +130,10 @@ process_sm_iq(#iq{from = From, to = To, lang = Lang} = IQ) -> if (Subscription == both) or (Subscription == from) or (From#jid.luser == To#jid.luser) and (From#jid.lserver == To#jid.lserver) -> - UserListRecord = - ejabberd_hooks:run_fold(privacy_get_user_list, Server, - #userlist{}, [User, Server]), + Pres = xmpp:set_from_to(#presence{}, To, From), case ejabberd_hooks:run_fold(privacy_check_packet, Server, allow, - [User, Server, UserListRecord, - {To, From, #presence{}}, out]) of + [To, Pres, out]) of allow -> get_last_iq(IQ, User, Server); deny -> xmpp:make_error(IQ, xmpp:err_forbidden()) end; diff --git a/src/mod_mam.erl b/src/mod_mam.erl index edb0d1485..0e2d77d96 100644 --- a/src/mod_mam.erl +++ b/src/mod_mam.erl @@ -33,10 +33,10 @@ %% API -export([start/2, stop/1, depends/2]). --export([user_send_packet/4, user_send_packet_strip_tag/4, user_receive_packet/5, +-export([user_send_packet/1, user_send_packet_strip_tag/1, user_receive_packet/1, process_iq_v0_2/1, process_iq_v0_3/1, disco_sm_features/5, remove_user/2, remove_room/3, mod_opt_type/1, muc_process_iq/2, - muc_filter_message/5, message_is_archived/5, delete_old_messages/2, + muc_filter_message/5, message_is_archived/3, delete_old_messages/2, get_commands_spec/0, msg_to_el/4, get_room_config/4, set_room_option/3]). -include("xmpp.hrl"). @@ -200,46 +200,50 @@ set_room_option(_Acc, {mam, Val}, _Lang) -> set_room_option(Acc, _Property, _Lang) -> Acc. --spec user_receive_packet(stanza(), ejabberd_c2s:state(), jid(), jid(), jid()) -> stanza(). -user_receive_packet(Pkt, C2SState, JID, Peer, _To) -> +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_receive_packet({Pkt, #{jid := JID} = C2SState}) -> + Peer = xmpp:get_from(Pkt), LUser = JID#jid.luser, LServer = JID#jid.lserver, - case should_archive(Pkt, LServer) of - true -> - NewPkt = strip_my_archived_tag(Pkt, LServer), - case store_msg(C2SState, NewPkt, LUser, LServer, Peer, recv) of - {ok, ID} -> - set_stanza_id(NewPkt, JID, ID); - _ -> - NewPkt - end; - _ -> - Pkt - end. + Pkt2 = case should_archive(Pkt, LServer) of + true -> + Pkt1 = strip_my_archived_tag(Pkt, LServer), + case store_msg(C2SState, Pkt1, LUser, LServer, Peer, recv) of + {ok, ID} -> + set_stanza_id(Pkt1, JID, ID); + _ -> + Pkt1 + end; + _ -> + Pkt + end, + {Pkt2, C2SState}. --spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -user_send_packet(Pkt, C2SState, JID, Peer) -> +-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_send_packet({Pkt, #{jid := JID} = C2SState}) -> + Peer = xmpp:get_to(Pkt), LUser = JID#jid.luser, LServer = JID#jid.lserver, - case should_archive(Pkt, LServer) of - true -> - NewPkt = strip_my_archived_tag(Pkt, LServer), - case store_msg(C2SState, xmpp:set_from_to(NewPkt, JID, Peer), - LUser, LServer, Peer, send) of - {ok, ID} -> - set_stanza_id(NewPkt, JID, ID); - _ -> - NewPkt - end; - false -> - Pkt - end. + Pkt2 = case should_archive(Pkt, LServer) of + true -> + Pkt1 = strip_my_archived_tag(Pkt, LServer), + case store_msg(C2SState, xmpp:set_from_to(Pkt1, JID, Peer), + LUser, LServer, Peer, send) of + {ok, ID} -> + set_stanza_id(Pkt1, JID, ID); + _ -> + Pkt1 + end; + false -> + Pkt + end, + {Pkt2, C2SState}. --spec user_send_packet_strip_tag(stanza(), ejabberd_c2s:state(), - jid(), jid()) -> stanza(). -user_send_packet_strip_tag(Pkt, _C2SState, JID, _Peer) -> +-spec user_send_packet_strip_tag({stanza(), ejabberd_c2s:state()}) -> + {stanza(), ejabberd_c2s:state()}. +user_send_packet_strip_tag({Pkt, #{jid := JID} = C2SState}) -> LServer = JID#jid.lserver, - strip_my_archived_tag(Pkt, LServer). + {strip_my_archived_tag(Pkt, LServer), C2SState}. -spec muc_filter_message(message(), mod_muc_room:state(), jid(), jid(), binary()) -> message(). @@ -338,12 +342,12 @@ disco_sm_features({result, OtherFeatures}, disco_sm_features(Acc, _From, _To, _Node, _Lang) -> Acc. --spec message_is_archived(boolean(), ejabberd_c2s:state(), - jid(), jid(), message()) -> boolean(). -message_is_archived(true, _C2SState, _Peer, _JID, _Pkt) -> +-spec message_is_archived(boolean(), ejabberd_c2s:state(), message()) -> boolean(). +message_is_archived(true, _C2SState, _Pkt) -> true; -message_is_archived(false, C2SState, Peer, - #jid{luser = LUser, lserver = LServer}, Pkt) -> +message_is_archived(false, #{jid := JID} = C2SState, Pkt) -> + #jid{luser = LUser, lserver = LServer} = JID, + Peer = xmpp:get_from(Pkt), case gen_mod:get_module_opt(LServer, ?MODULE, assume_mam_usage, fun(B) when is_boolean(B) -> B end, false) of true -> diff --git a/src/mod_metrics.erl b/src/mod_metrics.erl index 7861542c5..1698690d4 100644 --- a/src/mod_metrics.erl +++ b/src/mod_metrics.erl @@ -38,8 +38,8 @@ -export([offline_message_hook/3, sm_register_connection_hook/3, sm_remove_connection_hook/3, - user_send_packet/4, user_receive_packet/5, - s2s_send_packet/3, s2s_receive_packet/3, + user_send_packet/1, user_receive_packet/1, + s2s_send_packet/3, s2s_receive_packet/1, remove_user/2, register_user/2]). %%==================================================================== @@ -86,23 +86,27 @@ sm_register_connection_hook(_SID, #jid{lserver=LServer}, _Info) -> sm_remove_connection_hook(_SID, #jid{lserver=LServer}, _Info) -> push(LServer, sm_remove_connection). --spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -user_send_packet(Packet, _C2SState, #jid{lserver=LServer}, _To) -> +-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_send_packet({Packet, #{jid := #jid{lserver = LServer}} = C2SState}) -> push(LServer, user_send_packet), - Packet. + {Packet, C2SState}. --spec user_receive_packet(stanza(), ejabberd_c2s:state(), jid(), jid(), jid()) -> stanza(). -user_receive_packet(Packet, _C2SState, _JID, _From, #jid{lserver=LServer}) -> +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_receive_packet({Packet, #{jid := #jid{lserver = LServer}} = C2SState}) -> push(LServer, user_receive_packet), - Packet. + {Packet, C2SState}. -spec s2s_send_packet(jid(), jid(), stanza()) -> any(). s2s_send_packet(#jid{lserver=LServer}, _To, _Packet) -> push(LServer, s2s_send_packet). --spec s2s_receive_packet(jid(), jid(), stanza()) -> any(). -s2s_receive_packet(_From, #jid{lserver=LServer}, _Packet) -> - push(LServer, s2s_receive_packet). +-spec s2s_receive_packet({stanza(), ejabberd_s2s_in:state()}) -> + {stanza(), ejabberd_s2s_in:state()}. +s2s_receive_packet({Packet, S2SState}) -> + To = xmpp:get_to(Packet), + LServer = ejabberd_router:host_of_route(To#jid.lserver), + push(LServer, s2s_receive_packet), + {Packet, S2SState}. -spec remove_user(binary(), binary()) -> any(). remove_user(_User, Server) -> diff --git a/src/mod_offline.erl b/src/mod_offline.erl index 8d58b14c9..c1768bf1c 100644 --- a/src/mod_offline.erl +++ b/src/mod_offline.erl @@ -44,7 +44,7 @@ store_packet/3, store_offline_msg/5, resend_offline_messages/2, - pop_offline_messages/3, + c2s_self_presence/1, get_sm_features/5, get_sm_identity/5, get_sm_items/5, @@ -62,6 +62,7 @@ get_offline_els/2, find_x_expire/2, c2s_handle_info/2, + c2s_copy_session/2, webadmin_page/3, webadmin_user/4, webadmin_user_parse_query/5]). @@ -91,6 +92,8 @@ -define(MAX_USER_MESSAGES, infinity). -type us() :: {binary(), binary()}. +-type c2s_state() :: ejabberd_c2s:state(). + -callback init(binary(), gen_mod:opts()) -> any(). -callback import(#offline_msg{}) -> ok. -callback store_messages(binary(), us(), [#offline_msg{}], @@ -142,8 +145,7 @@ init([Host, Opts]) -> no_queue), ejabberd_hooks:add(offline_message_hook, Host, ?MODULE, store_packet, 50), - ejabberd_hooks:add(resend_offline_messages_hook, Host, - ?MODULE, pop_offline_messages, 50), + ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, c2s_self_presence, 50), ejabberd_hooks:add(remove_user, Host, ?MODULE, remove_user, 50), ejabberd_hooks:add(anonymous_purge_hook, Host, @@ -158,6 +160,7 @@ init([Host, Opts]) -> ?MODULE, get_sm_items, 50), ejabberd_hooks:add(disco_info, Host, ?MODULE, get_info, 50), ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50), + ejabberd_hooks:add(c2s_copy_session, Host, ?MODULE, c2s_copy_session, 50), ejabberd_hooks:add(webadmin_page_host, Host, ?MODULE, webadmin_page, 50), ejabberd_hooks:add(webadmin_user, Host, @@ -202,8 +205,7 @@ terminate(_Reason, State) -> Host = State#state.host, ejabberd_hooks:delete(offline_message_hook, Host, ?MODULE, store_packet, 50), - ejabberd_hooks:delete(resend_offline_messages_hook, - Host, ?MODULE, pop_offline_messages, 50), + ejabberd_hooks:delete(c2s_self_presence, Host, ?MODULE, c2s_self_presence, 50), ejabberd_hooks:delete(remove_user, Host, ?MODULE, remove_user, 50), ejabberd_hooks:delete(anonymous_purge_hook, Host, @@ -214,6 +216,7 @@ terminate(_Reason, State) -> ejabberd_hooks:delete(disco_sm_items, Host, ?MODULE, get_sm_items, 50), ejabberd_hooks:delete(disco_info, Host, ?MODULE, get_info, 50), ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50), + ejabberd_hooks:delete(c2s_copy_session, Host, ?MODULE, c2s_copy_session, 50), ejabberd_hooks:delete(webadmin_page_host, Host, ?MODULE, webadmin_page, 50), ejabberd_hooks:delete(webadmin_user, Host, @@ -309,12 +312,18 @@ get_info(_Acc, #jid{luser = U, lserver = S} = JID, get_info(Acc, _From, _To, _Node, _Lang) -> Acc. --spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state(). +-spec c2s_handle_info(c2s_state(), term()) -> c2s_state(). c2s_handle_info(State, {resend_offline, Flag}) -> {stop, State#{resend_offline => Flag}}; c2s_handle_info(State, _) -> State. +-spec c2s_copy_session(c2s_state(), c2s_state()) -> c2s_state(). +c2s_copy_session(State, #{resend_offline := Flag}) -> + State#{resend_offline => Flag}; +c2s_copy_session(State, _) -> + State. + -spec handle_offline_query(iq()) -> iq(). handle_offline_query(#iq{from = #jid{luser = U1, lserver = S1}, to = #jid{luser = U2, lserver = S2}, @@ -398,10 +407,10 @@ handle_offline_fetch(#jid{luser = U, lserver = S} = JID) -> ejabberd_sm:route(JID, {resend_offline, false}), lists:foreach( fun({Node, El}) -> - NewEl = set_offline_tag(El, Node), - From = xmpp:get_from(El), - To = xmpp:get_to(El), - ejabberd_router:route(From, To, NewEl) + El1 = set_offline_tag(El, Node), + From = xmpp:get_from(El1), + To = xmpp:get_to(El1), + ejabberd_router:route(From, To, El1) end, read_messages(U, S)). -spec fetch_msg_by_node(jid(), binary()) -> error | {ok, #offline_msg{}}. @@ -557,41 +566,67 @@ resend_offline_messages(User, Server) -> _ -> ok end. --spec pop_offline_messages([{route, jid(), jid(), message()}], - binary(), binary()) -> - [{route, jid(), jid(), message()}]. -pop_offline_messages(Ls, User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), +c2s_self_presence({#presence{type = available} = NewPres, State} = Acc) -> + NewPrio = get_priority_from_presence(NewPres), + LastPrio = try maps:get(pres_last, State) of + LastPres -> get_priority_from_presence(LastPres) + catch _:{badkey, _} -> + -1 + end, + if LastPrio < 0 andalso NewPrio >= 0 -> + route_offline_messages(State); + true -> + ok + end, + Acc; +c2s_self_presence(Acc) -> + Acc. + +-spec route_offline_messages(c2s_state()) -> ok. +route_offline_messages(#{jid := #jid{luser = LUser, lserver = LServer}} = State) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:pop_messages(LUser, LServer) of - {ok, Rs} -> - TS = p1_time_compat:timestamp(), - Ls ++ - lists:flatmap( - fun(R) -> - case offline_msg_to_route(LServer, R) of - error -> []; - RouteMsg -> [RouteMsg] - end - end, - lists:filter( - fun(#offline_msg{packet = Pkt} = R) -> - Expire = case R#offline_msg.expire of - undefined -> - find_x_expire(TS, Pkt); - Exp -> - Exp - end, - case Expire of - never -> true; - TimeStamp -> TS < TimeStamp - end - end, Rs)); + {ok, OffMsgs} -> + lists:foreach( + fun(OffMsg) -> + route_offline_message(State, OffMsg) + end, OffMsgs); _ -> - Ls + ok end. +-spec route_offline_message(c2s_state(), #offline_msg{}) -> ok. +route_offline_message(#{lserver := LServer} = State, + #offline_msg{expire = Expire} = OffMsg) -> + case offline_msg_to_route(LServer, OffMsg) of + error -> + ok; + {route, From, To, Msg} -> + case is_message_expired(Expire, Msg) of + true -> + ok; + false -> + case privacy_check_packet(State, Msg, in) of + allow -> ejabberd_router:route(From, To, Msg); + false -> ok + end + end + end. + +-spec is_message_expired(erlang:timestamp() | never, message()) -> boolean(). +is_message_expired(Expire, Msg) -> + TS = p1_time_compat:timestamp(), + Expire1 = case Expire of + undefined -> find_x_expire(TS, Msg); + _ -> Expire + end, + Expire1 /= never andalso Expire1 =< TS. + +-spec privacy_check_packet(c2s_state(), stanza(), in | out) -> allow | deny. +privacy_check_packet(#{lserver := LServer} = State, Pkt, Dir) -> + ejabberd_hooks:run_fold(privacy_check_packet, + LServer, allow, [State, Pkt, Dir]). + remove_expired_messages(Server) -> LServer = jid:nameprep(Server), Mod = gen_mod:db_mod(LServer, ?MODULE), @@ -635,14 +670,15 @@ get_offline_els(LUser, LServer) -> -spec offline_msg_to_route(binary(), #offline_msg{}) -> {route, jid(), jid(), message()} | error. -offline_msg_to_route(LServer, #offline_msg{} = R) -> +offline_msg_to_route(LServer, #offline_msg{from = From, to = To} = R) -> try xmpp:decode(R#offline_msg.packet, ?NS_CLIENT, [ignore_els]) of Pkt -> - NewPkt = add_delay_info(Pkt, LServer, R#offline_msg.timestamp), - {route, R#offline_msg.from, R#offline_msg.to, NewPkt} + Pkt1 = xmpp:set_from_to(Pkt, From, To), + Pkt2 = add_delay_info(Pkt1, LServer, R#offline_msg.timestamp), + {route, From, To, Pkt2} catch _:{xmpp_codec, Why} -> ?ERROR_MSG("failed to decode packet ~p of user ~s: ~s", - [R#offline_msg.packet, jid:to_string(R#offline_msg.to), + [R#offline_msg.packet, jid:to_string(To), xmpp:format_error(Why)]), error end. @@ -840,9 +876,17 @@ count_offline_messages(User, Server) -> add_delay_info(Packet, _LServer, undefined) -> Packet; add_delay_info(Packet, LServer, {_, _, _} = TS) -> - xmpp_util:add_delay_info(Packet, jid:make(LServer), TS, + Packet1 = xmpp:put_meta(Packet, from_offline, true), + xmpp_util:add_delay_info(Packet1, jid:make(LServer), TS, <<"Offline storage">>). +-spec get_priority_from_presence(presence()) -> integer(). +get_priority_from_presence(#presence{priority = Prio}) -> + case Prio of + undefined -> 0; + _ -> Prio + end. + export(LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:export(LServer). diff --git a/src/mod_ping.erl b/src/mod_ping.erl index 5e861b7f7..09550ee9a 100644 --- a/src/mod_ping.erl +++ b/src/mod_ping.erl @@ -54,8 +54,8 @@ -export([init/1, terminate/2, handle_call/3, handle_cast/2, handle_info/2, code_change/3]). --export([iq_ping/1, user_online/3, user_offline/3, - user_send/4, mod_opt_type/1, depends/2]). +-export([iq_ping/1, user_online/3, user_offline/3, disco_features/5, + user_send/1, mod_opt_type/1, depends/2]). -record(state, {host = <<"">>, @@ -116,7 +116,7 @@ init([Host, Opts]) -> end, none), IQDisc = gen_mod:get_opt(iqdisc, Opts, fun gen_iq_handler:check_type/1, no_queue), - mod_disco:register_feature(Host, ?NS_PING), + ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 50), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_PING, ?MODULE, iq_ping, IQDisc), gen_iq_handler:add_iq_handler(ejabberd_local, Host, @@ -145,11 +145,12 @@ terminate(_Reason, #state{host = Host}) -> ?MODULE, user_online, 100), ejabberd_hooks:delete(user_send_packet, Host, ?MODULE, user_send, 100), + ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, + disco_features, 50), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_PING), gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, - ?NS_PING), - mod_disco:unregister_feature(Host, ?NS_PING). + ?NS_PING). handle_call(stop, _From, State) -> {stop, normal, ok, State}; @@ -215,10 +216,22 @@ user_online(_SID, JID, _Info) -> user_offline(_SID, JID, _Info) -> stop_ping(JID#jid.lserver, JID). --spec user_send(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -user_send(Packet, _C2SState, JID, _From) -> +-spec user_send({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_send({Packet, #{jid := JID} = C2SState}) -> start_ping(JID#jid.lserver, JID), - Packet. + {Packet, C2SState}. + +-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty, + jid(), jid(), binary(), binary()) -> + {error, stanza_error()} | {result, [binary()]}. +disco_features({error, Err}, _From, _To, _Node, _Lang) -> + {error, Err}; +disco_features(empty, _From, _To, <<"">>, _Lang) -> + {result, [?NS_PING]}; +disco_features({result, Feats}, _From, _To, <<"">>, _Lang) -> + {result, [?NS_PING|Feats]}; +disco_features(Acc, _From, _To, _Node, _Lang) -> + Acc. %%==================================================================== %% Internal functions diff --git a/src/mod_pres_counter.erl b/src/mod_pres_counter.erl index 955e53f6f..8da4f7b29 100644 --- a/src/mod_pres_counter.erl +++ b/src/mod_pres_counter.erl @@ -27,7 +27,7 @@ -behavior(gen_mod). --export([start/2, stop/1, check_packet/6, +-export([start/2, stop/1, check_packet/4, mod_opt_type/1, depends/2]). -include("ejabberd.hrl"). @@ -51,10 +51,12 @@ stop(Host) -> depends(_Host, _Opts) -> []. --spec check_packet(allow | deny, binary(), binary(), _, - {jid(), jid(), stanza()}, in | out) -> allow | deny. -check_packet(_, _User, Server, _PrivacyList, - {From, To, #presence{type = Type}}, Dir) -> +-spec check_packet(allow | deny, ejabberd_c2s:state() | jid(), + stanza(), in | out) -> allow | deny. +check_packet(Acc, #{jid := JID}, Packet, Dir) -> + check_packet(Acc, JID, Packet, Dir); +check_packet(_, #jid{lserver = LServer}, + #presence{from = From, to = To, type = Type}, Dir) -> IsSubscription = case Type of subscribe -> true; subscribed -> true; @@ -67,11 +69,11 @@ check_packet(_, _User, Server, _PrivacyList, in -> To; out -> From end, - update(Server, JID, Dir); + update(LServer, JID, Dir); true -> allow end; -check_packet(_, _User, _Server, _PrivacyList, _Pkt, _Dir) -> - allow. +check_packet(Acc, _, _, _) -> + Acc. update(Server, JID, Dir) -> StormCount = gen_mod:get_module_opt(Server, ?MODULE, count, diff --git a/src/mod_privacy.erl b/src/mod_privacy.erl index b28bbcea2..6eb939c3c 100644 --- a/src/mod_privacy.erl +++ b/src/mod_privacy.erl @@ -32,10 +32,10 @@ -behaviour(gen_mod). -export([start/2, stop/1, process_iq/1, export/1, import_info/0, - process_iq_set/3, process_iq_get/3, get_user_list/3, - check_packet/6, remove_user/2, encode_list_item/1, - is_list_needdb/1, updated_list/3, - import_start/2, import_stop/2, c2s_handle_info/2, + c2s_session_opened/1, c2s_copy_session/2, push_list_update/3, + user_send_packet/1, user_receive_packet/1, disco_features/5, + check_packet/4, remove_user/2, encode_list_item/1, + is_list_needdb/1, import_start/2, import_stop/2, item_to_xml/1, get_user_lists/2, import/5, set_privacy_list/1, mod_opt_type/1, depends/2]). @@ -64,106 +64,124 @@ start(Host, Opts) -> one_queue), Mod = gen_mod:db_mod(Host, Opts, ?MODULE), Mod:init(Host, Opts), - mod_disco:register_feature(Host, ?NS_PRIVACY), - ejabberd_hooks:add(privacy_iq_get, Host, ?MODULE, - process_iq_get, 50), - ejabberd_hooks:add(privacy_iq_set, Host, ?MODULE, - process_iq_set, 50), - ejabberd_hooks:add(privacy_get_user_list, Host, ?MODULE, - get_user_list, 50), + ejabberd_hooks:add(disco_local_features, Host, ?MODULE, + disco_features, 50), + ejabberd_hooks:add(c2s_session_opened, Host, ?MODULE, + c2s_session_opened, 50), + ejabberd_hooks:add(c2s_copy_session, Host, ?MODULE, + c2s_copy_session, 50), + ejabberd_hooks:add(user_send_packet, Host, ?MODULE, + user_send_packet, 50), + ejabberd_hooks:add(user_receive_packet, Host, ?MODULE, + user_receive_packet, 50), ejabberd_hooks:add(privacy_check_packet, Host, ?MODULE, check_packet, 50), - ejabberd_hooks:add(privacy_updated_list, Host, ?MODULE, - updated_list, 50), ejabberd_hooks:add(remove_user, Host, ?MODULE, remove_user, 50), - ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, - c2s_handle_info, 50), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_PRIVACY, ?MODULE, process_iq, IQDisc). stop(Host) -> - mod_disco:unregister_feature(Host, ?NS_PRIVACY), - ejabberd_hooks:delete(privacy_iq_get, Host, ?MODULE, - process_iq_get, 50), - ejabberd_hooks:delete(privacy_iq_set, Host, ?MODULE, - process_iq_set, 50), - ejabberd_hooks:delete(privacy_get_user_list, Host, - ?MODULE, get_user_list, 50), + ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, + disco_features, 50), + ejabberd_hooks:delete(c2s_session_opened, Host, ?MODULE, + c2s_session_opened, 50), + ejabberd_hooks:delete(c2s_copy_session, Host, ?MODULE, + c2s_copy_session, 50), + ejabberd_hooks:delete(user_send_packet, Host, ?MODULE, + user_send_packet, 50), + ejabberd_hooks:delete(user_receive_packet, Host, ?MODULE, + user_receive_packet, 50), ejabberd_hooks:delete(privacy_check_packet, Host, ?MODULE, check_packet, 50), - ejabberd_hooks:delete(privacy_updated_list, Host, - ?MODULE, updated_list, 50), ejabberd_hooks:delete(remove_user, Host, ?MODULE, remove_user, 50), - ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, - c2s_handle_info, 50), gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_PRIVACY). --spec process_iq(iq()) -> iq(). -process_iq(IQ) -> - xmpp:make_error(IQ, xmpp:err_not_allowed()). - --spec process_iq_get({error, stanza_error()} | {result, xmpp_element() | undefined}, - iq(), userlist()) -> {error, stanza_error()} | - {result, xmpp_element() | undefined}. -process_iq_get(_, #iq{lang = Lang, - sub_els = [#privacy_query{default = Default, - active = Active}]}, - _) when Default /= undefined; Active /= undefined -> - Txt = <<"Only element is allowed in this query">>, - {error, xmpp:err_bad_request(Txt, Lang)}; -process_iq_get(_, #iq{from = From, lang = Lang, - sub_els = [#privacy_query{lists = Lists}]}, - #userlist{name = Active}) -> - #jid{luser = LUser, lserver = LServer} = From, - case Lists of - [] -> - process_lists_get(LUser, LServer, Active, Lang); - [#privacy_list{name = ListName}] -> - process_list_get(LUser, LServer, ListName, Lang); - _ -> - Txt = <<"Too many elements">>, - {error, xmpp:err_bad_request(Txt, Lang)} - end; -process_iq_get(Acc, _, _) -> +-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty, + jid(), jid(), binary(), binary()) -> + {error, stanza_error()} | {result, [binary()]}. +disco_features({error, Err}, _From, _To, _Node, _Lang) -> + {error, Err}; +disco_features(empty, _From, _To, <<"">>, _Lang) -> + {result, [?NS_PRIVACY]}; +disco_features({result, Feats}, _From, _To, <<"">>, _Lang) -> + {result, [?NS_PRIVACY|Feats]}; +disco_features(Acc, _From, _To, _Node, _Lang) -> Acc. --spec process_lists_get(binary(), binary(), binary(), binary()) -> - {error, stanza_error()} | {result, privacy_query()}. -process_lists_get(LUser, LServer, Active, Lang) -> +-spec process_iq(iq()) -> iq(). +process_iq(#iq{type = Type, + from = #jid{luser = U, lserver = S}, + to = #jid{luser = U, lserver = S}} = IQ) -> + case Type of + get -> process_iq_get(IQ); + set -> process_iq_set(IQ) + end; +process_iq(#iq{lang = Lang} = IQ) -> + Txt = <<"Query to another users is forbidden">>, + xmpp:make_error(IQ, xmpp:err_forbidden(Txt, Lang)). + +-spec process_iq_get(iq()) -> iq(). +process_iq_get(#iq{lang = Lang, + sub_els = [#privacy_query{default = Default, + active = Active}]} = IQ) + when Default /= undefined; Active /= undefined -> + Txt = <<"Only element is allowed in this query">>, + xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang)); +process_iq_get(#iq{lang = Lang, + sub_els = [#privacy_query{lists = Lists}]} = IQ) -> + case Lists of + [] -> + process_lists_get(IQ); + [#privacy_list{name = ListName}] -> + process_list_get(IQ, ListName); + _ -> + Txt = <<"Too many elements">>, + xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang)) + end; +process_iq_get(#iq{lang = Lang} = IQ) -> + Txt = <<"No module is handling this query">>, + xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)). + +-spec process_lists_get(iq()) -> iq(). +process_lists_get(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang, + meta = #{privacy_active_list := Active}} = IQ) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:process_lists_get(LUser, LServer) of error -> Txt = <<"Database failure">>, - {error, xmpp:err_internal_server_error(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)); {_Default, []} -> - {result, #privacy_query{}}; + xmpp:make_iq_result(IQ, #privacy_query{}); {Default, ListNames} -> - {result, - #privacy_query{active = Active, - default = Default, - lists = [#privacy_list{name = ListName} - || ListName <- ListNames]}} + xmpp:make_iq_result( + IQ, + #privacy_query{active = Active, + default = Default, + lists = [#privacy_list{name = ListName} + || ListName <- ListNames]}) end. --spec process_list_get(binary(), binary(), binary(), binary()) -> - {error, stanza_error()} | {result, privacy_query()}. -process_list_get(LUser, LServer, Name, Lang) -> +-spec process_list_get(iq(), binary()) -> iq(). +process_list_get(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ, Name) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:process_list_get(LUser, LServer, Name) of error -> Txt = <<"Database failure">>, - {error, xmpp:err_internal_server_error(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)); not_found -> Txt = <<"No privacy list with this name found">>, - {error, xmpp:err_item_not_found(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang)); Items -> LItems = lists:map(fun encode_list_item/1, Items), - {result, - #privacy_query{ - lists = [#privacy_list{name = Name, items = LItems}]}} + xmpp:make_iq_result( + IQ, + #privacy_query{ + lists = [#privacy_list{name = Name, items = LItems}]}) end. -spec item_to_xml(listitem()) -> xmlel(). @@ -228,69 +246,61 @@ decode_value(Type, Value) -> undefined -> none end. --spec process_iq_set({error, stanza_error()} | - {result, xmpp_element() | undefined} | - {result, xmpp_element() | undefined, userlist()}, - iq(), #userlist{}) -> - {error, stanza_error()} | - {result, xmpp_element() | undefined} | - {result, xmpp_element() | undefined, userlist()}. -process_iq_set(_, #iq{from = From, lang = Lang, - sub_els = [#privacy_query{default = Default, - active = Active, - lists = Lists}]}, - #userlist{} = UserList) -> - #jid{luser = LUser, lserver = LServer} = From, +-spec process_iq_set(iq()) -> iq(). +process_iq_set(#iq{lang = Lang, + sub_els = [#privacy_query{default = Default, + active = Active, + lists = Lists}]} = IQ) -> case Lists of [#privacy_list{items = Items, name = ListName}] when Default == undefined, Active == undefined -> - process_lists_set(LUser, LServer, ListName, Items, UserList, Lang); + process_lists_set(IQ, ListName, Items); [] when Default == undefined, Active /= undefined -> - process_active_set(LUser, LServer, Active, Lang); + process_active_set(IQ, Active); [] when Active == undefined, Default /= undefined -> - process_default_set(LUser, LServer, Default, Lang); + process_default_set(IQ, Default); _ -> Txt = <<"The stanza MUST contain only one element, " "one element, or one element">>, - {error, xmpp:err_bad_request(Txt, Lang)} + xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang)) end; -process_iq_set(Acc, _, _) -> - Acc. +process_iq_set(#iq{lang = Lang} = IQ) -> + Txt = <<"No module is handling this query">>, + xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)). --spec process_default_set(binary(), binary(), none | binary(), - binary()) -> {error, stanza_error()} | {result, undefined}. -process_default_set(LUser, LServer, Value, Lang) -> +-spec process_default_set(iq(), binary()) -> iq(). +process_default_set(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ, Value) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:process_default_set(LUser, LServer, Value) of {atomic, error} -> Txt = <<"Database failure">>, - {error, xmpp:err_internal_server_error(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)); {atomic, not_found} -> Txt = <<"No privacy list with this name found">>, - {error, xmpp:err_item_not_found(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang)); {atomic, ok} -> - {result, undefined}; + xmpp:make_iq_result(IQ); Err -> ?ERROR_MSG("failed to set default list '~s' for user ~s@~s: ~p", [Value, LUser, LServer, Err]), - {error, xmpp:err_internal_server_error()} + xmpp:make_error(IQ, xmpp:err_internal_server_error()) end. --spec process_active_set(binary(), binary(), none | binary(), binary()) -> - {error, stanza_error()} | - {result, undefined, userlist()}. -process_active_set(_LUser, _LServer, none, _Lang) -> - {result, undefined, #userlist{}}; -process_active_set(LUser, LServer, Name, Lang) -> +-spec process_active_set(IQ, none | binary()) -> IQ. +process_active_set(IQ, none) -> + xmpp:make_iq_result(xmpp:put_meta(IQ, privacy_list, #userlist{})); +process_active_set(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ, Name) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:process_active_set(LUser, LServer, Name) of error -> Txt = <<"No privacy list with this name found">>, - {error, xmpp:err_item_not_found(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang)); Items -> NeedDb = is_list_needdb(Items), - {result, undefined, - #userlist{name = Name, list = Items, needdb = NeedDb}} + List = #userlist{name = Name, list = Items, needdb = NeedDb}, + xmpp:make_iq_result(xmpp:put_meta(IQ, privacy_list, List)) end. -spec set_privacy_list(privacy()) -> any(). @@ -298,57 +308,100 @@ set_privacy_list(#privacy{us = {_, LServer}} = Privacy) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:set_privacy_list(Privacy). --spec process_lists_set(binary(), binary(), binary(), [privacy_item()], - #userlist{}, binary()) -> {error, stanza_error()} | - {result, undefined}. -process_lists_set(_LUser, _LServer, Name, [], #userlist{name = Name}, Lang) -> +-spec process_lists_set(iq(), binary(), [privacy_item()]) -> iq(). +process_lists_set(#iq{meta = #{privacy_active_list := Name}, + lang = Lang} = IQ, Name, []) -> Txt = <<"Cannot remove active list">>, - {error, xmpp:err_conflict(Txt, Lang)}; -process_lists_set(LUser, LServer, Name, [], _UserList, Lang) -> + xmpp:make_error(IQ, xmpp:err_conflict(Txt, Lang)); +process_lists_set(#iq{from = #jid{luser = LUser, lserver = LServer} = From, + lang = Lang} = IQ, Name, []) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:remove_privacy_list(LUser, LServer, Name) of {atomic, conflict} -> Txt = <<"Cannot remove default list">>, - {error, xmpp:err_conflict(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_conflict(Txt, Lang)); {atomic, not_found} -> Txt = <<"No privacy list with this name found">>, - {error, xmpp:err_item_not_found(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang)); {atomic, ok} -> - ejabberd_sm:route(jid:make(LUser, LServer, <<"">>), - {privacy_list, #userlist{name = Name}, Name}), - {result, undefined}; + push_list_update(From, #userlist{name = Name}, Name), + xmpp:make_iq_result(IQ); Err -> ?ERROR_MSG("failed to remove privacy list '~s' for user ~s@~s: ~p", [Name, LUser, LServer, Err]), Txt = <<"Database failure">>, - {error, xmpp:err_internal_server_error(Txt, Lang)} + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)) end; -process_lists_set(LUser, LServer, Name, Items, _UserList, Lang) -> +process_lists_set(#iq{from = #jid{luser = LUser, lserver = LServer} = From, + lang = Lang} = IQ, Name, Items) -> case catch lists:map(fun decode_item/1, Items) of {error, Why} -> Txt = xmpp:format_error(Why), - {error, xmpp:err_bad_request(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang)); List -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:set_privacy_list(LUser, LServer, Name, List) of {atomic, ok} -> - NeedDb = is_list_needdb(List), - ejabberd_sm:route(jid:make(LUser, LServer, <<"">>), - {privacy_list, - #userlist{name = Name, - list = List, - needdb = NeedDb}, - Name}), - {result, undefined}; + UserList = #userlist{name = Name, list = List, + needdb = is_list_needdb(List)}, + push_list_update(From, UserList, Name), + xmpp:make_iq_result(IQ); Err -> ?ERROR_MSG("failed to set privacy list '~s' " "for user ~s@~s: ~p", [Name, LUser, LServer, Err]), Txt = <<"Database failure">>, - {error, xmpp:err_internal_server_error(Txt, Lang)} + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)) end end. +-spec push_list_update(jid(), #userlist{}, binary() | none) -> ok. +push_list_update(From, List, Name) -> + BareFrom = jid:remove_resource(From), + lists:foreach( + fun(R) -> + To = jid:replace_resource(From, R), + IQ = #iq{type = set, from = BareFrom, to = To, + id = <<"push", (randoms:get_string())/binary>>, + sub_els = [#privacy_query{ + lists = [#privacy_list{name = Name}]}], + meta = #{privacy_updated_list => List}}, + ejabberd_router:route(BareFrom, To, IQ) + end, ejabberd_sm:get_user_resources(From#jid.luser, From#jid.lserver)). + +-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_send_packet({#iq{type = Type, + to = #jid{luser = U, lserver = S, lresource = <<"">>}, + from = #jid{luser = U, lserver = S}, + sub_els = [_]} = IQ, + #{privacy_list := #userlist{name = Name}} = State}) + when Type == get; Type == set -> + NewIQ = case xmpp:has_subtag(IQ, #privacy_query{}) of + true -> xmpp:put_meta(IQ, privacy_active_list, Name); + false -> IQ + end, + {NewIQ, State}; +user_send_packet(Acc) -> + Acc. + +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_receive_packet({#iq{type = result, meta = #{privacy_list := List}} = IQ, + State}) -> + {IQ, State#{privacy_list => List}}; +user_receive_packet({#iq{type = set, meta = #{privacy_updated_list := New}} = IQ, + #{user := U, server := S, resource := R, + privacy_list := Old} = State}) -> + State1 = if Old#userlist.name == New#userlist.name -> + State#{privacy_list => New}; + true -> + State + end, + From = jid:make(U, S, <<"">>), + To = jid:make(U, S, R), + {xmpp:set_from_to(IQ, From, To), State1}; +user_receive_packet(Acc) -> + Acc. + -spec decode_item(privacy_item()) -> listitem(). decode_item(#privacy_item{order = Order, action = Action, @@ -391,15 +444,20 @@ is_list_needdb(Items) -> end, Items). --spec get_user_list(userlist(), binary(), binary()) -> userlist(). -get_user_list(_Acc, User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), +-spec get_user_list(binary(), binary()) -> #userlist{}. +get_user_list(LUser, LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), {Default, Items} = Mod:get_user_list(LUser, LServer), NeedDb = is_list_needdb(Items), - #userlist{name = Default, list = Items, - needdb = NeedDb}. + #userlist{name = Default, list = Items, needdb = NeedDb}. + +-spec c2s_session_opened(ejabberd_c2s:state()) -> ejabberd_c2s:state(). +c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer}} = State) -> + State#{privacy_list => get_user_list(LUser, LServer)}. + +-spec c2s_copy_session(ejabberd_c2s:state(), ejabberd_c2s:state()) -> ejabberd_c2s:state(). +c2s_copy_session(State, #{privacy_list := List}) -> + State#{privacy_list => List}. -spec get_user_lists(binary(), binary()) -> {ok, privacy()} | error. get_user_lists(User, Server) -> @@ -411,59 +469,66 @@ get_user_lists(User, Server) -> %% From is the sender, To is the destination. %% If Dir = out, User@Server is the sender account (From). %% If Dir = in, User@Server is the destination account (To). --spec check_packet(allow | deny, binary(), binary(), userlist(), - {jid(), jid(), stanza()}, in | out) -> allow | deny. -check_packet(_, _User, _Server, _UserList, - {#jid{luser = <<"">>, lserver = Server} = _From, - #jid{lserver = Server} = _To, _}, - in) -> - allow; -check_packet(_, _User, _Server, _UserList, - {#jid{lserver = Server} = _From, - #jid{luser = <<"">>, lserver = Server} = _To, _}, - out) -> - allow; -check_packet(_, _User, _Server, _UserList, - {#jid{luser = User, lserver = Server} = _From, - #jid{luser = User, lserver = Server} = _To, _}, - _Dir) -> - allow; -check_packet(_, User, Server, - #userlist{list = List, needdb = NeedDb}, - {From, To, Packet}, Dir) -> - case List of - [] -> allow; - _ -> - PType = case Packet of - #message{} -> message; - #iq{} -> iq; - #presence{type = available} -> presence; - #presence{type = unavailable} -> presence; - _ -> other - end, - PType2 = case {PType, Dir} of - {message, in} -> message; - {iq, in} -> iq; - {presence, in} -> presence_in; - {presence, out} -> presence_out; - {_, _} -> other +-spec check_packet(allow | deny, ejabberd_c2s:state() | jid(), + stanza(), in | out) -> allow | deny. +check_packet(_, #{jid := #jid{luser = LUser, lserver = LServer}, + privacy_list := #userlist{list = List, needdb = NeedDb}}, + Packet, Dir) -> + From = xmpp:get_from(Packet), + To = xmpp:get_to(Packet), + case {From, To} of + {#jid{luser = <<"">>, lserver = LServer}, + #jid{lserver = LServer}} when Dir == in -> + %% Allow any packets from local server + allow; + {#jid{lserver = LServer}, + #jid{luser = <<"">>, lserver = LServer}} when Dir == out -> + %% Allow any packets to local server + allow; + {#jid{luser = LUser, lserver = LServer, lresource = <<"">>}, + #jid{luser = LUser, lserver = LServer}} when Dir == in -> + %% Allow incoming packets from user's bare jid to his full jid + allow; + {#jid{luser = LUser, lserver = LServer}, + #jid{luser = LUser, lserver = LServer, lresource = <<"">>}} when Dir == out -> + %% Allow outgoing packets from user's full jid to his bare JID + allow; + _ when List == [] -> + allow; + _ -> + PType = case Packet of + #message{} -> message; + #iq{} -> iq; + #presence{type = available} -> presence; + #presence{type = unavailable} -> presence; + _ -> other + end, + PType2 = case {PType, Dir} of + {message, in} -> message; + {iq, in} -> iq; + {presence, in} -> presence_in; + {presence, out} -> presence_out; + {_, _} -> other + end, + LJID = case Dir of + in -> jid:tolower(From); + out -> jid:tolower(To) end, - LJID = case Dir of - in -> jid:tolower(From); - out -> jid:tolower(To) - end, - {Subscription, Groups} = case NeedDb of - true -> - ejabberd_hooks:run_fold(roster_get_jid_info, - jid:nameprep(Server), - {none, []}, - [User, Server, - LJID]); - false -> {[], []} - end, - check_packet_aux(List, PType2, LJID, Subscription, - Groups) - end. + {Subscription, Groups} = + case NeedDb of + true -> + ejabberd_hooks:run_fold(roster_get_jid_info, + LServer, + {none, []}, + [LUser, LServer, LJID]); + false -> + {[], []} + end, + check_packet_aux(List, PType2, LJID, Subscription, Groups) + end; +check_packet(Acc, #jid{luser = LUser, lserver = LServer} = JID, Packet, Dir) -> + List = get_user_list(LUser, LServer), + check_packet(Acc, #{jid => JID, privacy_list => List}, Packet, Dir). -spec check_packet_aux([listitem()], message | iq | presence_in | presence_out | other, @@ -535,30 +600,6 @@ remove_user(User, Server) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:remove_user(LUser, LServer). -c2s_handle_info(#{privacy_list := Old, - user := U, server := S, resource := R} = State, - {privacy_list, New, Name}) -> - List = if Old#userlist.name == New#userlist.name -> New; - true -> Old - end, - From = jid:make(U, S), - To = jid:make(U, S, R), - PushIQ = #iq{type = set, from = From, to = To, - id = <<"push", (randoms:get_string())/binary>>, - sub_els = [#privacy_query{ - lists = [#privacy_list{name = Name}]}]}, - State1 = State#{privacy_list => List}, - {stop, ejabberd_c2s:send(State1, PushIQ)}; -c2s_handle_info(State, _) -> - State. - --spec updated_list(userlist(), userlist(), userlist()) -> userlist(). -updated_list(_, #userlist{name = OldName} = Old, - #userlist{name = NewName} = New) -> - if OldName == NewName -> New; - true -> Old - end. - numeric_to_binary(<<0, 0, _/binary>>) -> <<"0">>; numeric_to_binary(<<0, _, _:6/binary, T/binary>>) -> diff --git a/src/mod_privilege.erl b/src/mod_privilege.erl index c1ac5a3fc..ae0a67e72 100644 --- a/src/mod_privilege.erl +++ b/src/mod_privilege.erl @@ -38,7 +38,7 @@ terminate/2, code_change/3]). -export([component_connected/1, component_disconnected/2, roster_access/2, process_message/3, - process_presence_out/4, process_presence_in/5]). + process_presence_out/1, process_presence_in/1]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -133,10 +133,11 @@ roster_access(false, #iq{from = From, to = To, type = Type}) -> false end. --spec process_presence_out(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -process_presence_out(#presence{type = Type} = Pres, _C2SState, - #jid{luser = LUser, lserver = LServer} = From, - #jid{luser = LUser, lserver = LServer, lresource = <<"">>}) +-spec process_presence_out({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +process_presence_out({#presence{ + from = #jid{luser = LUser, lserver = LServer} = From, + to = #jid{luser = LUser, lserver = LServer, lresource = <<"">>}, + type = Type} = Pres, C2SState}) when Type == available; Type == unavailable -> %% Self-presence processing Permissions = get_permissions(LServer), @@ -151,15 +152,15 @@ process_presence_out(#presence{type = Type} = Pres, _C2SState, ok end end, dict:to_list(Permissions)), - Pres; -process_presence_out(Acc, _, _, _) -> + {Pres, C2SState}; +process_presence_out(Acc) -> Acc. --spec process_presence_in(stanza(), ejabberd_c2s:state(), - jid(), jid(), jid()) -> stanza(). -process_presence_in(#presence{type = Type} = Pres, _C2SState, _, - #jid{luser = U, lserver = S} = From, - #jid{luser = LUser, lserver = LServer}) +-spec process_presence_in({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +process_presence_in({#presence{ + from = #jid{luser = U, lserver = S} = From, + to = #jid{luser = LUser, lserver = LServer}, + type = Type} = Pres, C2SState}) when {U, S} /= {LUser, LServer} andalso (Type == available orelse Type == unavailable) -> Permissions = get_permissions(LServer), @@ -179,8 +180,8 @@ process_presence_in(#presence{type = Type} = Pres, _C2SState, _, ok end end, dict:to_list(Permissions)), - Pres; -process_presence_in(Acc, _, _, _, _) -> + {Pres, C2SState}; +process_presence_in(Acc) -> Acc. %%%=================================================================== diff --git a/src/mod_pubsub.erl b/src/mod_pubsub.erl index 8819e3a99..d631b0ad0 100644 --- a/src/mod_pubsub.erl +++ b/src/mod_pubsub.erl @@ -272,7 +272,6 @@ init([ServerHost, Opts]) -> ejabberd_mnesia:create(?MODULE, pubsub_last_item, [{ram_copies, [node()]}, {attributes, record_info(fields, pubsub_last_item)}]), - mod_disco:register_feature(ServerHost, ?NS_PUBSUB), lists:foreach( fun(H) -> T = gen_mod:get_module_proc(H, config), @@ -533,7 +532,7 @@ disco_local_features(Acc, _From, To, <<>>, _Lang) -> {result, I} -> I; _ -> [] end, - {result, Feats ++ [feature(F) || F <- features(Host, <<>>)]}; + {result, Feats ++ [?NS_PUBSUB|[feature(F) || F <- features(Host, <<>>)]]}; disco_local_features(Acc, _From, _To, _Node, _Lang) -> Acc. @@ -923,7 +922,6 @@ terminate(_Reason, gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_PUBSUB_OWNER), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_VCARD), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_COMMANDS), - mod_disco:unregister_feature(ServerHost, ?NS_PUBSUB), case whereis(gen_mod:get_module_proc(ServerHost, ?LOOPNAME)) of undefined -> ?ERROR_MSG("~s process is dead, pubsub was broken", [?LOOPNAME]); diff --git a/src/mod_roster.erl b/src/mod_roster.erl index 5c207f3a4..085f50225 100644 --- a/src/mod_roster.erl +++ b/src/mod_roster.erl @@ -43,9 +43,9 @@ -export([start/2, stop/1, process_iq/1, export/1, import_info/0, process_local_iq/1, get_user_roster/2, - import/5, get_subscription_lists/3, get_roster/2, - import_start/2, import_stop/2, c2s_handle_info/2, - get_in_pending_subscriptions/3, in_subscription/6, + import/5, c2s_session_opened/1, get_roster/2, + import_start/2, import_stop/2, user_receive_packet/1, + c2s_self_presence/1, in_subscription/6, out_subscription/4, set_items/3, remove_user/2, get_jid_info/4, encode_item/1, webadmin_page/3, webadmin_user/4, get_versioning_feature/2, @@ -94,24 +94,24 @@ start(Host, Opts) -> ?MODULE, in_subscription, 50), ejabberd_hooks:add(roster_out_subscription, Host, ?MODULE, out_subscription, 50), - ejabberd_hooks:add(roster_get_subscription_lists, Host, - ?MODULE, get_subscription_lists, 50), + ejabberd_hooks:add(c2s_session_opened, Host, ?MODULE, + c2s_session_opened, 50), ejabberd_hooks:add(roster_get_jid_info, Host, ?MODULE, get_jid_info, 50), ejabberd_hooks:add(remove_user, Host, ?MODULE, remove_user, 50), ejabberd_hooks:add(anonymous_purge_hook, Host, ?MODULE, remove_user, 50), - ejabberd_hooks:add(resend_subscription_requests_hook, - Host, ?MODULE, get_in_pending_subscriptions, 50), + ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, + c2s_self_presence, 50), ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE, get_versioning_feature, 50), ejabberd_hooks:add(webadmin_page_host, Host, ?MODULE, webadmin_page, 50), ejabberd_hooks:add(webadmin_user, Host, ?MODULE, webadmin_user, 50), - ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, - c2s_handle_info, 50), + ejabberd_hooks:add(user_receive_packet, Host, ?MODULE, + user_receive_packet, 50), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_ROSTER, ?MODULE, process_iq, IQDisc). @@ -122,24 +122,24 @@ stop(Host) -> ?MODULE, in_subscription, 50), ejabberd_hooks:delete(roster_out_subscription, Host, ?MODULE, out_subscription, 50), - ejabberd_hooks:delete(roster_get_subscription_lists, - Host, ?MODULE, get_subscription_lists, 50), + ejabberd_hooks:delete(c2s_session_opened, Host, ?MODULE, + c2s_session_opened, 50), ejabberd_hooks:delete(roster_get_jid_info, Host, ?MODULE, get_jid_info, 50), ejabberd_hooks:delete(remove_user, Host, ?MODULE, remove_user, 50), ejabberd_hooks:delete(anonymous_purge_hook, Host, ?MODULE, remove_user, 50), - ejabberd_hooks:delete(resend_subscription_requests_hook, - Host, ?MODULE, get_in_pending_subscriptions, 50), + ejabberd_hooks:delete(c2s_self_presence, Host, ?MODULE, + c2s_self_presence, 50), ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE, get_versioning_feature, 50), ejabberd_hooks:delete(webadmin_page_host, Host, ?MODULE, webadmin_page, 50), ejabberd_hooks:delete(webadmin_user, Host, ?MODULE, webadmin_user, 50), - ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, - c2s_handle_info, 50), + ejabberd_hooks:delete(user_receive_packet, Host, ?MODULE, + user_receive_packet, 50), gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_ROSTER). @@ -220,10 +220,16 @@ roster_version_on_db(Host) -> %% Returns a list that may contain an xmlelement with the XEP-237 feature if it's enabled. -spec get_versioning_feature([xmpp_element()], binary()) -> [xmpp_element()]. get_versioning_feature(Acc, Host) -> - case roster_versioning_enabled(Host) of - true -> - [#rosterver_feature{}|Acc]; - false -> [] + case gen_mod:is_loaded(Host, ?MODULE) of + true -> + case roster_versioning_enabled(Host) of + true -> + [#rosterver_feature{}|Acc]; + false -> + Acc + end; + false -> + Acc end. roster_version(LServer, LUser) -> @@ -423,8 +429,6 @@ process_iq_set(#iq{from = From, to = To, end. push_item(User, Server, From, Item) -> - ejabberd_sm:route(jid:make(User, Server, <<"">>), - {item, Item#roster.jid, Item#roster.subscription}), case roster_versioning_enabled(Server) of true -> push_item_version(Server, User, From, Item, @@ -446,15 +450,12 @@ push_item(User, Server, Resource, From, Item, not_found -> undefined; _ -> RosterVersion end, - ResIQ = #iq{type = set, -%% @doc Roster push, calculate and include the version attribute. -%% TODO: don't push to those who didn't load roster + To = jid:make(User, Server, Resource), + ResIQ = #iq{type = set, from = From, to = To, id = <<"push", (randoms:get_string())/binary>>, sub_els = [#roster_query{ver = Ver, items = [encode_item(Item)]}]}, - ejabberd_router:route(From, - jid:make(User, Server, Resource), - ResIQ). + ejabberd_router:route(From, To, xmpp:put_meta(ResIQ, roster_item, Item)). push_item_version(Server, User, From, Item, RosterVersion) -> @@ -464,19 +465,19 @@ push_item_version(Server, User, From, Item, end, ejabberd_sm:get_user_resources(User, Server)). -c2s_handle_info(State, {item, JID, Sub}) -> - {stop, roster_change(State, JID, Sub)}; -c2s_handle_info(State, _) -> - State. +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_receive_packet({#iq{type = set, meta = #{roster_item := Item}} = IQ, State}) -> + {IQ, roster_change(State, Item)}; +user_receive_packet(Acc) -> + Acc. --spec roster_change(ejabberd_c2s:state(), jid(), subscription()) -> ejabberd_c2s:state(). -roster_change(#{user := U, server := S, resource := R} = State, - IJID, ISubscription) -> +-spec roster_change(ejabberd_c2s:state(), #roster{}) -> ejabberd_c2s:state(). +roster_change(#{user := U, server := S, resource := R, + pres_a := PresA, pres_f := PresF, pres_t := PresT} = State, + #roster{jid = IJID, subscription = ISubscription}) -> LIJID = jid:tolower(IJID), IsFrom = (ISubscription == both) or (ISubscription == from), IsTo = (ISubscription == both) or (ISubscription == to), - PresF = maps:get(pres_f, State, ?SETS:new()), - PresT = maps:get(pres_t, State, ?SETS:new()), OldIsFrom = ?SETS:is_element(LIJID, PresF), FSet = if IsFrom -> ?SETS:add_element(LIJID, PresF); true -> ?SETS:del_element(LIJID, PresF) @@ -490,7 +491,6 @@ roster_change(#{user := U, server := S, resource := R} = State, State1; LastPres -> From = jid:make(U, S, R), - PresA = maps:get(pres_a, State1, ?SETS:new()), To = jid:make(IJID), Cond1 = IsFrom andalso not OldIsFrom, Cond2 = not IsFrom andalso OldIsFrom andalso @@ -507,7 +507,7 @@ roster_change(#{user := U, server := S, resource := R} = State, end, A = ?SETS:add_element(LIJID, PresA), State1#{pres_a => A}; - Cond2 -> + Cond2 -> PU = #presence{from = From, to = To, type = unavailable}, case ejabberd_hooks:run_fold( privacy_check_packet, allow, @@ -524,26 +524,29 @@ roster_change(#{user := U, server := S, resource := R} = State, end end. --spec get_subscription_lists({[ljid()], [ljid()]}, binary(), binary()) - -> {[ljid()], [ljid()]}. -get_subscription_lists(_Acc, User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), +-spec c2s_session_opened(ejabberd_c2s:state()) -> ejabberd_c2s:state(). +c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer} = JID, + pres_f := PresF, pres_t := PresT} = State) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Items = Mod:get_only_items(LUser, LServer), - fill_subscription_lists(LServer, Items, [], []). + {F, T} = fill_subscription_lists(Items, PresF, PresT), + LJID = jid:tolower(jid:remove_resource(JID)), + State#{pres_f => ?SETS:add(LJID, F), pres_t => ?SETS:add(LJID, T)}. -fill_subscription_lists(LServer, [I | Is], F, T) -> +fill_subscription_lists([I | Is], F, T) -> J = element(3, I#roster.usj), - case I#roster.subscription of - both -> - fill_subscription_lists(LServer, Is, [J | F], [J | T]); - from -> - fill_subscription_lists(LServer, Is, [J | F], T); - to -> fill_subscription_lists(LServer, Is, F, [J | T]); - _ -> fill_subscription_lists(LServer, Is, F, T) - end; -fill_subscription_lists(_LServer, [], F, T) -> + {F1, T1} = case I#roster.subscription of + both -> + {?SETS:add_element(J, F), ?SETS:add_element(J, T)}; + from -> + {?SETS:add_element(J, F), T}; + to -> + {F, ?SETS:add_element(J, T)}; + _ -> + {F, T} + end, + fill_subscription_lists(Is, F1, T1); +fill_subscription_lists([], F, T) -> {F, T}. ask_to_pending(subscribe) -> out; @@ -836,27 +839,47 @@ process_item_set_t(LUser, LServer, #roster_item{jid = JID1} = QueryItem) -> end; process_item_set_t(_LUser, _LServer, _) -> ok. --spec get_in_pending_subscriptions([presence()], binary(), binary()) -> [presence()]. -get_in_pending_subscriptions(Ls, User, Server) -> - LServer = jid:nameprep(Server), - Mod = gen_mod:db_mod(LServer, ?MODULE), - get_in_pending_subscriptions(Ls, User, Server, Mod). +-spec c2s_self_presence({presence(), ejabberd_c2s:state()}) + -> {presence(), ejabberd_c2s:state()}. +c2s_self_presence({_, #{pres_last := _}} = Acc) -> + Acc; +c2s_self_presence({#presence{type = available} = Pkt, + #{lserver := LServer} = State}) -> + Prio = get_priority_from_presence(Pkt), + if Prio >= 0 -> + Mod = gen_mod:db_mod(LServer, ?MODULE), + State1 = resend_pending_subscriptions(State, Mod), + {Pkt, State1}; + true -> + {Pkt, State} + end; +c2s_self_presence(Acc) -> + Acc. -get_in_pending_subscriptions(Ls, User, Server, Mod) -> - JID = jid:make(User, Server, <<"">>), +-spec resend_pending_subscriptions(ejabberd_c2s:state(), module()) -> ejabberd_c2s:state(). +resend_pending_subscriptions(#{jid := JID} = State, Mod) -> + BareJID = jid:remove_resource(JID), Result = Mod:get_only_items(JID#jid.luser, JID#jid.lserver), - Ls ++ lists:flatmap( - fun(#roster{ask = Ask} = R) when Ask == in; Ask == both -> - Message = R#roster.askmessage, - Status = if is_binary(Message) -> (Message); - true -> <<"">> - end, - [#presence{from = R#roster.jid, to = JID, - type = subscribe, - status = xmpp:mk_text(Status)}]; - (_) -> - [] - end, Result). + lists:foldl( + fun(#roster{ask = Ask} = R, AccState) when Ask == in; Ask == both -> + Message = R#roster.askmessage, + Status = if is_binary(Message) -> (Message); + true -> <<"">> + end, + Sub = #presence{from = R#roster.jid, to = BareJID, + type = subscribe, + status = xmpp:mk_text(Status)}, + ejabberd_c2s:send(AccState, Sub); + (_, AccState) -> + AccState + end, State, Result). + +-spec get_priority_from_presence(presence()) -> integer(). +get_priority_from_presence(#presence{priority = Prio}) -> + case Prio of + undefined -> 0; + _ -> Prio + end. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% diff --git a/src/mod_s2s_dialback.erl b/src/mod_s2s_dialback.erl index d0d78a30c..7e952f576 100644 --- a/src/mod_s2s_dialback.erl +++ b/src/mod_s2s_dialback.erl @@ -131,13 +131,14 @@ s2s_out_auth_result(#{db_verify := _} = State, _) -> %% in section 2.1.2, step 2 {stop, send_verify_request(State)}; s2s_out_auth_result(#{db_enabled := true, + sockmod := SockMod, socket := Socket, ip := IP, server := LServer, - remote_server := RServer} = State, false) -> + remote_server := RServer} = State, {false, _}) -> %% SASL authentication has failed, retrying with dialback %% Sending dialback request, section 2.1.1, step 1 ?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)", - [ejabberd_socket:pp(Socket), LServer, RServer, + [SockMod:pp(Socket), LServer, RServer, ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), State1 = maps:remove(stop_reason, State#{on_route => queue}), {stop, send_db_request(State1)}; @@ -150,6 +151,7 @@ s2s_out_downgraded(#{db_verify := _} = State, _) -> %% section 2.1.2, step 2 {stop, send_verify_request(State)}; s2s_out_downgraded(#{db_enabled := true, + sockmod := SockMod, socket := Socket, ip := IP, server := LServer, remote_server := RServer} = State, _) -> @@ -157,7 +159,7 @@ s2s_out_downgraded(#{db_enabled := true, %% section 2.1.1, step 1 ?INFO_MSG("(~s) Trying s2s dialback authentication with " "non-RFC compliant server: ~s -> ~s (~s)", - [ejabberd_socket:pp(Socket), LServer, RServer, + [SockMod:pp(Socket), LServer, RServer, ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), {stop, send_db_request(State)}; s2s_out_downgraded(State, _) -> diff --git a/src/mod_service_log.erl b/src/mod_service_log.erl index ea7768bca..f27c4d0d8 100644 --- a/src/mod_service_log.erl +++ b/src/mod_service_log.erl @@ -29,8 +29,8 @@ -behaviour(gen_mod). --export([start/2, stop/1, log_user_send/4, - log_user_receive/5, mod_opt_type/1, depends/2]). +-export([start/2, stop/1, log_user_send/1, + log_user_receive/1, mod_opt_type/1, depends/2]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -54,15 +54,19 @@ stop(Host) -> depends(_Host, _Opts) -> []. --spec log_user_send(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -log_user_send(Packet, _C2SState, From, To) -> +-spec log_user_send({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +log_user_send({Packet, C2SState}) -> + From = xmpp:get_from(Packet), + To = xmpp:get_to(Packet), log_packet(From, To, Packet, From#jid.lserver), - Packet. + {Packet, C2SState}. --spec log_user_receive(stanza(), ejabberd_c2s:state(), jid(), jid(), jid()) -> stanza(). -log_user_receive(Packet, _C2SState, _JID, From, To) -> +-spec log_user_receive({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +log_user_receive({Packet, C2SState}) -> + From = xmpp:get_from(Packet), + To = xmpp:get_to(Packet), log_packet(From, To, Packet, To#jid.lserver), - Packet. + {Packet, C2SState}. -spec log_packet(jid(), jid(), stanza(), binary()) -> ok. log_packet(From, To, Packet, Host) -> diff --git a/src/mod_shared_roster.erl b/src/mod_shared_roster.erl index e91f7481a..e7510936f 100644 --- a/src/mod_shared_roster.erl +++ b/src/mod_shared_roster.erl @@ -31,9 +31,9 @@ -export([start/2, stop/1, export/1, import_info/0, webadmin_menu/3, webadmin_page/3, - get_user_roster/2, get_subscription_lists/3, + get_user_roster/2, c2s_session_opened/1, get_jid_info/4, import/5, process_item/2, import_start/2, - in_subscription/6, out_subscription/4, user_available/1, + in_subscription/6, out_subscription/4, c2s_self_presence/1, unset_presence/4, register_user/2, remove_user/2, list_groups/1, create_group/2, create_group/3, delete_group/2, get_group_opts/2, set_group_opts/3, @@ -54,6 +54,8 @@ -include("mod_shared_roster.hrl"). +-define(SETS, gb_sets). + -type group_options() :: [{atom(), any()}]. -callback init(binary(), gen_mod:opts()) -> any(). -callback import(binary(), binary(), [binary()]) -> ok. @@ -84,14 +86,14 @@ start(Host, Opts) -> ?MODULE, in_subscription, 30), ejabberd_hooks:add(roster_out_subscription, Host, ?MODULE, out_subscription, 30), - ejabberd_hooks:add(roster_get_subscription_lists, Host, - ?MODULE, get_subscription_lists, 70), + ejabberd_hooks:add(c2s_session_opened, Host, + ?MODULE, c2s_session_opened, 70), ejabberd_hooks:add(roster_get_jid_info, Host, ?MODULE, get_jid_info, 70), ejabberd_hooks:add(roster_process_item, Host, ?MODULE, process_item, 50), - ejabberd_hooks:add(user_available_hook, Host, ?MODULE, - user_available, 50), + ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, + c2s_self_presence, 50), ejabberd_hooks:add(unset_presence_hook, Host, ?MODULE, unset_presence, 50), ejabberd_hooks:add(register_user, Host, ?MODULE, @@ -112,14 +114,14 @@ stop(Host) -> ?MODULE, in_subscription, 30), ejabberd_hooks:delete(roster_out_subscription, Host, ?MODULE, out_subscription, 30), - ejabberd_hooks:delete(roster_get_subscription_lists, - Host, ?MODULE, get_subscription_lists, 70), + ejabberd_hooks:delete(c2s_session_opened, + Host, ?MODULE, c2s_session_opened, 70), ejabberd_hooks:delete(roster_get_jid_info, Host, ?MODULE, get_jid_info, 70), ejabberd_hooks:delete(roster_process_item, Host, ?MODULE, process_item, 50), - ejabberd_hooks:delete(user_available_hook, Host, - ?MODULE, user_available, 50), + ejabberd_hooks:delete(c2s_self_presence, Host, + ?MODULE, c2s_self_presence, 50), ejabberd_hooks:delete(unset_presence_hook, Host, ?MODULE, unset_presence, 50), ejabberd_hooks:delete(register_user, Host, ?MODULE, @@ -294,19 +296,21 @@ set_item(User, Server, Resource, Item) -> jid:make(Server), ResIQ). --spec get_subscription_lists({[ljid()], [ljid()]}, binary(), binary()) - -> {[ljid()], [ljid()]}. -get_subscription_lists({F, T}, User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), +c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer} = JID, + pres_f := PresF, pres_t := PresT} = State) -> US = {LUser, LServer}, DisplayedGroups = get_user_displayed_groups(US), - SRUsers = lists:usort(lists:flatmap(fun (Group) -> - get_group_users(LServer, Group) - end, - DisplayedGroups)), - SRJIDs = [{U1, S1, <<"">>} || {U1, S1} <- SRUsers], - {lists:usort(SRJIDs ++ F), lists:usort(SRJIDs ++ T)}. + SRUsers = lists:flatmap(fun(Group) -> + get_group_users(LServer, Group) + end, + DisplayedGroups), + BareLJID = jid:tolower(jid:remove_resource(JID)), + PresBoth = lists:foldl( + fun({U, S}, Acc) -> + ?SETS:add_element({U, S, <<"">>}, Acc) + end, ?SETS:new(), [BareLJID|SRUsers]), + State#{pres_f => ?SETS:union(PresBoth, PresF), + pres_t => ?SETS:union(PresBoth, PresT)}. -spec get_jid_info({subscription(), [binary()]}, binary(), binary(), jid()) -> {subscription(), [binary()]}. @@ -739,12 +743,15 @@ push_roster_item(User, Server, ContactU, ContactS, groups = [GroupName]}, push_item(User, Server, Item). --spec user_available(jid()) -> ok. -user_available(New) -> +-spec c2s_self_presence({presence(), ejabberd_c2s:state()}) + -> {presence(), ejabberd_c2s:state()}. +c2s_self_presence({_, #{pres_last := _}} = Acc) -> + %% This is just a presence update, nothing to do + Acc; +c2s_self_presence({#presence{type = available}, #{jid := New}} = Acc) -> LUser = New#jid.luser, LServer = New#jid.lserver, - Resources = ejabberd_sm:get_user_resources(LUser, - LServer), + Resources = ejabberd_sm:get_user_resources(LUser, LServer), ?DEBUG("user_available for ~p @ ~p (~p resources)", [LUser, LServer, length(Resources)]), case length(Resources) of @@ -761,7 +768,10 @@ user_available(New) -> end, UserGroups); _ -> ok - end. + end, + Acc; +c2s_self_presence(Acc) -> + Acc. -spec unset_presence(binary(), binary(), binary(), binary()) -> ok. unset_presence(LUser, LServer, Resource, Status) -> diff --git a/src/mod_shared_roster_ldap.erl b/src/mod_shared_roster_ldap.erl index 97ead9f3d..777854b8e 100644 --- a/src/mod_shared_roster_ldap.erl +++ b/src/mod_shared_roster_ldap.erl @@ -39,7 +39,7 @@ -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([get_user_roster/2, get_subscription_lists/3, +-export([get_user_roster/2, c2s_session_opened/1, get_jid_info/4, process_item/2, in_subscription/6, out_subscription/4, mod_opt_type/1, opt_type/1, depends/2]). @@ -49,6 +49,7 @@ -include("mod_roster.hrl"). -include("eldap.hrl"). +-define(SETS, gb_sets). -define(CACHE_SIZE, 1000). -define(USER_CACHE_VALIDITY, 300). %% in seconds -define(GROUP_CACHE_VALIDITY, 300). @@ -160,19 +161,21 @@ process_item(RosterItem, _Host) -> _ -> RosterItem#roster{subscription = both, ask = none} end. --spec get_subscription_lists({[ljid()], [ljid()]}, binary(), binary()) - -> {[ljid()], [ljid()]}. -get_subscription_lists({F, T}, User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), +c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer} = JID, + pres_f := PresF, pres_t := PresT} = State) -> US = {LUser, LServer}, DisplayedGroups = get_user_displayed_groups(US), - SRUsers = lists:usort(lists:flatmap(fun (Group) -> - get_group_users(LServer, Group) - end, - DisplayedGroups)), - SRJIDs = [{U1, S1, <<"">>} || {U1, S1} <- SRUsers], - {lists:usort(SRJIDs ++ F), lists:usort(SRJIDs ++ T)}. + SRUsers = lists:flatmap(fun(Group) -> + get_group_users(LServer, Group) + end, + DisplayedGroups), + BareLJID = jid:tolower(jid:remove_resource(JID)), + PresBoth = lists:foldl( + fun({U, S}, Acc) -> + ?SETS:add_element({U, S, <<"">>}, Acc) + end, ?SETS:new(), [BareLJID|SRUsers]), + State#{pres_f => ?SETS:union(PresBoth, PresF), + pres_t => ?SETS:union(PresBoth, PresT)}. -spec get_jid_info({subscription(), [binary()]}, binary(), binary(), jid()) -> {subscription(), [binary()]}. @@ -246,8 +249,8 @@ init([Host, Opts]) -> ?MODULE, in_subscription, 30), ejabberd_hooks:add(roster_out_subscription, Host, ?MODULE, out_subscription, 30), - ejabberd_hooks:add(roster_get_subscription_lists, Host, - ?MODULE, get_subscription_lists, 70), + ejabberd_hooks:add(c2s_session_opened, Host, + ?MODULE, c2s_session_opened, 70), ejabberd_hooks:add(roster_get_jid_info, Host, ?MODULE, get_jid_info, 70), ejabberd_hooks:add(roster_process_item, Host, ?MODULE, @@ -275,8 +278,8 @@ terminate(_Reason, State) -> ?MODULE, in_subscription, 30), ejabberd_hooks:delete(roster_out_subscription, Host, ?MODULE, out_subscription, 30), - ejabberd_hooks:delete(roster_get_subscription_lists, - Host, ?MODULE, get_subscription_lists, 70), + ejabberd_hooks:delete(c2s_session_opened, + Host, ?MODULE, c2s_session_opened, 70), ejabberd_hooks:delete(roster_get_jid_info, Host, ?MODULE, get_jid_info, 70), ejabberd_hooks:delete(roster_process_item, Host, diff --git a/src/mod_sm.erl b/src/mod_sm.erl index 7e64e6a00..0382c60a9 100644 --- a/src/mod_sm.erl +++ b/src/mod_sm.erl @@ -31,8 +31,8 @@ -export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2, c2s_authenticated_packet/2, c2s_unauthenticated_packet/2, c2s_unbinded_packet/2, c2s_closed/2, c2s_terminated/2, - c2s_handle_send/3, c2s_filter_send/1, c2s_handle_info/2, - c2s_handle_call/3, c2s_handle_recv/3]). + c2s_handle_send/3, c2s_handle_info/2, c2s_handle_call/3, + c2s_handle_recv/3]). -include("xmpp.hrl"). -include("logger.hrl"). @@ -63,7 +63,6 @@ start(Host, _Opts) -> c2s_authenticated_packet, 50), ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50), ejabberd_hooks:add(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50), - ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50), ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50), ejabberd_hooks:add(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50), ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50), @@ -83,7 +82,6 @@ stop(Host) -> c2s_authenticated_packet, 50), ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50), ejabberd_hooks:delete(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50), - ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50), ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50), ejabberd_hooks:delete(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50), ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50), @@ -179,21 +177,33 @@ c2s_handle_recv(#{lang := Lang} = State, El, {error, Why}) -> c2s_handle_recv(State, _, _) -> State. -c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, _Result) +c2s_handle_send(#{mgmt_state := MgmtState, + lang := Lang} = State, Pkt, SendResult) when MgmtState == pending; MgmtState == active -> - State1 = mgmt_queue_add(State, Pkt), case xmpp:is_stanza(Pkt) of true -> - send_rack(State1); + case mgmt_queue_add(State, Pkt) of + #{mgmt_max_queue := exceeded} = State1 -> + State2 = State1#{mgmt_resend => false}, + case MgmtState of + active -> + Err = xmpp:serr_policy_violation( + <<"Too many unacked stanzas">>, Lang), + send(State2, Err); + _ -> + ejabberd_c2s:stop(State2) + end; + State1 when SendResult == ok -> + send_rack(State1); + State1 -> + State1 + end; false -> - State1 + State end; c2s_handle_send(State, _Pkt, _Result) -> State. -c2s_filter_send({Pkt, State}) -> - {Pkt, State}. - c2s_handle_call(#{sid := {Time, _}} = State, {resume_session, Time}, From) -> ejabberd_c2s:reply(From, {resume, State}), @@ -216,6 +226,13 @@ c2s_handle_info(#{mgmt_state := pending, jid := JID} = State, ?DEBUG("Timed out waiting for resumption of stream for ~s", [jid:to_string(JID)]), ejabberd_c2s:stop(State#{mgmt_state => timeout}); +c2s_handle_info(#{jid := JID} = State, {_Ref, {resume, OldState}}) -> + %% This happens if the resume_session/1 request timed out; the new session + %% now receives the late response. + ?DEBUG("Received old session state for ~s after failed resumption", + [jid:to_string(JID)]), + route_unacked_stanzas(OldState#{mgmt_resend => false}), + State; c2s_handle_info(State, _) -> State. @@ -325,7 +342,7 @@ handle_a(State, #sm_a{h = H}) -> resend_rack(State1). -spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}. -handle_resume(#{user := User, lserver := LServer, +handle_resume(#{user := User, lserver := LServer, sockmod := SockMod, lang := Lang, socket := Socket} = State, #sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) -> R = case inherit_session_state(State, PrevID) of @@ -354,7 +371,7 @@ handle_resume(#{user := User, lserver := LServer, %% csi_flush_queue(State4), State5 = ejabberd_hooks:run_fold(c2s_session_resumed, LServer, State4, []), ?INFO_MSG("(~s) Resumed session for ~s", - [ejabberd_socket:pp(Socket), jid:to_string(JID)]), + [SockMod:pp(Socket), jid:to_string(JID)]), {ok, State5}; {error, El, Msg} -> ?INFO_MSG("Cannot resume session for ~s@~s: ~s", @@ -363,6 +380,8 @@ handle_resume(#{user := User, lserver := LServer, end. -spec transition_to_pending(state()) -> state(). +transition_to_pending(#{mgmt_state := active, mgmt_timeout := 0} = State) -> + ejabberd_c2s:stop(State); transition_to_pending(#{mgmt_state := active, jid := JID, lserver := LServer, mgmt_timeout := Timeout} = State) -> State1 = cancel_ack_timer(State), @@ -405,9 +424,9 @@ send_rack(#{mgmt_ack_timer := _} = State) -> send_rack(#{mgmt_xmlns := Xmlns, mgmt_stanzas_out := NumStanzasOut, mgmt_ack_timeout := AckTimeout} = State) -> - State1 = send(State, #sm_r{xmlns = Xmlns}), TRef = erlang:start_timer(AckTimeout, self(), ack_timeout), - State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}. + State1 = State#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}, + send(State1, #sm_r{xmlns = Xmlns}). resend_rack(#{mgmt_ack_timer := _, mgmt_queue := Queue, @@ -424,18 +443,13 @@ resend_rack(State) -> -spec mgmt_queue_add(state(), xmpp_element()) -> state(). mgmt_queue_add(#{mgmt_stanzas_out := NumStanzasOut, mgmt_queue := Queue} = State, Pkt) -> - case xmpp:is_stanza(Pkt) of - true -> - NewNum = case NumStanzasOut of - 4294967295 -> 0; - Num -> Num + 1 - end, - Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Pkt}, Queue), - State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum}, - check_queue_length(State1); - false -> - State - end. + NewNum = case NumStanzasOut of + 4294967295 -> 0; + Num -> Num + 1 + end, + Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Pkt}, Queue), + State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum}, + check_queue_length(State1). -spec mgmt_queue_drop(state(), non_neg_integer()) -> state(). mgmt_queue_drop(#{mgmt_queue := Queue} = State, NumHandled) -> @@ -510,20 +524,24 @@ route_unacked_stanzas(#{mgmt_state := MgmtState, %% easily lead to unexpected results as well. ?DEBUG("Dropping forwarded message stanza from ~s", [jid:to_string(From)]); - ({_, Time, El}) -> + ({_, Time, #message{} = Msg}) -> case ejabberd_hooks:run_fold(message_is_archived, LServer, false, - [State, El]) of + [State, Msg]) of true -> ?DEBUG("Dropping archived message stanza from ~s", - [jid:to_string(xmpp:get_from(El))]); + [jid:to_string(xmpp:get_from(Msg))]); false when ResendOnTimeout -> - NewEl = add_resent_delay_info(State, El, Time), + NewEl = add_resent_delay_info(State, Msg, Time), route(NewEl); false -> Txt = <<"User session terminated">>, - route_error(El, xmpp:err_service_unavailable(Txt, Lang)) - end + route_error(Msg, xmpp:err_service_unavailable(Txt, Lang)) + end; + ({_, _Time, El}) -> + %% Raw element of type 'error' resulting from a validation error + %% We cannot pass it to the router, it will generate an error + ?DEBUG("Do not route raw element from ack queue: ~p", [El]) end, Queue); route_unacked_stanzas(_State) -> ok. @@ -587,11 +605,13 @@ resume_session({Time, Pid}, _State) -> make_resume_id(#{sid := {Time, _}, resource := Resource}) -> jlib:term_to_base64({Resource, Time}). --spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza(). -add_resent_delay_info(_State, #iq{} = El, _Time) -> - El; -add_resent_delay_info(#{lserver := LServer}, El, Time) -> - xmpp_util:add_delay_info(El, jid:make(LServer), Time, <<"Resent">>). +-spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza(); + (state(), xmlel(), erlang:timestamp()) -> xmlel(). +add_resent_delay_info(#{lserver := LServer}, El, Time) + when is_record(El, message); is_record(El, presence) -> + xmpp_util:add_delay_info(El, jid:make(LServer), Time, <<"Resent">>); +add_resent_delay_info(_State, El, _Time) -> + El. -spec route(stanza()) -> ok. route(Pkt) -> diff --git a/src/mod_vcard_xupdate.erl b/src/mod_vcard_xupdate.erl index 4d1dfa2fc..900758e39 100644 --- a/src/mod_vcard_xupdate.erl +++ b/src/mod_vcard_xupdate.erl @@ -12,7 +12,7 @@ %% gen_mod callbacks -export([start/2, stop/1]). --export([update_presence/3, vcard_set/3, export/1, +-export([update_presence/1, vcard_set/3, export/1, import_info/0, import/5, import_start/2, mod_opt_type/1, depends/2]). @@ -33,14 +33,14 @@ start(Host, Opts) -> Mod = gen_mod:db_mod(Host, Opts, ?MODULE), Mod:init(Host, Opts), - ejabberd_hooks:add(c2s_update_presence, Host, ?MODULE, + ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, update_presence, 100), ejabberd_hooks:add(vcard_set, Host, ?MODULE, vcard_set, 100), ok. stop(Host) -> - ejabberd_hooks:delete(c2s_update_presence, Host, + ejabberd_hooks:delete(c2s_self_presence, Host, ?MODULE, update_presence, 100), ejabberd_hooks:delete(vcard_set, Host, ?MODULE, vcard_set, 100), @@ -52,10 +52,15 @@ depends(_Host, _Opts) -> %%==================================================================== %% Hooks %%==================================================================== --spec update_presence(presence(), binary(), binary()) -> presence(). -update_presence(#presence{type = available} = Packet, User, Host) -> - presence_with_xupdate(Packet, User, Host); -update_presence(Packet, _User, _Host) -> Packet. +-spec update_presence({presence(), ejabberd_c2s:state()}) + -> {presence(), ejabberd_c2s:state()}. +update_presence({#presence{type = available} = Pres, + #{jid := #jid{luser = LUser, lserver = LServer}} = State}) -> + Hash = get_xupdate(LUser, LServer), + Pres1 = xmpp:set_subtag(Pres, #vcard_xupdate{hash = Hash}), + {Pres1, State}; +update_presence(Acc) -> + Acc. -spec vcard_set(binary(), binary(), xmlel()) -> ok. vcard_set(LUser, LServer, VCARD) -> @@ -86,15 +91,6 @@ remove_xupdate(LUser, LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:remove_xupdate(LUser, LServer). -%%%---------------------------------------------------------------------- -%%% Presence stanza rebuilding -%%%---------------------------------------------------------------------- - -presence_with_xupdate(Presence, User, Host) -> - Hash = get_xupdate(User, Host), - Presence1 = xmpp:remove_subtag(Presence, #vcard_xupdate{}), - xmpp:set_subtag(Presence1, #vcard_xupdate{hash = Hash}). - import_info() -> [{<<"vcard_xupdate">>, 3}]. @@ -110,5 +106,8 @@ export(LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:export(LServer). +%%==================================================================== +%% Options +%%==================================================================== mod_opt_type(db_type) -> fun(T) -> ejabberd_config:v_db(?MODULE, T) end; mod_opt_type(_) -> [db_type]. diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl index 1ad78d45b..b2b3b3072 100644 --- a/src/xmpp_stream_in.erl +++ b/src/xmpp_stream_in.erl @@ -20,9 +20,11 @@ %%% %%%------------------------------------------------------------------- -module(xmpp_stream_in). --behaviour(gen_server). +-define(GEN_SERVER, gen_server). +-behaviour(?GEN_SERVER). -protocol({rfc, 6120}). +-protocol({xep, 114, '1.6'}). %% API -export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1, @@ -43,17 +45,18 @@ -include("xmpp.hrl"). -type state() :: map(). -type stop_reason() :: {stream, reset | {in | out, stream_error()}} | - {tls, term()} | + {tls, inet:posix() | atom() | binary()} | {socket, inet:posix() | closed | timeout} | internal_failure. --callback init(list()) -> {ok, state()} | {stop, term()} | ignore. +-callback init(list()) -> {ok, state()} | {error, term()} | ignore. -callback handle_cast(term(), state()) -> state(). -callback handle_call(term(), term(), state()) -> state(). -callback handle_info(term(), state()) -> state(). -callback terminate(term(), state()) -> any(). -callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. --callback handle_stream_start(state()) -> state(). +-callback handle_stream_start(stream_start(), state()) -> state(). +-callback handle_stream_established(state()) -> state(). -callback handle_stream_end(stop_reason(), state()) -> state(). -callback handle_cdata(binary(), state()) -> state(). -callback handle_unauthenticated_packet(xmpp_element(), state()) -> state(). @@ -63,6 +66,7 @@ -callback handle_auth_failure(binary(), binary(), atom(), state()) -> state(). -callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state(). -callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state(). +-callback handle_timeout(state()) -> state(). -callback get_password_fun(state()) -> fun(). -callback check_password_fun(state()) -> fun(). -callback check_password_digest_fun(state()) -> fun(). @@ -71,6 +75,8 @@ -callback tls_options(state()) -> [proplists:property()]. -callback tls_required(state()) -> boolean(). -callback tls_verify(state()) -> boolean(). +-callback tls_enabled(state()) -> boolean(). +-callback sasl_mechanisms([cyrsasl:mechanism()], state()) -> [cyrsasl:mechanism()]. -callback unauthenticated_stream_features(state()) -> [xmpp_element()]. -callback authenticated_stream_features(state()) -> [xmpp_element()]. @@ -81,7 +87,8 @@ handle_info/2, terminate/2, code_change/3, - handle_stream_start/1, + handle_stream_start/2, + handle_stream_established/1, handle_stream_end/2, handle_cdata/2, handle_authenticated_packet/2, @@ -91,6 +98,7 @@ handle_auth_failure/4, handle_send/3, handle_recv/3, + handle_timeout/1, get_password_fun/1, check_password_fun/1, check_password_digest_fun/1, @@ -99,6 +107,8 @@ tls_options/1, tls_required/1, tls_verify/1, + tls_enabled/1, + sasl_mechanisms/2, unauthenticated_stream_features/1, authenticated_stream_features/1]). @@ -106,19 +116,19 @@ %%% API %%%=================================================================== start(Mod, Args, Opts) -> - gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). start_link(Mod, Args, Opts) -> - gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). call(Ref, Msg, Timeout) -> - gen_server:call(Ref, Msg, Timeout). + ?GEN_SERVER:call(Ref, Msg, Timeout). cast(Ref, Msg) -> - gen_server:cast(Ref, Msg). + ?GEN_SERVER:cast(Ref, Msg). reply(Ref, Reply) -> - gen_server:reply(Ref, Reply). + ?GEN_SERVER:reply(Ref, Reply). -spec stop(pid()) -> ok; (state()) -> no_return(). @@ -135,7 +145,7 @@ stop(_) -> send(Pid, Pkt) when is_pid(Pid) -> cast(Pid, {send, Pkt}); send(#{owner := Owner} = State, Pkt) when Owner == self() -> - send_element(State, Pkt); + send_pkt(State, Pkt); send(_, _) -> erlang:error(badarg). @@ -193,7 +203,7 @@ format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) -> format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) -> format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]); format_error({tls, Reason}) -> - format("TLS failed: ~w", [Reason]); + format("TLS failed: ~s", [format_tls_error(Reason)]); format_error(internal_failure) -> <<"Internal server error">>; format_error(Err) -> @@ -203,13 +213,9 @@ format_error(Err) -> %%% gen_server callbacks %%%=================================================================== init([Module, {SockMod, Socket}, Opts]) -> - XMLSocket = case lists:keyfind(xml_socket, 1, Opts) of - {_, XS} -> XS; - false -> false - end, Encrypted = proplists:get_bool(tls, Opts), SocketMonitor = SockMod:monitor(Socket), - case peername(SockMod, Socket) of + case SockMod:peername(Socket) of {ok, IP} -> Time = p1_time_compat:monotonic_time(milli_seconds), State = #{owner => self(), @@ -227,7 +233,6 @@ init([Module, {SockMod, Socket}, Opts]) -> stream_encrypted => Encrypted, stream_version => {1,0}, stream_authenticated => false, - xml_socket => XMLSocket, xmlns => ?NS_CLIENT, lang => <<"">>, user => <<"">>, @@ -238,18 +243,32 @@ init([Module, {SockMod, Socket}, Opts]) -> case try Module:init([State, Opts]) catch _:undef -> {ok, State} end of - {ok, State1} -> + {ok, State1} when not Encrypted -> {_, State2, Timeout} = noreply(State1), {ok, State2, Timeout}; - Err -> - Err + {ok, State1} when Encrypted -> + TLSOpts = try Module:tls_options(State1) + catch _:undef -> [] + end, + case SockMod:starttls(Socket, TLSOpts) of + {ok, TLSSocket} -> + State2 = State1#{socket => TLSSocket}, + {_, State3, Timeout} = noreply(State2), + {ok, State3, Timeout}; + {error, Reason} -> + {stop, Reason} + end; + {error, Reason} -> + {stop, Reason}; + ignore -> + ignore end; - {error, Reason} -> - {stop, Reason} + {error, _Reason} -> + ignore end. handle_cast({send, Pkt}, State) -> - noreply(send_element(State, Pkt)); + noreply(send_pkt(State, Pkt)); handle_cast(stop, State) -> {stop, normal, State}; handle_cast(Cast, #{mod := Mod} = State) -> @@ -278,7 +297,7 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, State1 = send_header(State), case is_disconnected(State1) of true -> State1; - false -> send_element(State1, xmpp:serr_invalid_xml()) + false -> send_pkt(State1, xmpp:serr_invalid_xml()) end catch _:{xmpp_codec, Why} -> State1 = send_header(State), @@ -288,7 +307,7 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, Txt = xmpp:io_format_error(Why), Lang = select_lang(MyLang, xmpp:get_lang(El)), Err = xmpp:serr_invalid_xml(Txt, Lang), - send_element(State1, Err) + send_pkt(State1, Err) end end); handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> @@ -303,7 +322,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> _ -> xmpp:serr_not_well_formed() end, - send_element(State1, Err) + send_pkt(State1, Err) end); handle_info({'$gen_event', {xmlstreamelement, El}}, #{xmlns := NS, mod := Mod} = State) -> @@ -339,7 +358,7 @@ handle_info(timeout, #{mod := Mod} = State) -> Disconnected = is_disconnected(State), noreply(try Mod:handle_timeout(State) catch _:undef when not Disconnected -> - send_element(State, xmpp:serr_connection_timeout()); + send_pkt(State, xmpp:serr_connection_timeout()); _:undef -> stop(State) end); @@ -385,14 +404,6 @@ new_id() -> is_disconnected(#{stream_state := StreamState}) -> StreamState == disconnected. --spec peername(term(), term()) -> {ok, {inet:ip_address(), inet:port_number()}}| - {error, inet:posix()}. -peername(SockMod, Socket) -> - case SockMod of - gen_tcp -> inet:peername(Socket); - _ -> SockMod:peername(Socket) - end. - -spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state(). process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> case xmpp:is_stanza(El) of @@ -408,12 +419,12 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> Txt = xmpp:io_format_error(Reason), Err = #sasl_failure{reason = 'malformed-request', text = xmpp:mk_text(Txt, MyLang)}, - send_element(State, Err); + send_pkt(State, Err); {<<"starttls">>, ?NS_TLS} -> - send_element(State, #starttls_failure{}); + send_pkt(State, #starttls_failure{}); {<<"compress">>, ?NS_COMPRESS} -> Err = #compress_failure{reason = 'setup-failed'}, - send_element(State, Err); + send_pkt(State, Err); _ -> %% Maybe add something more? State @@ -434,9 +445,9 @@ process_stream(#stream_start{xmlns = XML_NS, stream_xmlns = STREAM_NS}, #{xmlns := NS} = State) when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> - send_element(State, xmpp:serr_invalid_namespace()); + send_pkt(State, xmpp:serr_invalid_namespace()); process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> - send_element(State, xmpp:serr_unsupported_version()); + send_pkt(State, xmpp:serr_unsupported_version()); process_stream(#stream_start{lang = Lang}, #{xmlns := ?NS_CLIENT, lang := DefaultLang} = State) when size(Lang) > 35 -> @@ -445,14 +456,14 @@ process_stream(#stream_start{lang = Lang}, %% language tags MUST allow for language tags of at least 35 characters. %% Do not store long language tag to avoid possible DoS/flood attacks Txt = <<"Too long value of 'xml:lang' attribute">>, - send_element(State, xmpp:serr_policy_violation(Txt, DefaultLang)); + send_pkt(State, xmpp:serr_policy_violation(Txt, DefaultLang)); process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) -> Txt = <<"Missing 'to' attribute">>, - send_element(State, xmpp:serr_improper_addressing(Txt, Lang)); + send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)); process_stream(#stream_start{to = #jid{luser = U, lresource = R}}, #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> -> Txt = <<"Improper 'to' attribute">>, - send_element(State, xmpp:serr_improper_addressing(Txt, Lang)); + send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)); process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart, #{xmlns := ?NS_COMPONENT, mod := Mod} = State) -> State1 = State#{remote_server => RemoteServer, @@ -509,29 +520,29 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> #starttls{} -> process_starttls_failure(unexpected_starttls_request, State); #sasl_auth{} when StateName == wait_for_starttls -> - send_element(State, #sasl_failure{reason = 'encryption-required'}); + send_pkt(State, #sasl_failure{reason = 'encryption-required'}); #sasl_auth{} when StateName == wait_for_sasl_request -> process_sasl_request(Pkt, State); #sasl_auth{} -> Txt = <<"SASL negotiation is not allowed in this state">>, - send_element(State, #sasl_failure{reason = 'not-authorized', + send_pkt(State, #sasl_failure{reason = 'not-authorized', text = xmpp:mk_text(Txt, Lang)}); #sasl_response{} when StateName == wait_for_starttls -> - send_element(State, #sasl_failure{reason = 'encryption-required'}); + send_pkt(State, #sasl_failure{reason = 'encryption-required'}); #sasl_response{} when StateName == wait_for_sasl_response -> process_sasl_response(Pkt, State); #sasl_response{} -> Txt = <<"SASL negotiation is not allowed in this state">>, - send_element(State, #sasl_failure{reason = 'not-authorized', + send_pkt(State, #sasl_failure{reason = 'not-authorized', text = xmpp:mk_text(Txt, Lang)}); #sasl_abort{} when StateName == wait_for_sasl_response -> process_sasl_abort(State); #sasl_abort{} -> - send_element(State, #sasl_failure{reason = 'aborted'}); + send_pkt(State, #sasl_failure{reason = 'aborted'}); #sasl_success{} -> State; #compress{} when StateName == wait_for_sasl_response -> - send_element(State, #compress_failure{reason = 'setup-failed'}); + send_pkt(State, #compress_failure{reason = 'setup-failed'}); #compress{} -> process_compress(Pkt, State); #handshake{} when StateName == wait_for_handshake -> @@ -570,7 +581,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) -> {ok, #iq{type = set, sub_els = [_]} = Pkt2} when NS == ?NS_CLIENT -> case xmpp:get_subtag(Pkt2, #xmpp_session{}) of #xmpp_session{} -> - send_element(State, xmpp:make_iq_result(Pkt2)); + send_pkt(State, xmpp:make_iq_result(Pkt2)); _ -> try Mod:handle_authenticated_packet(Pkt2, State) catch _:undef -> @@ -585,7 +596,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) -> send_error(State, Pkt, Err) end; {error, Err} -> - send_element(State, Err) + send_pkt(State, Err) end. -spec process_bind(xmpp_element(), state()) -> state(). @@ -604,7 +615,7 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt, server := S, resource := NewR} = State1} when NewR /= <<"">> -> Reply = #bind{jid = jid:make(U, S, NewR)}, - State2 = send_element(State1, xmpp:make_iq_result(Pkt, Reply)), + State2 = send_pkt(State1, xmpp:make_iq_result(Pkt, Reply)), process_stream_established(State2); {error, #stanza_error{}, State1} = Err -> send_error(State1, Pkt, Err) @@ -646,7 +657,7 @@ process_handshake(#handshake{data = Digest}, case is_disconnected(State1) of true -> State1; false -> - State2 = send_element(State1, #handshake{}), + State2 = send_pkt(State1, #handshake{}), process_stream_established(State2) end; false -> @@ -656,7 +667,7 @@ process_handshake(#handshake{data = Digest}, end, case is_disconnected(State1) of true -> State1; - false -> send_element(State1, xmpp:serr_not_authorized()) + false -> send_pkt(State1, xmpp:serr_not_authorized()) end end. @@ -674,7 +685,7 @@ process_stream_established(#{mod := Mod} = State) -> -spec process_compress(compress(), state()) -> state(). process_compress(#compress{}, #{stream_compressed := true} = State) -> - send_element(State, #compress_failure{reason = 'setup-failed'}); + send_pkt(State, #compress_failure{reason = 'setup-failed'}); process_compress(#compress{methods = HisMethods}, #{socket := Socket, sockmod := SockMod, mod := Mod} = State) -> MyMethods = try Mod:compress_methods(State) @@ -683,44 +694,60 @@ process_compress(#compress{methods = HisMethods}, CommonMethods = lists_intersection(MyMethods, HisMethods), case lists:member(<<"zlib">>, CommonMethods) of true -> - BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})), - ZlibSocket = SockMod:compress(Socket, BCompressed), - State#{socket => ZlibSocket, - stream_id => new_id(), - stream_header_sent => false, - stream_restarted => true, - stream_state => wait_for_stream, - stream_compressed => true}; - false -> - send_element(State, #compress_failure{reason = 'unsupported-method'}) - end. - --spec process_starttls(state()) -> state(). -process_starttls(#{socket := Socket, - sockmod := SockMod, mod := Mod} = State) -> - TLSOpts = try Mod:tls_options(State) - catch _:undef -> [] - end, - case SockMod:starttls(Socket, TLSOpts) of - {ok, TLSSocket} -> - State1 = send_element(State, #starttls_proceed{}), + State1 = send_pkt(State, #compressed{}), case is_disconnected(State1) of true -> State1; false -> - State1#{socket => TLSSocket, - stream_id => new_id(), - stream_header_sent => false, - stream_restarted => true, - stream_state => wait_for_stream, - stream_encrypted => true} + case SockMod:compress(Socket) of + {ok, ZlibSocket} -> + State1#{socket => ZlibSocket, + stream_id => new_id(), + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + stream_compressed => true}; + {error, _} -> + Err = #compress_failure{reason = 'setup-failed'}, + send_pkt(State1, Err) + end end; - {error, Reason} -> - process_starttls_failure(Reason, State) + false -> + send_pkt(State, #compress_failure{reason = 'unsupported-method'}) + end. + +-spec process_starttls(state()) -> state(). +process_starttls(#{stream_encrypted := true} = State) -> + process_starttls_failure(already_encrypted, State); +process_starttls(#{socket := Socket, + sockmod := SockMod, mod := Mod} = State) -> + case is_starttls_available(State) of + true -> + TLSOpts = try Mod:tls_options(State) + catch _:undef -> [] + end, + case SockMod:starttls(Socket, TLSOpts) of + {ok, TLSSocket} -> + State1 = send_pkt(State, #starttls_proceed{}), + case is_disconnected(State1) of + true -> State1; + false -> + State1#{socket => TLSSocket, + stream_id => new_id(), + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + stream_encrypted => true} + end; + {error, Reason} -> + process_starttls_failure(Reason, State) + end; + false -> + process_starttls_failure(starttls_unsupported, State) end. -spec process_starttls_failure(term(), state()) -> state(). process_starttls_failure(Why, State) -> - State1 = send_element(State, #starttls_failure{}), + State1 = send_pkt(State, #starttls_failure{}), case is_disconnected(State1) of true -> State1; false -> process_stream_end({tls, Why}, State1) @@ -780,17 +807,17 @@ process_sasl_success(Props, ServerOut, mod := Mod, sasl_mech := Mech} = State) -> User = identity(Props), AuthModule = proplists:get_value(auth_module, Props), - State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State) - catch _:undef -> State - end, + State1 = send_pkt(State, #sasl_success{text = ServerOut}), case is_disconnected(State1) of true -> State1; false -> - SockMod:reset_stream(Socket), - State2 = send_element(State1, #sasl_success{text = ServerOut}), + State2 = try Mod:handle_auth_success(User, Mech, AuthModule, State1) + catch _:undef -> State1 + end, case is_disconnected(State2) of true -> State2; false -> + SockMod:reset_stream(Socket), State3 = maps:remove(sasl_state, maps:remove(sasl_mech, State2)), State3#{stream_id => new_id(), @@ -806,19 +833,23 @@ process_sasl_success(Props, ServerOut, process_sasl_continue(ServerOut, NewSASLState, State) -> State1 = State#{sasl_state => NewSASLState, stream_state => wait_for_sasl_response}, - send_element(State1, #sasl_challenge{text = ServerOut}). + send_pkt(State1, #sasl_challenge{text = ServerOut}). -spec process_sasl_failure(atom(), binary(), state()) -> state(). process_sasl_failure(Err, User, #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) -> {Reason, Text} = format_sasl_error(Mech, Err), - State1 = try Mod:handle_auth_failure(User, Mech, Text, State) - catch _:undef -> State - end, - State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)), - State3 = State2#{stream_state => wait_for_sasl_request}, - send_element(State3, #sasl_failure{reason = Reason, - text = xmpp:mk_text(Text, Lang)}). + State1 = send_pkt(State, #sasl_failure{reason = Reason, + text = xmpp:mk_text(Text, Lang)}), + case is_disconnected(State1) of + true -> State1; + false -> + State2 = try Mod:handle_auth_failure(User, Mech, Text, State1) + catch _:undef -> State1 + end, + State3 = maps:remove(sasl_state, maps:remove(sasl_mech, State2)), + State3#{stream_state => wait_for_sasl_request} + end. -spec process_sasl_abort(state()) -> state(). process_sasl_abort(State) -> @@ -835,7 +866,7 @@ send_features(#{stream_version := {1,0}, ++ get_tls_feature(State) ++ get_bind_feature(State) ++ get_session_feature(State) ++ get_other_features(State) end, - send_element(State, #stream_features{sub_els = Features}); + send_pkt(State, #stream_features{sub_els = Features}); send_features(State) -> %% clients and servers from stone age State. @@ -849,10 +880,13 @@ get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod, TLSVerify = try Mod:tls_verify(State) catch _:undef -> false end, - if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) -> - [<<"EXTERNAL">>|Mechs]; - true -> - Mechs + Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) -> + [<<"EXTERNAL">>|Mechs]; + true -> + Mechs + end, + try Mod:sasl_mechanisms(Mechs1, State) + catch _:undef -> Mechs1 end. -spec get_sasl_feature(state()) -> [sasl_mechanisms()]. @@ -882,8 +916,13 @@ get_compress_feature(_) -> -spec get_tls_feature(state()) -> [starttls()]. get_tls_feature(#{stream_authenticated := false, stream_encrypted := false} = State) -> - TLSRequired = is_starttls_required(State), - [#starttls{required = TLSRequired}]; + case is_starttls_available(State) of + true -> + TLSRequired = is_starttls_required(State), + [#starttls{required = TLSRequired}]; + false -> + [] + end; get_tls_feature(_) -> []. @@ -913,6 +952,12 @@ get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) -> [] end. +-spec is_starttls_available(state()) -> boolean(). +is_starttls_available(#{mod := Mod} = State) -> + try Mod:tls_enabled(State) + catch _:undef -> true + end. + -spec is_starttls_required(state()) -> boolean(). is_starttls_required(#{mod := Mod} = State) -> try Mod:tls_required(State) @@ -967,13 +1012,14 @@ send_header(#{stream_id := StreamID, lang := MyLang, xmlns := NS, server := DefaultServer} = State, - #stream_start{to = To, lang = HisLang, version = HisVersion}) -> + #stream_start{to = HisTo, from = HisFrom, + lang = HisLang, version = HisVersion}) -> Lang = select_lang(MyLang, HisLang), NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK; true -> <<"">> end, - From = case To of - #jid{} -> To; + From = case HisTo of + #jid{} -> HisTo; undefined -> jid:make(DefaultServer) end, Version = case HisVersion of @@ -981,45 +1027,40 @@ send_header(#{stream_id := StreamID, {0,_} -> HisVersion; _ -> MyVersion end, - Header = xmpp:encode(#stream_start{version = Version, - lang = Lang, - xmlns = NS, - stream_xmlns = ?NS_STREAM, - db_xmlns = NS_DB, - id = StreamID, - from = From}), + StreamStart = #stream_start{version = Version, + lang = Lang, + xmlns = NS, + stream_xmlns = ?NS_STREAM, + db_xmlns = NS_DB, + id = StreamID, + to = HisFrom, + from = From}, State1 = State#{lang => Lang, stream_version => Version, stream_header_sent => true}, - case send_text(State1, fxml:element_to_header(Header)) of + case socket_send(State1, StreamStart) of ok -> State1; {error, Why} -> process_stream_end({socket, Why}, State1) end; send_header(State, _) -> State. --spec send_element(state(), xmpp_element()) -> state(). -send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> - El = xmpp:encode(Pkt, NS), - Data = fxml:element_to_binary(El), - Result = send_text(State, Data), +-spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). +send_pkt(#{mod := Mod} = State, Pkt) -> + Result = socket_send(State, Pkt), State1 = try Mod:handle_send(Pkt, Result, State) catch _:undef -> State end, - case is_disconnected(State1) of - true -> State1; - false -> - case Result of - _ when is_record(Pkt, stream_error) -> - process_stream_end({stream, {out, Pkt}}, State1); - ok -> - State1; - {error, Why} -> - process_stream_end({socket, Why}, State1) - end + case Result of + _ when is_record(Pkt, stream_error) -> + process_stream_end({stream, {out, Pkt}}, State1); + ok -> + State1; + {error, Why} -> + process_stream_end({socket, Why}, State1) end. --spec send_error(state(), xmpp_element(), stanza_error()) -> state(). +-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state(). send_error(State, Pkt, Err) -> case xmpp:is_stanza(Pkt) of true -> @@ -1030,7 +1071,7 @@ send_error(State, Pkt, Err) -> <<"error">> -> State; _ -> ErrPkt = xmpp:make_error(Pkt, Err), - send_element(State, ErrPkt) + send_pkt(State, ErrPkt) end; false -> State @@ -1038,15 +1079,23 @@ send_error(State, Pkt, Err) -> -spec send_trailer(state()) -> state(). send_trailer(State) -> - send_text(State, <<"">>), + socket_send(State, trailer), close_socket(State). --spec send_text(state(), binary()) -> ok | {error, inet:posix()}. -send_text(#{socket := Sock, sockmod := SockMod, - stream_state := StateName, - stream_header_sent := true}, Data) when StateName /= disconnected -> - SockMod:send(Sock, Data); -send_text(_, _) -> +-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}. +socket_send(#{socket := Sock, sockmod := SockMod, + stream_state := StateName, + xmlns := NS, + stream_header_sent := true}, Pkt) when StateName /= disconnected -> + case Pkt of + trailer -> + SockMod:send_trailer(Sock); + #stream_start{} -> + SockMod:send_header(Sock, xmpp:encode(Pkt)); + _ -> + SockMod:send_element(Sock, xmpp:encode(Pkt, NS)) + end; +socket_send(_, _) -> {error, closed}. -spec close_socket(state()) -> state(). @@ -1096,6 +1145,12 @@ format_sasl_error(<<"EXTERNAL">>, Err) -> format_sasl_error(Mech, Err) -> cyrsasl:format_error(Mech, Err). +-spec format_tls_error(atom() | binary()) -> list(). +format_tls_error(Reason) when is_atom(Reason) -> + format_inet_error(Reason); +format_tls_error(Reason) -> + Reason. + -spec format(io:format(), list()) -> binary(). format(Fmt, Args) -> iolist_to_binary(io_lib:format(Fmt, Args)). diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl index adbc6ffba..3dcecf6f6 100644 --- a/src/xmpp_stream_out.erl +++ b/src/xmpp_stream_out.erl @@ -20,9 +20,11 @@ %%% %%%------------------------------------------------------------------- -module(xmpp_stream_out). --behaviour(gen_server). +-define(GEN_SERVER, gen_server). +-behaviour(?GEN_SERVER). -protocol({rfc, 6120}). +-protocol({xep, 114, '1.6'}). %% API -export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1, @@ -42,7 +44,6 @@ -define(TCP_SEND_TIMEOUT, 15000). -include("xmpp.hrl"). --include("logger.hrl"). -include_lib("kernel/include/inet.hrl"). -type state() :: map(). @@ -53,31 +54,87 @@ -type stop_reason() :: {idna, bad_string} | {dns, inet:posix() | inet_res:res_error()} | {stream, reset | {in | out, stream_error()}} | - {tls, term()} | + {tls, inet:posix() | atom() | binary()} | {pkix, binary()} | {auth, atom() | binary() | string()} | {socket, inet:posix() | closed | timeout} | internal_failure. --callback init(list()) -> {ok, state()} | {stop, term()} | ignore. +-callback init(list()) -> {ok, state()} | {error, term()} | ignore. +-callback handle_cast(term(), state()) -> state(). +-callback handle_call(term(), term(), state()) -> state(). +-callback handle_info(term(), state()) -> state(). +-callback terminate(term(), state()) -> any(). +-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. +-callback handle_stream_start(stream_start(), state()) -> state(). +-callback handle_stream_established(state()) -> state(). +-callback handle_stream_downgraded(stream_start(), state()) -> state(). +-callback handle_stream_end(stop_reason(), state()) -> state(). +-callback handle_cdata(binary(), state()) -> state(). +-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state(). +-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state(). +-callback handle_timeout(state()) -> state(). +-callback handle_authenticated_features(stream_features(), state()) -> state(). +-callback handle_unauthenticated_features(stream_features(), state()) -> state(). +-callback handle_auth_success(cyrsasl:mechanism(), state()) -> state(). +-callback handle_auth_failure(cyrsasl:mechanism(), binary(), state()) -> state(). +-callback handle_packet(xmpp_element(), state()) -> state(). +-callback tls_options(state()) -> [proplists:property()]. +-callback tls_required(state()) -> boolean(). +-callback tls_verify(state()) -> boolean(). +-callback tls_enabled(state()) -> boolean(). +-callback dns_timeout(state()) -> timeout(). +-callback dns_retries(state()) -> non_neg_integer(). +-callback default_port(state()) -> inet:port_number(). +-callback address_families(state()) -> [inet:address_family()]. +-callback connect_timeout(state()) -> timeout(). + +-optional_callbacks([init/1, + handle_cast/2, + handle_call/3, + handle_info/2, + terminate/2, + code_change/3, + handle_stream_start/2, + handle_stream_established/1, + handle_stream_downgraded/2, + handle_stream_end/2, + handle_cdata/2, + handle_send/3, + handle_recv/3, + handle_timeout/1, + handle_authenticated_features/2, + handle_unauthenticated_features/2, + handle_auth_success/2, + handle_auth_failure/3, + handle_packet/2, + tls_options/1, + tls_required/1, + tls_verify/1, + tls_enabled/1, + dns_timeout/1, + dns_retries/1, + default_port/1, + address_families/1, + connect_timeout/1]). %%%=================================================================== %%% API %%%=================================================================== start(Mod, Args, Opts) -> - gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). start_link(Mod, Args, Opts) -> - gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). call(Ref, Msg, Timeout) -> - gen_server:call(Ref, Msg, Timeout). + ?GEN_SERVER:call(Ref, Msg, Timeout). cast(Ref, Msg) -> - gen_server:cast(Ref, Msg). + ?GEN_SERVER:cast(Ref, Msg). reply(Ref, Reply) -> - gen_server:reply(Ref, Reply). + ?GEN_SERVER:reply(Ref, Reply). -spec connect(pid()) -> ok. connect(Ref) -> @@ -98,7 +155,7 @@ stop(_) -> send(Pid, Pkt) when is_pid(Pid) -> cast(Pid, {send, Pkt}); send(#{owner := Owner} = State, Pkt) when Owner == self() -> - send_element(State, Pkt); + send_pkt(State, Pkt); send(_, _) -> erlang:error(badarg). @@ -154,7 +211,8 @@ format_error({dns, Reason}) -> format_error({socket, Reason}) -> format("Connection failed: ~s", [format_inet_error(Reason)]); format_error({pkix, Reason}) -> - format("Peer certificate rejected: ~s", [Reason]); + {_, ErrTxt} = xmpp_stream_pkix:format_error(Reason), + format("Peer certificate rejected: ~s", [ErrTxt]); format_error({stream, reset}) -> <<"Stream reset by peer">>; format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) -> @@ -162,7 +220,7 @@ format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) -> format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) -> format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]); format_error({tls, Reason}) -> - format("TLS failed: ~w", [Reason]); + format("TLS failed: ~s", [format_tls_error(Reason)]); format_error({auth, Reason}) -> format("Authentication failed: ~s", [Reason]); format_error(internal_failure) -> @@ -199,8 +257,10 @@ init([Mod, SockMod, From, To, Opts]) -> {ok, State1} -> {_, State2, Timeout} = noreply(State1), {ok, State2, Timeout}; - Err -> - Err + {error, Reason} -> + {stop, Reason}; + ignore -> + ignore end. -spec handle_call(term(), term(), state()) -> noreply(). @@ -239,7 +299,7 @@ handle_cast(connect, State) -> %% Ignoring connection attempts in other states noreply(State); handle_cast({send, Pkt}, State) -> - noreply(send_element(State, Pkt)); + noreply(send_pkt(State, Pkt)); handle_cast(stop, State) -> {stop, normal, State}; handle_cast(Cast, #{mod := Mod} = State) -> @@ -257,12 +317,12 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, #stream_start{} = Pkt -> process_stream(Pkt, State); _ -> - send_element(State, xmpp:serr_invalid_xml()) + send_pkt(State, xmpp:serr_invalid_xml()) catch _:{xmpp_codec, Why} -> Txt = xmpp:io_format_error(Why), Lang = select_lang(MyLang, xmpp:get_lang(El)), Err = xmpp:serr_invalid_xml(Txt, Lang), - send_element(State, Err) + send_pkt(State, Err) end); handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> State1 = send_header(State), @@ -276,7 +336,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> _ -> xmpp:serr_not_well_formed() end, - send_element(State1, Err) + send_pkt(State1, Err) end); handle_info({'$gen_event', {xmlstreamelement, El}}, #{xmlns := NS, mod := Mod} = State) -> @@ -291,7 +351,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}}, false -> process_element(Pkt, State1) end catch _:{xmpp_codec, Why} -> - State1 = try Mod:handle_recv(El, undefined, State) + State1 = try Mod:handle_recv(El, {error, Why}, State) catch _:undef -> State end, case is_disconnected(State1) of @@ -312,7 +372,7 @@ handle_info(timeout, #{mod := Mod} = State) -> Disconnected = is_disconnected(State), noreply(try Mod:handle_timeout(State) catch _:undef when not Disconnected -> - send_element(State, xmpp:serr_connection_timeout()); + send_pkt(State, xmpp:serr_connection_timeout()); _:undef -> stop(State) end); @@ -384,9 +444,9 @@ process_stream(#stream_start{xmlns = XML_NS, stream_xmlns = STREAM_NS}, #{xmlns := NS} = State) when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> - send_element(State, xmpp:serr_invalid_namespace()); + send_pkt(State, xmpp:serr_invalid_namespace()); process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> - send_element(State, xmpp:serr_unsupported_version()); + send_pkt(State, xmpp:serr_unsupported_version()); process_stream(#stream_start{lang = Lang, id = ID, version = Version} = StreamStart, #{mod := Mod} = State) -> @@ -451,15 +511,19 @@ process_features(#stream_features{sub_els = Els} = StreamFeatures, true -> State1; false -> TLSRequired = is_starttls_required(State1), + TLSAvailable = is_starttls_available(State1), %% TODO: improve xmpp.erl Msg = #message{sub_els = Els}, case xmpp:get_subtag(Msg, #starttls{}) of false when TLSRequired and not Encrypted -> Txt = <<"Use of STARTTLS required">>, - send_element(State1, xmpp:err_policy_violation(Txt, Lang)); - #starttls{} when not Encrypted -> + send_pkt(State1, xmpp:serr_policy_violation(Txt, Lang)); + #starttls{required = true} when not TLSAvailable and not Encrypted -> + Txt = <<"Use of STARTTLS forbidden">>, + send_pkt(State1, xmpp:serr_policy_violation(Txt, Lang)); + #starttls{} when TLSAvailable and not Encrypted -> State2 = State1#{stream_state => wait_for_starttls_response}, - send_element(State2, #starttls{}); + send_pkt(State2, #starttls{}); _ -> State2 = process_cert_verification(State1), case is_disconnected(State2) of @@ -497,7 +561,7 @@ process_sasl_mechanisms(Mechs, #{user := User, server := Server} = State) -> true -> State1 = State#{stream_state => wait_for_sasl_response}, Authzid = jid:to_string(jid:make(User, Server)), - send_element(State1, #sasl_auth{mechanism = Mech, text = Authzid}); + send_pkt(State1, #sasl_auth{mechanism = Mech, text = Authzid}); false -> process_sasl_failure( #sasl_failure{reason = 'invalid-mechanism'}, State) @@ -527,12 +591,12 @@ process_stream_downgrade(StreamStart, TLSRequired = is_starttls_required(State), if not Encrypted and TLSRequired -> Txt = <<"Use of STARTTLS required">>, - send_element(State, xmpp:err_policy_violation(Txt, Lang)); + send_pkt(State, xmpp:serr_policy_violation(Txt, Lang)); true -> State1 = State#{stream_state => downgraded}, try Mod:handle_stream_downgraded(StreamStart, State1) catch _:undef -> - send_element(State1, xmpp:serr_unsupported_version()) + send_pkt(State1, xmpp:serr_unsupported_version()) end end. @@ -576,7 +640,7 @@ process_sasl_success(#{mod := Mod, -spec process_sasl_failure(sasl_failure(), state()) -> state(). process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) -> - try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State) + try Mod:handle_auth_failure(<<"EXTERNAL">>, {auth, Reason}, State) catch _:undef -> process_stream_end({auth, Reason}, State) end. @@ -592,6 +656,12 @@ is_starttls_required(#{mod := Mod} = State) -> catch _:undef -> false end. +-spec is_starttls_available(state()) -> boolean(). +is_starttls_available(#{mod := Mod} = State) -> + try Mod:tls_enabled(State) + catch _:undef -> true + end. + -spec send_header(state()) -> state(). send_header(#{remote_server := RemoteServer, stream_encrypted := Encrypted, @@ -610,40 +680,34 @@ send_header(#{remote_server := RemoteServer, true -> undefined end, - Header = xmpp:encode( - #stream_start{xmlns = NS, - lang = Lang, - stream_xmlns = ?NS_STREAM, - db_xmlns = NS_DB, - from = From, - to = jid:make(RemoteServer), - version = {1,0}}), - case send_text(State, fxml:element_to_header(Header)) of + StreamStart = #stream_start{xmlns = NS, + lang = Lang, + stream_xmlns = ?NS_STREAM, + db_xmlns = NS_DB, + from = From, + to = jid:make(RemoteServer), + version = {1,0}}, + case socket_send(State, StreamStart) of ok -> State; {error, Why} -> process_stream_end({socket, Why}, State) end. --spec send_element(state(), xmpp_element()) -> state(). -send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> - El = xmpp:encode(Pkt, NS), - Data = fxml:element_to_binary(El), - State1 = try Mod:handle_send(Pkt, El, Data, State) +-spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). +send_pkt(#{mod := Mod} = State, Pkt) -> + Result = socket_send(State, Pkt), + State1 = try Mod:handle_send(Pkt, Result, State) catch _:undef -> State end, - case is_disconnected(State1) of - true -> State1; - false -> - case send_text(State1, Data) of - _ when is_record(Pkt, stream_error) -> - process_stream_end({stream, {out, Pkt}}, State1); - ok -> - State1; - {error, Why} -> - process_stream_end({socket, Why}, State1) - end + case Result of + _ when is_record(Pkt, stream_error) -> + process_stream_end({stream, {out, Pkt}}, State1); + ok -> + State1; + {error, Why} -> + process_stream_end({socket, Why}, State1) end. --spec send_error(state(), xmpp_element(), stanza_error()) -> state(). +-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state(). send_error(State, Pkt, Err) -> case xmpp:is_stanza(Pkt) of true -> @@ -654,22 +718,29 @@ send_error(State, Pkt, Err) -> <<"error">> -> State; _ -> ErrPkt = xmpp:make_error(Pkt, Err), - send_element(State, ErrPkt) + send_pkt(State, ErrPkt) end; false -> State end. --spec send_text(state(), binary()) -> ok | {error, inet:posix()}. -send_text(#{sockmod := SockMod, socket := Socket, - stream_state := StateName}, Data) when StateName /= disconnected -> - SockMod:send(Socket, Data); -send_text(_, _) -> +-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}. +socket_send(#{sockmod := SockMod, socket := Socket, xmlns := NS, + stream_state := StateName}, Pkt) when StateName /= disconnected -> + case Pkt of + trailer -> + SockMod:send_trailer(Socket); + #stream_start{} -> + SockMod:send_header(Socket, xmpp:encode(Pkt)); + _ -> + SockMod:send_element(Socket, xmpp:encode(Pkt, NS)) + end; +socket_send(_, _) -> {error, closed}. -spec send_trailer(state()) -> state(). send_trailer(State) -> - send_text(State, <<"">>), + socket_send(State, trailer), close_socket(State). -spec close_socket(state()) -> state(). @@ -710,6 +781,12 @@ format_stream_error(Reason, Txt) -> binary_to_list(Data) ++ " (" ++ Slogan ++ ")" end. +-spec format_tls_error(atom() | binary()) -> list(). +format_tls_error(Reason) when is_atom(Reason) -> + format_inet_error(Reason); +format_tls_error(Reason) -> + binary_to_list(Reason). + -spec format(io:format(), list()) -> binary(). format(Fmt, Args) -> iolist_to_binary(io_lib:format(Fmt, Args)). @@ -747,13 +824,16 @@ resolve(Host, State) -> end. -spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error(). +srv_lookup(_Host, #{xmlns := ?NS_COMPONENT}) -> + %% Do not attempt to lookup SRV for component connections + {error, nxdomain}; srv_lookup(Host, State) -> %% Only perform SRV lookups for FQDN names case string:chr(Host, $.) of 0 -> {error, nxdomain}; _ -> - case inet_parse:address(Host) of + case inet:parse_address(Host) of {ok, _} -> {error, nxdomain}; {error, _} -> @@ -763,7 +843,7 @@ srv_lookup(Host, State) -> end end. --spec srv_lookup(string(), non_neg_integer(), integer()) -> +-spec srv_lookup(string(), timeout(), integer()) -> {ok, [host_port()]} | network_error(). srv_lookup(_Host, _Timeout, Retries) when Retries < 1 -> {error, timeout}; @@ -807,7 +887,7 @@ a_lookup([], _State, Err) -> Err. -spec a_lookup(inet:hostname(), inet:port_number(), inet:address_family(), - non_neg_integer(), integer()) -> {ok, [ip_port()]} | network_error(). + timeout(), integer()) -> {ok, [ip_port()]} | network_error(). a_lookup(_Host, _Port, _Family, _Timeout, Retries) when Retries < 1 -> {error, timeout}; a_lookup(Host, Port, Family, Timeout, Retries) -> @@ -861,7 +941,7 @@ connect(AddrPorts, #{sockmod := SockMod} = State) -> Timeout = get_connect_timeout(State), connect(AddrPorts, SockMod, Timeout, {error, nxdomain}). --spec connect([ip_port()], module(), non_neg_integer(), network_error()) -> +-spec connect([ip_port()], module(), timeout(), network_error()) -> {ok, term(), ip_port()} | network_error(). connect([{Addr, Port}|AddrPorts], SockMod, Timeout, _) -> Type = get_addr_type(Addr), @@ -883,12 +963,11 @@ connect([], _SockMod, _Timeout, Err) -> get_addr_type({_, _, _, _}) -> inet; get_addr_type({_, _, _, _, _, _, _, _}) -> inet6. --spec get_dns_timeout(state()) -> non_neg_integer(). +-spec get_dns_timeout(state()) -> timeout(). get_dns_timeout(#{mod := Mod} = State) -> - timer:seconds( - try Mod:dns_timeout(State) - catch _:undef -> 10 - end). + try Mod:dns_timeout(State) + catch _:undef -> timer:seconds(10) + end. -spec get_dns_retries(state()) -> non_neg_integer(). get_dns_retries(#{mod := Mod} = State) -> @@ -909,9 +988,8 @@ get_address_families(#{mod := Mod} = State) -> catch _:undef -> [inet, inet6] end. --spec get_connect_timeout(state()) -> non_neg_integer(). +-spec get_connect_timeout(state()) -> timeout(). get_connect_timeout(#{mod := Mod} = State) -> - timer:seconds( - try Mod:connect_timeout(State) - catch _:undef -> 10 - end). + try Mod:connect_timeout(State) + catch _:undef -> timer:seconds(10) + end.