diff --git a/ejabberd.yml.example b/ejabberd.yml.example index 8c6d026ee..fc8555265 100644 --- a/ejabberd.yml.example +++ b/ejabberd.yml.example @@ -702,6 +702,10 @@ modules: mod_vcard: search: false mod_version: {} + # Since 17.02 S2S Dialback (XEP-0220) and Stream Management (XEP-0198) + # are implemented in modules + mod_s2s_dialback: {} + mod_sm: {} ## ## Enable modules with custom options in a specific virtual host diff --git a/include/ejabberd.hrl b/include/ejabberd.hrl index f10d8d81e..419e91d0e 100644 --- a/include/ejabberd.hrl +++ b/include/ejabberd.hrl @@ -41,8 +41,6 @@ -define(COPYRIGHT, "Copyright (c) 2002-2017 ProcessOne"). --define(S2STIMEOUT, timer:minutes(10)). - %%-define(DBGFSM, true). -record(scram, diff --git a/include/ejabberd_router.hrl b/include/ejabberd_router.hrl new file mode 100644 index 000000000..8de23c4c7 --- /dev/null +++ b/include/ejabberd_router.hrl @@ -0,0 +1,6 @@ +-type local_hint() :: undefined | integer() | {apply, atom(), atom()}. + +-record(route, {domain :: binary(), + server_host :: binary(), + pid :: undefined | pid(), + local_hint :: local_hint()}). diff --git a/include/mod_muc.hrl b/include/mod_muc.hrl index fd62436e0..ef66e2c2b 100644 --- a/include/mod_muc.hrl +++ b/include/mod_muc.hrl @@ -22,11 +22,15 @@ {'_', binary()}, opts = [] :: list() | '_'}). --record(muc_online_room, - {name_host = {<<"">>, <<"">>} :: {binary(), binary()} | '$1' | - {'_', binary()} | '_', - pid = self() :: pid() | '$2' | '_' | '$1'}). - -record(muc_registered, {us_host = {{<<"">>, <<"">>}, <<"">>} :: {{binary(), binary()}, binary()} | '$1', nick = <<"">> :: binary()}). + +-record(muc_online_room, + {name_host :: {binary(), binary()} | '$1' | {'_', binary()} | '_', + pid :: pid() | '$2' | '_' | '$1'}). + +-record(muc_online_users, {us :: {binary(), binary()}, + resource :: binary() | '_', + room :: binary() | '_' | '$1', + host :: binary() | '_' | '$2'}). diff --git a/include/mod_muc_room.hrl b/include/mod_muc_room.hrl index 5c0b12e2a..010dc6e99 100644 --- a/include/mod_muc_room.hrl +++ b/include/mod_muc_room.hrl @@ -120,10 +120,3 @@ room_shaper = none :: shaper:shaper(), room_queue = queue:new() :: ?TQUEUE }). - --record(muc_online_users, {us = {<<>>, <<>>} :: {binary(), binary()}, - resource = <<>> :: binary() | '_', - room = <<>> :: binary() | '_' | '$1', - host = <<>> :: binary() | '_' | '$2'}). - --type muc_online_users() :: #muc_online_users{}. diff --git a/src/cyrsasl.erl b/src/cyrsasl.erl index fcc83d975..014df7e80 100644 --- a/src/cyrsasl.erl +++ b/src/cyrsasl.erl @@ -25,13 +25,11 @@ -module(cyrsasl). --behaviour(ejabberd_config). - -author('alexey@process-one.net'). -export([start/0, register_mechanism/3, listmech/1, server_new/7, server_start/3, server_step/2, - get_mech/1, format_error/2, opt_type/1]). + get_mech/1, format_error/2]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -113,15 +111,9 @@ format_error(Mech, Reason) -> PasswordType :: password_type()) -> any(). register_mechanism(Mechanism, Module, PasswordType) -> - case is_disabled(Mechanism) of - false -> ets:insert(sasl_mechanism, #sasl_mechanism{mechanism = Mechanism, module = Module, - password_type = PasswordType}); - true -> - ?DEBUG("SASL mechanism ~p is disabled", [Mechanism]), - true - end. + password_type = PasswordType}). check_credentials(_State, Props) -> User = proplists:get_value(authzid, Props, <<>>), @@ -134,7 +126,7 @@ check_credentials(_State, Props) -> -spec listmech(Host ::binary()) -> Mechanisms::mechanisms(). listmech(Host) -> - Mechs = ets:select(sasl_mechanism, + ets:select(sasl_mechanism, [{#sasl_mechanism{mechanism = '$1', password_type = '$2', _ = '_'}, case catch ejabberd_auth:store_type(Host) of @@ -146,8 +138,7 @@ listmech(Host) -> []; _Else -> [] end, - ['$1']}]), - filter_anonymous(Host, Mechs). + ['$1']}]). -spec server_new(binary(), binary(), binary(), term(), fun(), fun(), fun()) -> sasl_state(). @@ -206,33 +197,3 @@ server_step(State, ClientIn) -> -spec get_mech(sasl_state()) -> binary(). get_mech(#sasl_state{mech_name = Mech}) -> Mech. - -%% Remove the anonymous mechanism from the list if not enabled for the given -%% host -%% --spec filter_anonymous(Host :: binary(), Mechs :: mechanisms()) -> mechanisms(). - -filter_anonymous(Host, Mechs) -> - case ejabberd_auth_anonymous:is_sasl_anonymous_enabled(Host) of - true -> Mechs; - false -> Mechs -- [<<"ANONYMOUS">>] - end. - --spec is_disabled(Mechanism :: mechanism()) -> boolean(). - -is_disabled(Mechanism) -> - Disabled = ejabberd_config:get_option( - disable_sasl_mechanisms, - fun(V) when is_list(V) -> - lists:map(fun(M) -> str:to_upper(M) end, V); - (V) -> - [str:to_upper(V)] - end, []), - lists:member(Mechanism, Disabled). - -opt_type(disable_sasl_mechanisms) -> - fun (V) when is_list(V) -> - lists:map(fun (M) -> str:to_upper(M) end, V); - (V) -> [str:to_upper(V)] - end; -opt_type(_) -> [disable_sasl_mechanisms]. diff --git a/src/cyrsasl_digest.erl b/src/cyrsasl_digest.erl index bee74a6b2..fa5c5338b 100644 --- a/src/cyrsasl_digest.erl +++ b/src/cyrsasl_digest.erl @@ -59,7 +59,7 @@ start(_Opts) -> Fqdn = get_local_fqdn(), - ?INFO_MSG("FQDN used to check DIGEST-MD5 SASL authentication: ~p", + ?INFO_MSG("FQDN used to check DIGEST-MD5 SASL authentication: ~s", [Fqdn]), cyrsasl:register_mechanism(<<"DIGEST-MD5">>, ?MODULE, digest). diff --git a/src/ejabberd_admin.erl b/src/ejabberd_admin.erl index e70d0a1b8..67778e71c 100644 --- a/src/ejabberd_admin.erl +++ b/src/ejabberd_admin.erl @@ -414,7 +414,7 @@ send_service_message_all_mucs(Subject, AnnouncementText) -> fun(ServerHost) -> MUCHost = gen_mod:get_module_opt_host( ServerHost, mod_muc, <<"conference.@HOST@">>), - mod_muc:broadcast_service_message(MUCHost, Message) + mod_muc:broadcast_service_message(ServerHost, MUCHost, Message) end, ?MYHOSTS). diff --git a/src/ejabberd_app.erl b/src/ejabberd_app.erl index 98f664008..493600afc 100644 --- a/src/ejabberd_app.erl +++ b/src/ejabberd_app.erl @@ -54,8 +54,6 @@ start(normal, _Args) -> ejabberd_ctl:init(), ejabberd_commands:init(), ejabberd_admin:start(), - gen_mod:start(), - ext_mod:start(), setup_if_elixir_conf_used(), ejabberd_config:start(), set_settings_from_config(), @@ -66,11 +64,13 @@ start(normal, _Args) -> ejabberd_rdbms:start(), ejabberd_riak_sup:start(), ejabberd_redis:start(), + ejabberd_router:start(), + ejabberd_router_multicast:start(), + ejabberd_local:start(), ejabberd_sm:start(), cyrsasl:start(), - % Profiling - %ejabberd_debug:eprof_start(), - %ejabberd_debug:fprof_start(), + gen_mod:start(), + ext_mod:start(), maybe_add_nameservers(), ejabberd_auth:start(), ejabberd_oauth:start(), @@ -169,7 +169,7 @@ broadcast_c2s_shutdown() -> Children = ejabberd_sm:get_all_pids(), lists:foreach( fun(C2SPid) when node(C2SPid) == node() -> - C2SPid ! system_shutdown; + ejabberd_c2s:send(C2SPid, xmpp:serr_system_shutdown()); (_) -> ok end, Children). diff --git a/src/ejabberd_auth.erl b/src/ejabberd_auth.erl index a4f536627..b77e574d7 100644 --- a/src/ejabberd_auth.erl +++ b/src/ejabberd_auth.erl @@ -42,7 +42,7 @@ get_password_s/2, get_password_with_authmodule/2, is_user_exists/2, is_user_exists_in_other_modules/3, remove_user/2, remove_user/3, plain_password_required/1, - store_type/1, entropy/1]). + store_type/1, entropy/1, backend_type/1]). -export([auth_modules/1, opt_type/1]). @@ -188,7 +188,7 @@ try_register(User, Server, Password) -> true -> {atomic, exists}; false -> LServer = jid:nameprep(Server), - case lists:member(LServer, ?MYHOSTS) of + case ejabberd_router:is_my_host(LServer) of true -> Res = lists:foldl(fun (_M, {atomic, ok} = Res) -> Res; (M, _) -> @@ -412,6 +412,13 @@ entropy(B) -> length(S) * math:log(lists:sum(Set)) / math:log(2) end. +-spec backend_type(atom()) -> atom(). +backend_type(Mod) -> + case atom_to_list(Mod) of + "ejabberd_auth_" ++ T -> list_to_atom(T); + _ -> Mod + end. + %%%---------------------------------------------------------------------- %%% Internal functions %%%---------------------------------------------------------------------- diff --git a/src/ejabberd_auth_anonymous.erl b/src/ejabberd_auth_anonymous.erl index 9223083de..59d4c99e7 100644 --- a/src/ejabberd_auth_anonymous.erl +++ b/src/ejabberd_auth_anonymous.erl @@ -52,17 +52,7 @@ -include("logger.hrl"). -include("jid.hrl"). -%% Create the anonymous table if at least one virtual host has anonymous features enabled -%% Register to login / logout events --record(anonymous, {us = {<<"">>, <<"">>} :: {binary(), binary()}, - sid = ejabberd_sm:make_sid() :: ejabberd_sm:sid()}). - start(Host) -> - %% TODO: Check cluster mode - ejabberd_mnesia:create(?MODULE, anonymous, [{ram_copies, [node()]}, - {type, bag}, - {attributes, record_info(fields, anonymous)}]), - %% The hooks are needed to add / remove users from the anonymous tables ejabberd_hooks:add(sm_register_connection_hook, Host, ?MODULE, register_connection, 100), ejabberd_hooks:add(sm_remove_connection_hook, Host, @@ -119,56 +109,33 @@ allow_multiple_connections(Host) -> fun(V) when is_boolean(V) -> V end, false). -%% Check if user exist in the anonymus database anonymous_user_exist(User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), - US = {LUser, LServer}, - case catch mnesia:dirty_read({anonymous, US}) of - [] -> - false; - [_H|_T] -> - true - end. - -%% Remove connection from Mnesia tables -remove_connection(SID, LUser, LServer) -> - US = {LUser, LServer}, - F = fun () -> mnesia:delete_object({anonymous, US, SID}) - end, - mnesia:transaction(F). + lists:any( + fun({_LResource, Info}) -> + proplists:get_value(auth_module, Info) == ?MODULE + end, ejabberd_sm:get_user_info(User, Server)). %% Register connection -spec register_connection(ejabberd_sm:sid(), jid(), ejabberd_sm:info()) -> ok. -register_connection(SID, +register_connection(_SID, #jid{luser = LUser, lserver = LServer}, Info) -> - AuthModule = proplists:get_value(auth_module, Info, undefined), - case AuthModule == (?MODULE) of - true -> - ejabberd_hooks:run(register_user, LServer, - [LUser, LServer]), - US = {LUser, LServer}, - mnesia:sync_dirty(fun () -> - mnesia:write(#anonymous{us = US, - sid = SID}) - end); - false -> ok + case proplists:get_value(auth_module, Info) of + ?MODULE -> + ejabberd_hooks:run(register_user, LServer, [LUser, LServer]); + false -> + ok end. %% Remove an anonymous user from the anonymous users table -spec unregister_connection(ejabberd_sm:sid(), jid(), ejabberd_sm:info()) -> any(). -unregister_connection(SID, - #jid{luser = LUser, lserver = LServer}, _) -> - purge_hook(anonymous_user_exist(LUser, LServer), LUser, - LServer), - remove_connection(SID, LUser, LServer). - -%% Launch the hook to purge user data only for anonymous users -purge_hook(false, _LUser, _LServer) -> - ok; -purge_hook(true, LUser, LServer) -> - ejabberd_hooks:run(anonymous_purge_hook, LServer, - [LUser, LServer]). +unregister_connection(_SID, + #jid{luser = LUser, lserver = LServer}, Info) -> + case proplists:get_value(auth_module, Info) of + ?MODULE -> + ejabberd_hooks:run(remove_user, LServer, [LUser, LServer]); + _ -> + ok + end. %% --------------------------------- %% Specific anonymous auth functions @@ -258,8 +225,6 @@ get_password_s(User, Server) -> Password end. -%% Returns true if the user exists in the DB or if an anonymous user is logged -%% under the given name is_user_exists(User, Server) -> anonymous_user_exist(User, Server). diff --git a/src/ejabberd_bosh.erl b/src/ejabberd_bosh.erl index 1603d49dd..b38de2515 100644 --- a/src/ejabberd_bosh.erl +++ b/src/ejabberd_bosh.erl @@ -49,7 +49,7 @@ -include("ejabberd.hrl"). -include("logger.hrl"). --include("jlib.hrl"). +-include("xmpp.hrl"). -include("ejabberd_http.hrl"). diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index d39586423..ecd6321d4 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -1,11 +1,8 @@ -%%%---------------------------------------------------------------------- -%%% File : ejabberd_c2s.erl -%%% Author : Alexey Shchepin -%%% Purpose : Serve C2S connection -%%% Created : 16 Nov 2002 by Alexey Shchepin +%%%------------------------------------------------------------------- +%%% Created : 8 Dec 2016 by Evgeny Khramtsov %%% %%% -%%% ejabberd, Copyright (C) 2002-2017 ProcessOne +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne %%% %%% This program is free software; you can redistribute it and/or %%% modify it under the terms of the GNU General Public License as @@ -21,2008 +18,718 @@ %%% with this program; if not, write to the Free Software Foundation, Inc., %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. %%% -%%%---------------------------------------------------------------------- - +%%%------------------------------------------------------------------- -module(ejabberd_c2s). - +-behaviour(xmpp_stream_in). -behaviour(ejabberd_config). +-behaviour(ejabberd_socket). --author('alexey@process-one.net'). +-protocol({rfc, 6121}). --protocol({xep, 78, '2.5'}). --protocol({xep, 138, '2.0'}). --protocol({xep, 198, '1.3'}). --protocol({xep, 356, '7.1'}). - --update_info({update, 0}). - --define(GEN_FSM, p1_fsm). - --behaviour(?GEN_FSM). - -%% External exports --export([start/2, - stop/1, - start_link/2, - close/1, - send_text/2, - send_element/2, - socket_type/0, - get_presence/1, - get_last_presence/1, - get_aux_field/2, - set_aux_field/3, - del_aux_field/2, - get_subscription/2, - get_queued_stanzas/1, - get_csi_state/1, - set_csi_state/2, - get_resume_timeout/1, - set_resume_timeout/2, - send_filtered/5, - broadcast/4, - get_subscribed/1, - transform_listen_option/2]). - --export([init/1, wait_for_stream/2, wait_for_auth/2, - wait_for_feature_request/2, wait_for_bind/2, - wait_for_sasl_response/2, - wait_for_resume/2, session_established/2, - handle_event/3, handle_sync_event/4, code_change/4, - handle_info/3, terminate/3, print_state/1, opt_type/1]). +%% ejabberd_socket callbacks +-export([start/2, start_link/2, socket_type/0]). +%% ejabberd_config callbacks +-export([opt_type/1, transform_listen_option/2]). +%% xmpp_stream_in callbacks +-export([init/1, handle_call/3, handle_cast/2, + handle_info/2, terminate/2, code_change/3]). +-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, + compress_methods/1, bind/2, sasl_mechanisms/2, + get_password_fun/1, check_password_fun/1, check_password_digest_fun/1, + unauthenticated_stream_features/1, authenticated_stream_features/1, + handle_stream_start/2, handle_stream_end/2, + handle_unauthenticated_packet/2, handle_authenticated_packet/2, + handle_auth_success/4, handle_auth_failure/4, handle_send/3, + handle_recv/3, handle_cdata/2, handle_unbinded_packet/2]). +%% Hooks +-export([handle_unexpected_cast/2, + reject_unauthenticated_packet/2, process_closed/2, + process_terminated/2, process_info/2]). +%% API +-export([get_presence/1, get_subscription/2, get_subscribed/1, + open_session/1, call/3, send/2, close/1, close/2, stop/1, + reply/2, copy_state/2, set_timeout/2, add_hooks/1]). -include("ejabberd.hrl"). +-include("xmpp.hrl"). -include("logger.hrl"). --include("xmpp.hrl"). -%%-include("legacy.hrl"). - --include("mod_privacy.hrl"). - -define(SETS, gb_sets). --define(DICT, dict). -%% pres_a contains all the presence available send (either through roster mechanism or directed). -%% Directed presence unavailable remove user from pres_a. --record(state, {socket, - sockmod, - socket_monitor, - xml_socket, - streamid, - sasl_state, - access, - shaper, - zlib = false, - tls = false, - tls_required = false, - tls_enabled = false, - tls_options = [], - authenticated = false, - jid, - user = <<"">>, server = <<"">>, resource = <<"">>, - sid, - pres_t = ?SETS:new(), - pres_f = ?SETS:new(), - pres_a = ?SETS:new(), - pres_last, - pres_timestamp, - privacy_list = #userlist{}, - conn = unknown, - auth_module = unknown, - ip, - aux_fields = [], - csi_state = active, - mgmt_state, - mgmt_xmlns, - mgmt_queue, - mgmt_max_queue, - mgmt_pending_since, - mgmt_timeout, - mgmt_max_timeout, - mgmt_ack_timeout, - mgmt_ack_timer, - mgmt_resend, - mgmt_stanzas_in = 0, - mgmt_stanzas_out = 0, - mgmt_stanzas_req = 0, - ask_offline = true, - lang = <<"">>}). - --type state_name() :: wait_for_stream | wait_for_auth | - wait_for_feature_request | wait_for_bind | - wait_for_sasl_response | wait_for_resume | - session_established. --type state() :: #state{}. --type fsm_stop() :: {stop, normal, state()}. --type fsm_next() :: {next_state, state_name(), state(), non_neg_integer()}. --type fsm_reply() :: {reply, any(), state_name(), state(), non_neg_integer()}. --type fsm_transition() :: fsm_stop() | fsm_next(). +-type state() :: map(). -export_type([state/0]). -%-define(DBGFSM, true). - --ifdef(DBGFSM). - --define(FSMOPTS, [{debug, [trace]}]). - --else. - --define(FSMOPTS, []). - --endif. - -%% This is the timeout to apply between event when starting a new -%% session: --define(C2S_OPEN_TIMEOUT, 60000). - --define(C2S_HIBERNATE_TIMEOUT, ejabberd_config:get_option(c2s_hibernate, fun(X) when is_integer(X); X == hibernate-> X end, 90000)). - --define(STREAM_HEADER, - <<"">>). - --define(STREAM_TRAILER, <<"">>). - -%% XEP-0198: - --define(IS_STREAM_MGMT_PACKET(Pkt), - is_record(Pkt, sm_enable) or - is_record(Pkt, sm_resume) or - is_record(Pkt, sm_a) or - is_record(Pkt, sm_r)). - -%%%---------------------------------------------------------------------- -%%% API -%%%---------------------------------------------------------------------- +%%%=================================================================== +%%% ejabberd_socket API +%%%=================================================================== start(SockData, Opts) -> - ?GEN_FSM:start(ejabberd_c2s, - [SockData, Opts], - fsm_limit_opts(Opts) ++ ?FSMOPTS). - -start_link(SockData, Opts) -> - (?GEN_FSM):start_link(ejabberd_c2s, - [SockData, Opts], - fsm_limit_opts(Opts) ++ ?FSMOPTS). - -socket_type() -> xml_stream. - -%% Return Username, Resource and presence information -get_presence(FsmRef) -> - (?GEN_FSM):sync_send_all_state_event(FsmRef, - {get_presence}, 1000). -get_last_presence(FsmRef) -> - (?GEN_FSM):sync_send_all_state_event(FsmRef, - {get_last_presence}, 1000). - --spec get_aux_field(any(), state()) -> {ok, any()} | error. -get_aux_field(Key, #state{aux_fields = Opts}) -> - case lists:keyfind(Key, 1, Opts) of - {_, Val} -> {ok, Val}; - false -> error + case proplists:get_value(supervisor, Opts, true) of + true -> + supervisor:start_child(ejabberd_c2s_sup, [SockData, Opts]); + _ -> + xmpp_stream_in:start(?MODULE, [SockData, Opts], + ejabberd_config:fsm_limit_opts(Opts)) end. --spec set_aux_field(any(), any(), state()) -> state(). -set_aux_field(Key, Val, - #state{aux_fields = Opts} = State) -> - Opts1 = lists:keydelete(Key, 1, Opts), - State#state{aux_fields = [{Key, Val} | Opts1]}. +start_link(SockData, Opts) -> + xmpp_stream_in:start_link(?MODULE, [SockData, Opts], + ejabberd_config:fsm_limit_opts(Opts)). --spec del_aux_field(any(), state()) -> state(). -del_aux_field(Key, #state{aux_fields = Opts} = State) -> - Opts1 = lists:keydelete(Key, 1, Opts), - State#state{aux_fields = Opts1}. +socket_type() -> + xml_stream. + +%%%=================================================================== +%%% Common API +%%%=================================================================== +-spec call(pid(), term(), non_neg_integer() | infinity) -> term(). +call(Ref, Msg, Timeout) -> + xmpp_stream_in:call(Ref, Msg, Timeout). + +reply(Ref, Reply) -> + xmpp_stream_in:reply(Ref, Reply). + +-spec get_presence(pid()) -> presence(). +get_presence(Ref) -> + call(Ref, get_presence, 1000). -spec get_subscription(jid() | ljid(), state()) -> both | from | to | none. -get_subscription(From = #jid{}, StateData) -> - get_subscription(jid:tolower(From), StateData); -get_subscription(LFrom, StateData) -> - LBFrom = setelement(3, LFrom, <<"">>), - F = (?SETS):is_element(LFrom, StateData#state.pres_f) - orelse - (?SETS):is_element(LBFrom, StateData#state.pres_f), - T = (?SETS):is_element(LFrom, StateData#state.pres_t) - orelse - (?SETS):is_element(LBFrom, StateData#state.pres_t), +get_subscription(#jid{} = From, State) -> + get_subscription(jid:tolower(From), State); +get_subscription(LFrom, #{pres_f := PresF, pres_t := PresT}) -> + LBFrom = jid:remove_resource(LFrom), + F = ?SETS:is_element(LFrom, PresF) orelse ?SETS:is_element(LBFrom, PresF), + T = ?SETS:is_element(LFrom, PresT) orelse ?SETS:is_element(LBFrom, PresT), if F and T -> both; F -> from; T -> to; true -> none end. -get_queued_stanzas(#state{mgmt_queue = Queue} = StateData) -> - lists:map(fun({_N, Time, El}) -> - add_resent_delay_info(StateData, El, Time) - end, queue:to_list(Queue)). +-spec get_subscribed(pid()) -> [ljid()]. +%% Return list of all available resources of contacts +get_subscribed(Ref) -> + call(Ref, get_subscribed, 1000). -get_csi_state(#state{csi_state = CsiState}) -> - CsiState. +close(Ref) -> + xmpp_stream_in:close(Ref). -set_csi_state(#state{} = StateData, CsiState) -> - StateData#state{csi_state = CsiState}; -set_csi_state(FsmRef, CsiState) -> - FsmRef ! {set_csi_state, CsiState}. +close(Ref, SendTrailer) -> + xmpp_stream_in:close(Ref, SendTrailer). -get_resume_timeout(#state{mgmt_timeout = Timeout}) -> - Timeout. +stop(Ref) -> + xmpp_stream_in:stop(Ref). -set_resume_timeout(#state{} = StateData, Timeout) -> - StateData#state{mgmt_timeout = Timeout}; -set_resume_timeout(FsmRef, Timeout) -> - FsmRef ! {set_resume_timeout, Timeout}. +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Pid, Pkt) when is_pid(Pid) -> + xmpp_stream_in:send(Pid, Pkt); +send(#{lserver := LServer} = State, Pkt) -> + Pkt1 = fix_from_to(Pkt, State), + case ejabberd_hooks:run_fold(c2s_filter_send, LServer, {Pkt1, State}, []) of + {drop, State1} -> State1; + {Pkt2, State1} -> xmpp_stream_in:send(State1, Pkt2) + end. --spec send_filtered(pid(), binary(), jid(), jid(), stanza()) -> any(). -send_filtered(FsmRef, Feature, From, To, Packet) -> - FsmRef ! {send_filtered, Feature, From, To, Packet}. +-spec set_timeout(state(), timeout()) -> state(). +set_timeout(State, Timeout) -> + xmpp_stream_in:set_timeout(State, Timeout). --spec broadcast(pid(), any(), jid(), stanza()) -> any(). -broadcast(FsmRef, Type, From, Packet) -> - FsmRef ! {broadcast, Type, From, Packet}. +-spec add_hooks(binary()) -> ok. +add_hooks(Host) -> + ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100), + ejabberd_hooks:add(c2s_terminated, Host, ?MODULE, + process_terminated, 100), + ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE, + reject_unauthenticated_packet, 100), + ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, + process_info, 100), + ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE, + handle_unexpected_cast, 100). --spec stop(pid()) -> any(). -stop(FsmRef) -> (?GEN_FSM):send_event(FsmRef, stop). +%% Copies content of one c2s state to another. +%% This is needed for session migration from one pid to another. +-spec copy_state(state(), state()) -> state(). +copy_state(#{owner := Owner} = NewState, + #{jid := JID, resource := Resource, sid := {Time, _}, + auth_module := AuthModule, lserver := LServer, + pres_t := PresT, pres_a := PresA, + pres_f := PresF} = OldState) -> + State1 = case OldState of + #{pres_last := Pres, pres_timestamp := PresTS} -> + NewState#{pres_last => Pres, pres_timestamp => PresTS}; + _ -> + NewState + end, + Conn = get_conn_type(State1), + State2 = State1#{jid => JID, resource => Resource, + conn => Conn, + sid => {Time, Owner}, + auth_module => AuthModule, + pres_t => PresT, pres_a => PresA, + pres_f => PresF}, + ejabberd_hooks:run_fold(c2s_copy_session, LServer, State2, [OldState]). --spec close(pid()) -> any(). -%% What is the difference between stop and close??? -close(FsmRef) -> (?GEN_FSM):send_event(FsmRef, closed). +-spec open_session(state()) -> {ok, state()} | state(). +open_session(#{user := U, server := S, resource := R, + sid := SID, ip := IP, auth_module := AuthModule} = State) -> + JID = jid:make(U, S, R), + change_shaper(State), + Conn = get_conn_type(State), + State1 = State#{conn => Conn, resource => R, jid => JID}, + Prio = try maps:get(pres_last, State) of + Pres -> get_priority_from_presence(Pres) + catch _:{badkey, _} -> + undefined + end, + Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}], + ejabberd_sm:open_session(SID, U, S, R, Prio, Info), + xmpp_stream_in:establish(State1). -%%%---------------------------------------------------------------------- -%%% Callback functions from gen_fsm -%%%---------------------------------------------------------------------- +%%%=================================================================== +%%% Hooks +%%%=================================================================== +process_info(#{lserver := LServer} = State, + {route, From, To, Packet0}) -> + Packet = xmpp:set_from_to(Packet0, From, To), + {Pass, State1} = case Packet of + #presence{} -> + process_presence_in(State, Packet); + #message{} -> + process_message_in(State, Packet); + #iq{} -> + process_iq_in(State, Packet) + end, + if Pass -> + {Packet1, State2} = ejabberd_hooks:run_fold( + user_receive_packet, LServer, + {Packet, State1}, []), + case Packet1 of + drop -> State2; + _ -> send(State2, Packet1) + end; + true -> + State1 + end; +process_info(State, force_update_presence) -> + try maps:get(pres_last, State) of + Pres -> process_self_presence(State, Pres) + catch _:{badkey, _} -> + State + end; +process_info(State, Info) -> + ?WARNING_MSG("got unexpected info: ~p", [Info]), + State. -init([{SockMod, Socket}, Opts]) -> - Access = gen_mod:get_opt(access, Opts, - fun acl:access_rules_validator/1, all), - Shaper = gen_mod:get_opt(shaper, Opts, - fun acl:shaper_rules_validator/1, none), - XMLSocket = case lists:keysearch(xml_socket, 1, Opts) of - {value, {_, XS}} -> XS; - _ -> false - end, - Zlib = proplists:get_bool(zlib, Opts), - StartTLS = proplists:get_bool(starttls, Opts), - StartTLSRequired = proplists:get_bool(starttls_required, Opts), - TLSEnabled = proplists:get_bool(tls, Opts), - TLS = StartTLS orelse - StartTLSRequired orelse TLSEnabled, - TLSOpts1 = lists:filter(fun ({certfile, _}) -> true; - ({ciphers, _}) -> true; - ({dhfile, _}) -> true; - (_) -> false - end, - Opts), - TLSOpts2 = case lists:keysearch(protocol_options, 1, Opts) of - {value, {_, O}} -> - [_|ProtocolOptions] = lists:foldl( - fun(X, Acc) -> X ++ Acc end, [], - [["|" | binary_to_list(Opt)] || Opt <- O, is_binary(Opt)] - ), - [{protocol_options, iolist_to_binary(ProtocolOptions)} | TLSOpts1]; - _ -> TLSOpts1 +handle_unexpected_cast(State, Msg) -> + ?WARNING_MSG("got unexpected cast: ~p", [Msg]), + State. + +reject_unauthenticated_packet(State, _Pkt) -> + Err = xmpp:serr_not_authorized(), + send(State, Err). + +process_closed(State, Reason) -> + stop(State#{stop_reason => Reason}). + +process_terminated(#{sockmod := SockMod, socket := Socket, jid := JID} = State, + Reason) -> + Status = format_reason(State, Reason), + ?INFO_MSG("(~s) Closing c2s session for ~s: ~s", + [SockMod:pp(Socket), jid:to_string(JID), Status]), + State1 = case maps:is_key(pres_last, State) of + true -> + Pres = #presence{type = unavailable, + status = xmpp:mk_text(Status), + from = JID, + to = jid:remove_resource(JID)}, + broadcast_presence_unavailable(State, Pres); + false -> + State + end, + bounce_message_queue(), + State1; +process_terminated(State, _Reason) -> + State. + +%%%=================================================================== +%%% xmpp_stream_in callbacks +%%%=================================================================== +tls_options(#{lserver := LServer, tls_options := DefaultOpts}) -> + TLSOpts1 = case ejabberd_config:get_option( + {c2s_certfile, LServer}, + fun iolist_to_binary/1, + ejabberd_config:get_option( + {domain_certfile, LServer}, + fun iolist_to_binary/1)) of + undefined -> []; + CertFile -> lists:keystore(certfile, 1, DefaultOpts, + {certfile, CertFile}) + end, + TLSOpts2 = case ejabberd_config:get_option( + {c2s_ciphers, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts1; + Ciphers -> lists:keystore(ciphers, 1, TLSOpts1, + {ciphers, Ciphers}) end, + TLSOpts3 = case ejabberd_config:get_option( + {c2s_protocol_options, LServer}, + fun (Options) -> str:join(Options, <<$|>>) end) of + undefined -> TLSOpts2; + ProtoOpts -> lists:keystore(protocol_options, 1, TLSOpts2, + {protocol_options, ProtoOpts}) + end, + TLSOpts4 = case ejabberd_config:get_option( + {c2s_dhfile, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts3; + DHFile -> lists:keystore(dhfile, 1, TLSOpts3, + {dhfile, DHFile}) + end, + TLSOpts5 = case ejabberd_config:get_option( + {c2s_cafile, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts4; + CAFile -> lists:keystore(cafile, 1, TLSOpts4, + {cafile, CAFile}) + end, + case ejabberd_config:get_option( + {c2s_tls_compression, LServer}, + fun(B) when is_boolean(B) -> B end) of + undefined -> TLSOpts5; + false -> [compression_none | TLSOpts5]; + true -> lists:delete(compression_none, TLSOpts5) + end. + +tls_required(#{tls_required := TLSRequired}) -> + TLSRequired. + +tls_verify(#{tls_verify := TLSVerify}) -> + TLSVerify. + +tls_enabled(#{tls_enabled := TLSEnabled, + tls_required := TLSRequired, + tls_verify := TLSVerify}) -> + TLSEnabled or TLSRequired or TLSVerify. + +compress_methods(#{zlib := true}) -> + [<<"zlib">>]; +compress_methods(_) -> + []. + +unauthenticated_stream_features(#{lserver := LServer}) -> + ejabberd_hooks:run_fold(c2s_pre_auth_features, LServer, [], [LServer]). + +authenticated_stream_features(#{lserver := LServer}) -> + ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]). + +sasl_mechanisms(Mechs, #{lserver := LServer}) -> + Mechs1 = ejabberd_config:get_option( + {disable_sasl_mechanisms, LServer}, + fun(V) when is_list(V) -> + lists:map(fun(M) -> str:to_upper(M) end, V); + (V) -> + [str:to_upper(V)] + end, []), + Mechs2 = case ejabberd_auth_anonymous:is_sasl_anonymous_enabled(LServer) of + true -> Mechs1; + false -> [<<"ANONYMOUS">>|Mechs1] + end, + Mechs -- Mechs2. + +get_password_fun(#{lserver := LServer}) -> + fun(U) -> + ejabberd_auth:get_password_with_authmodule(U, LServer) + end. + +check_password_fun(#{lserver := LServer}) -> + fun(U, AuthzId, P) -> + ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P) + end. + +check_password_digest_fun(#{lserver := LServer}) -> + fun(U, AuthzId, P, D, DG) -> + ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG) + end. + +bind(<<"">>, State) -> + bind(new_uniq_id(), State); +bind(R, #{user := U, server := S, access := Access, lang := Lang, + lserver := LServer, sockmod := SockMod, socket := Socket, + ip := IP} = State) -> + case resource_conflict_action(U, S, R) of + closenew -> + {error, xmpp:err_conflict(), State}; + {accept_resource, Resource} -> + JID = jid:make(U, S, Resource), + case acl:access_matches(Access, + #{usr => jid:split(JID), ip => IP}, + LServer) of + allow -> + State1 = open_session(State#{resource => Resource, + sid => ejabberd_sm:make_sid()}), + State2 = ejabberd_hooks:run_fold( + c2s_session_opened, LServer, State1, []), + ?INFO_MSG("(~s) Opened c2s session for ~s", + [SockMod:pp(Socket), jid:to_string(JID)]), + {ok, State2}; + deny -> + ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]), + ?INFO_MSG("(~s) Forbidden c2s session for ~s", + [SockMod:pp(Socket), jid:to_string(JID)]), + Txt = <<"Denied by ACL">>, + {error, xmpp:err_not_allowed(Txt, Lang), State} + end + end. + +handle_stream_start(StreamStart, #{lserver := LServer} = State) -> + case ejabberd_router:is_my_host(LServer) of + false -> + send(State, xmpp:serr_host_unknown()); + true -> + change_shaper(State), + ejabberd_hooks:run_fold( + c2s_stream_started, LServer, State, [StreamStart]) + end. + +handle_stream_end(Reason, #{lserver := LServer} = State) -> + State1 = State#{stop_reason => Reason}, + ejabberd_hooks:run_fold(c2s_closed, LServer, State1, [Reason]). + +handle_auth_success(User, Mech, AuthModule, + #{socket := Socket, sockmod := SockMod, + ip := IP, lserver := LServer} = State) -> + ?INFO_MSG("(~s) Accepted c2s ~s authentication for ~s@~s by ~s backend from ~s", + [SockMod:pp(Socket), Mech, User, LServer, + ejabberd_auth:backend_type(AuthModule), + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + State1 = State#{auth_module => AuthModule}, + ejabberd_hooks:run_fold(c2s_auth_result, LServer, State1, [true, User]). + +handle_auth_failure(User, Mech, Reason, + #{socket := Socket, sockmod := SockMod, + ip := IP, lserver := LServer} = State) -> + ?INFO_MSG("(~s) Failed c2s ~s authentication ~sfrom ~s: ~s", + [SockMod:pp(Socket), Mech, + if User /= <<"">> -> ["for ", User, "@", LServer, " "]; + true -> "" + end, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), + ejabberd_hooks:run_fold(c2s_auth_result, LServer, State, [false, User]). + +handle_unbinded_packet(Pkt, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer, State, [Pkt]). + +handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_unauthenticated_packet, LServer, State, [Pkt]). + +handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) -> + ejabberd_hooks:run_fold(c2s_authenticated_packet, + LServer, State, [Pkt]); +handle_authenticated_packet(Pkt, #{lserver := LServer, jid := JID} = State) -> + State1 = ejabberd_hooks:run_fold(c2s_authenticated_packet, + LServer, State, [Pkt]), + #jid{luser = LUser} = JID, + {Pkt1, State2} = ejabberd_hooks:run_fold( + user_send_packet, LServer, {Pkt, State1}, []), + case Pkt1 of + drop -> + State2; + #presence{to = #jid{luser = LUser, lserver = LServer, + lresource = <<"">>}} -> + process_self_presence(State2, Pkt1); + #presence{} -> + process_presence_out(State2, Pkt1); + _ -> + check_privacy_then_route(State2, Pkt1) + end. + +handle_cdata(Data, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_handle_cdata, LServer, + State, [Data]). + +handle_recv(El, Pkt, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_handle_recv, LServer, State, [El, Pkt]). + +handle_send(Pkt, Result, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_handle_send, LServer, State, [Pkt, Result]). + +init([State, Opts]) -> + Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all), + Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none), + TLSOpts1 = lists:filter( + fun({certfile, _}) -> true; + ({ciphers, _}) -> true; + ({dhfile, _}) -> true; + ({cafile, _}) -> true; + (_) -> false + end, Opts), + TLSOpts2 = case lists:keyfind(protocol_options, 1, Opts) of + false -> TLSOpts1; + {_, OptString} -> + ProtoOpts = str:join(OptString, <<$|>>), + [{protocol_options, ProtoOpts}|TLSOpts1] + end, TLSOpts3 = case proplists:get_bool(tls_compression, Opts) of false -> [compression_none | TLSOpts2]; true -> TLSOpts2 end, - TLSOpts = [verify_none | TLSOpts3], - StreamMgmtEnabled = proplists:get_value(stream_management, Opts, true), - StreamMgmtState = if StreamMgmtEnabled -> inactive; - true -> disabled - end, - MaxAckQueue = case proplists:get_value(max_ack_queue, Opts) of - Limit when is_integer(Limit), Limit > 0 -> Limit; - infinity -> infinity; - _ -> 1000 - end, - ResumeTimeout = case proplists:get_value(resume_timeout, Opts) of - RTimeo when is_integer(RTimeo), RTimeo >= 0 -> RTimeo; - _ -> 300 - end, - MaxResumeTimeout = case proplists:get_value(max_resume_timeout, Opts) of - Max when is_integer(Max), Max >= ResumeTimeout -> Max; - _ -> ResumeTimeout - end, - AckTimeout = case proplists:get_value(ack_timeout, Opts) of - ATimeo when is_integer(ATimeo), ATimeo > 0 -> ATimeo * 1000; - infinity -> undefined; - _ -> 60000 - end, - ResendOnTimeout = case proplists:get_value(resend_on_timeout, Opts) of - Resend when is_boolean(Resend) -> Resend; - if_offline -> if_offline; - _ -> false - end, - IP = peerip(SockMod, Socket), - Socket1 = if TLSEnabled andalso - SockMod /= ejabberd_frontend_socket -> - SockMod:starttls(Socket, TLSOpts); - true -> Socket - end, - SocketMonitor = SockMod:monitor(Socket1), - StateData = #state{socket = Socket1, sockmod = SockMod, - socket_monitor = SocketMonitor, - xml_socket = XMLSocket, zlib = Zlib, tls = TLS, - tls_required = StartTLSRequired, - tls_enabled = TLSEnabled, tls_options = TLSOpts, - sid = ejabberd_sm:make_sid(), streamid = new_id(), - access = Access, shaper = Shaper, ip = IP, - mgmt_state = StreamMgmtState, - mgmt_max_queue = MaxAckQueue, - mgmt_timeout = ResumeTimeout, - mgmt_max_timeout = MaxResumeTimeout, - mgmt_ack_timeout = AckTimeout, - mgmt_resend = ResendOnTimeout}, - {ok, wait_for_stream, StateData, ?C2S_OPEN_TIMEOUT}. + TLSEnabled = proplists:get_bool(starttls, Opts), + TLSRequired = proplists:get_bool(starttls_required, Opts), + TLSVerify = proplists:get_bool(tls_verify, Opts), + Zlib = proplists:get_bool(zlib, Opts), + State1 = State#{tls_options => TLSOpts3, + tls_required => TLSRequired, + tls_enabled => TLSEnabled, + tls_verify => TLSVerify, + pres_a => ?SETS:new(), + pres_f => ?SETS:new(), + pres_t => ?SETS:new(), + zlib => Zlib, + lang => ?MYLANG, + server => ?MYNAME, + lserver => ?MYNAME, + access => Access, + shaper => Shaper}, + ejabberd_hooks:run_fold(c2s_init, {ok, State1}, [Opts]). --spec get_subscribed(pid()) -> [ljid()]. -%% Return list of all available resources of contacts, -get_subscribed(FsmRef) -> - (?GEN_FSM):sync_send_all_state_event(FsmRef, - get_subscribed, 1000). - -wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) -> - try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of - #stream_start{xmlns = NS_CLIENT, stream_xmlns = NS_STREAM, - version = Version, lang = Lang} - when NS_CLIENT /= ?NS_CLIENT; NS_STREAM /= ?NS_STREAM -> - send_header(StateData, ?MYNAME, Version, Lang), - send_element(StateData, xmpp:serr_invalid_namespace()), - {stop, normal, StateData}; - #stream_start{lang = Lang, version = Version} when byte_size(Lang) > 35 -> - %% As stated in BCP47, 4.4.1: - %% Protocols or specifications that specify limited buffer sizes for - %% language tags MUST allow for language tags of at least 35 characters. - %% Do not store long language tag to avoid possible DoS/flood attacks - send_header(StateData, ?MYNAME, Version, ?MYLANG), - Txt = <<"Too long value of 'xml:lang' attribute">>, - send_element(StateData, - xmpp:serr_policy_violation(Txt, ?MYLANG)), - {stop, normal, StateData}; - #stream_start{to = undefined, lang = Lang, version = Version} -> - Txt = <<"Missing 'to' attribute">>, - send_header(StateData, ?MYNAME, Version, Lang), - send_element(StateData, - xmpp:serr_improper_addressing(Txt, Lang)), - {stop, normal, StateData}; - #stream_start{to = #jid{lserver = To}, lang = Lang, - version = Version} -> - Server = case StateData#state.server of - <<"">> -> To; - S -> S - end, - StreamVersion = case Version of - {1,0} -> {1,0}; - _ -> undefined - end, - IsBlacklistedIP = is_ip_blacklisted(StateData#state.ip, Lang), - case lists:member(Server, ?MYHOSTS) of - true when IsBlacklistedIP == false -> - change_shaper(StateData, jid:make(<<"">>, Server, <<"">>)), - case StreamVersion of - {1,0} -> - send_header(StateData, Server, {1,0}, ?MYLANG), - case StateData#state.authenticated of - false -> - TLS = StateData#state.tls, - TLSEnabled = StateData#state.tls_enabled, - TLSRequired = StateData#state.tls_required, - SASLState = cyrsasl:server_new( - <<"jabber">>, Server, <<"">>, [], - fun (U) -> - ejabberd_auth:get_password_with_authmodule( - U, Server) - end, - fun(U, AuthzId, P) -> - ejabberd_auth:check_password_with_authmodule( - U, AuthzId, Server, P) - end, - fun(U, AuthzId, P, D, DG) -> - ejabberd_auth:check_password_with_authmodule( - U, AuthzId, Server, P, D, DG) - end), - Mechs = - case TLSEnabled or not TLSRequired of - true -> - [#sasl_mechanisms{list = cyrsasl:listmech(Server)}]; - false -> - [] - end, - SockMod = - (StateData#state.sockmod):get_sockmod(StateData#state.socket), - Zlib = StateData#state.zlib, - CompressFeature = case Zlib andalso - ((SockMod == gen_tcp) orelse (SockMod == fast_tls)) of - true -> - [#compression{methods = [<<"zlib">>]}]; - _ -> - [] - end, - TLSFeature = - case (TLS == true) andalso - (TLSEnabled == false) andalso - (SockMod == gen_tcp) of - true -> - [#starttls{required = TLSRequired}]; - false -> - [] - end, - StreamFeatures1 = TLSFeature ++ CompressFeature ++ Mechs, - StreamFeatures = ejabberd_hooks:run_fold(c2s_stream_features, - Server, StreamFeatures1, [Server]), - send_element(StateData, - #stream_features{sub_els = StreamFeatures}), - fsm_next_state(wait_for_feature_request, - StateData#state{server = Server, - sasl_state = SASLState, - lang = Lang}); - _ -> - case StateData#state.resource of - <<"">> -> - RosterVersioningFeature = - ejabberd_hooks:run_fold(roster_get_versioning_feature, - Server, [], - [Server]), - StreamManagementFeature = - case stream_mgmt_enabled(StateData) of - true -> - [#feature_sm{xmlns = ?NS_STREAM_MGMT_2}, - #feature_sm{xmlns = ?NS_STREAM_MGMT_3}]; - false -> - [] - end, - SockMod = - (StateData#state.sockmod):get_sockmod( - StateData#state.socket), - Zlib = StateData#state.zlib, - CompressFeature = - case Zlib andalso - ((SockMod == gen_tcp) orelse (SockMod == fast_tls)) of - true -> - [#compression{methods = [<<"zlib">>]}]; - _ -> - [] - end, - StreamFeatures1 = - [#bind{}, #xmpp_session{optional = true}] - ++ - RosterVersioningFeature ++ - StreamManagementFeature ++ - CompressFeature ++ - ejabberd_hooks:run_fold(c2s_post_auth_features, - Server, [], [Server]), - StreamFeatures = ejabberd_hooks:run_fold(c2s_stream_features, - Server, StreamFeatures1, [Server]), - send_element(StateData, - #stream_features{sub_els = StreamFeatures}), - fsm_next_state(wait_for_bind, - StateData#state{server = Server, lang = Lang}); - _ -> - send_element(StateData, #stream_features{}), - fsm_next_state(session_established, - StateData#state{server = Server, lang = Lang}) - end - end; - _ -> - send_header(StateData, Server, StreamVersion, ?MYLANG), - if not StateData#state.tls_enabled and - StateData#state.tls_required -> - send_element( - StateData, - xmpp:serr_policy_violation( - <<"Use of STARTTLS required">>, Lang)), - {stop, normal, StateData}; - true -> - fsm_next_state(wait_for_auth, - StateData#state{server = Server, - lang = Lang}) - end - end; - true -> - IP = StateData#state.ip, - {true, LogReason, ReasonT} = IsBlacklistedIP, - ?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s", - [jlib:ip_to_list(IP), LogReason]), - send_header(StateData, Server, StreamVersion, ?MYLANG), - send_element(StateData, xmpp:serr_policy_violation(ReasonT, Lang)), - {stop, normal, StateData}; - _ -> - send_header(StateData, ?MYNAME, StreamVersion, ?MYLANG), - send_element(StateData, xmpp:serr_host_unknown()), - {stop, normal, StateData} - end; - _ -> - send_header(StateData, ?MYNAME, {1,0}, ?MYLANG), - send_element(StateData, xmpp:serr_invalid_xml()), - {stop, normal, StateData} - catch _:{xmpp_codec, Why} -> - Txt = xmpp:format_error(Why), - send_header(StateData, ?MYNAME, {1,0}, ?MYLANG), - send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)), - {stop, normal, StateData} - end; -wait_for_stream(timeout, StateData) -> - {stop, normal, StateData}; -wait_for_stream({xmlstreamelement, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_stream({xmlstreamend, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_stream({xmlstreamerror, _}, StateData) -> - send_header(StateData, ?MYNAME, {1,0}, <<"">>), - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_stream(closed, StateData) -> - {stop, normal, StateData}; -wait_for_stream(stop, StateData) -> - {stop, normal, StateData}. - -wait_for_auth({xmlstreamelement, #xmlel{} = El}, StateData) -> - decode_element(El, wait_for_auth, StateData); -wait_for_auth(Pkt, StateData) when ?IS_STREAM_MGMT_PACKET(Pkt) -> - fsm_next_state(wait_for_auth, dispatch_stream_mgmt(Pkt, StateData)); -wait_for_auth(#iq{type = get, sub_els = [#legacy_auth{}]} = IQ, StateData) -> - Auth = #legacy_auth{username = <<>>, password = <<>>, resource = <<>>}, - Res = case ejabberd_auth:plain_password_required(StateData#state.server) of - false -> - xmpp:make_iq_result(IQ, Auth#legacy_auth{digest = <<>>}); - true -> - xmpp:make_iq_result(IQ, Auth) - end, - send_element(StateData, Res), - fsm_next_state(wait_for_auth, StateData); -wait_for_auth(#iq{type = set, sub_els = [#legacy_auth{resource = <<"">>}]} = IQ, - StateData) -> - Lang = StateData#state.lang, - Txt = <<"No resource provided">>, - Err = xmpp:make_error(IQ, xmpp:err_not_acceptable(Txt, Lang)), - send_element(StateData, Err), - fsm_next_state(wait_for_auth, StateData); -wait_for_auth(#iq{type = set, sub_els = [#legacy_auth{username = U, - password = P0, - digest = D0, - resource = R}]} = IQ, - StateData) when is_binary(U), is_binary(R) -> - JID = jid:make(U, StateData#state.server, R), - case (JID /= error) andalso - acl:access_matches(StateData#state.access, - #{usr => jid:split(JID), ip => StateData#state.ip}, - StateData#state.server) == allow of - true -> - DGen = fun (PW) -> - p1_sha:sha(<<(StateData#state.streamid)/binary, PW/binary>>) - end, - P = if is_binary(P0) -> P0; true -> <<>> end, - D = if is_binary(D0) -> D0; true -> <<>> end, - case ejabberd_auth:check_password_with_authmodule( - U, U, StateData#state.server, P, D, DGen) of - {true, AuthModule} -> - ?INFO_MSG("(~w) Accepted legacy authentication for ~s by ~p from ~s", - [StateData#state.socket, - jid:to_string(JID), AuthModule, - ejabberd_config:may_hide_data(jlib:ip_to_list(StateData#state.ip))]), - ejabberd_hooks:run(c2s_auth_result, StateData#state.server, - [true, U, StateData#state.server, - StateData#state.ip]), - Conn = get_conn_type(StateData), - Info = [{ip, StateData#state.ip}, {conn, Conn}, - {auth_module, AuthModule}], - Res = xmpp:make_iq_result(IQ), - send_element(StateData, Res), - ejabberd_sm:open_session(StateData#state.sid, U, - StateData#state.server, R, - Info), - change_shaper(StateData, JID), - {Fs, Ts} = ejabberd_hooks:run_fold( - roster_get_subscription_lists, - StateData#state.server, - {[], []}, - [U, StateData#state.server]), - LJID = jid:tolower(jid:remove_resource(JID)), - Fs1 = [LJID | Fs], - Ts1 = [LJID | Ts], - PrivList = ejabberd_hooks:run_fold(privacy_get_user_list, - StateData#state.server, - #userlist{}, - [U, StateData#state.server]), - NewStateData = StateData#state{ - user = U, - resource = R, - jid = JID, - conn = Conn, - auth_module = AuthModule, - pres_f = (?SETS):from_list(Fs1), - pres_t = (?SETS):from_list(Ts1), - privacy_list = PrivList}, - fsm_next_state(session_established, NewStateData); - _ -> - ?INFO_MSG("(~w) Failed legacy authentication for ~s from ~s", - [StateData#state.socket, - jid:to_string(JID), - ejabberd_config:may_hide_data(jlib:ip_to_list(StateData#state.ip))]), - ejabberd_hooks:run(c2s_auth_result, StateData#state.server, - [false, U, StateData#state.server, - StateData#state.ip]), - Lang = StateData#state.lang, - Txt = <<"Legacy authentication failed">>, - Err = xmpp:make_error(IQ, xmpp:err_not_authorized(Txt, Lang)), - send_element(StateData, Err), - fsm_next_state(wait_for_auth, StateData) - end; - false when JID == error -> - ?INFO_MSG("(~w) Forbidden legacy authentication " - "for username '~s' with resource '~s'", - [StateData#state.socket, U, R]), - Err = xmpp:make_error(IQ, xmpp:err_jid_malformed()), - send_element(StateData, Err), - fsm_next_state(wait_for_auth, StateData); - false -> - ?INFO_MSG("(~w) Forbidden legacy authentication for ~s from ~s", - [StateData#state.socket, - jid:to_string(JID), - ejabberd_config:may_hide_data(jlib:ip_to_list(StateData#state.ip))]), - ejabberd_hooks:run(c2s_auth_result, StateData#state.server, - [false, U, StateData#state.server, - StateData#state.ip]), - Lang = StateData#state.lang, - Txt = <<"Legacy authentication forbidden">>, - Err = xmpp:make_error(IQ, xmpp:err_not_allowed(Txt, Lang)), - send_element(StateData, Err), - fsm_next_state(wait_for_auth, StateData) - end; -wait_for_auth(timeout, StateData) -> - {stop, normal, StateData}; -wait_for_auth({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -wait_for_auth({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_auth(closed, StateData) -> - {stop, normal, StateData}; -wait_for_auth(stop, StateData) -> - {stop, normal, StateData}; -wait_for_auth(Pkt, StateData) -> - process_unauthenticated_stanza(StateData, Pkt), - fsm_next_state(wait_for_auth, StateData). - -wait_for_feature_request({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_feature_request, StateData); -wait_for_feature_request(Pkt, StateData) when ?IS_STREAM_MGMT_PACKET(Pkt) -> - fsm_next_state(wait_for_feature_request, - dispatch_stream_mgmt(Pkt, StateData)); -wait_for_feature_request(#sasl_auth{mechanism = Mech, - text = ClientIn}, - #state{tls_enabled = TLSEnabled, - tls_required = TLSRequired} = StateData) - when TLSEnabled or not TLSRequired -> - case cyrsasl:server_start(StateData#state.sasl_state, Mech, ClientIn) of - {ok, Props} -> - (StateData#state.sockmod):reset_stream(StateData#state.socket), - U = identity(Props), - AuthModule = proplists:get_value(auth_module, Props, undefined), - ?INFO_MSG("(~w) Accepted authentication for ~s by ~p from ~s", - [StateData#state.socket, U, AuthModule, - ejabberd_config:may_hide_data(jlib:ip_to_list(StateData#state.ip))]), - ejabberd_hooks:run(c2s_auth_result, StateData#state.server, - [true, U, StateData#state.server, - StateData#state.ip]), - send_element(StateData, #sasl_success{}), - fsm_next_state(wait_for_stream, - StateData#state{streamid = new_id(), - authenticated = true, - auth_module = AuthModule, - sasl_state = undefined, - user = U}); - {continue, ServerOut, NewSASLState} -> - send_element(StateData, #sasl_challenge{text = ServerOut}), - fsm_next_state(wait_for_sasl_response, - StateData#state{sasl_state = NewSASLState}); - {error, Error, Username} -> - {Reason, ErrTxt} = cyrsasl:format_error(Mech, Error), - ?INFO_MSG("(~w) Failed authentication for ~s@~s from ~s: ~s", - [StateData#state.socket, - Username, StateData#state.server, - ejabberd_config:may_hide_data(jlib:ip_to_list(StateData#state.ip)), - ErrTxt]), - ejabberd_hooks:run(c2s_auth_result, StateData#state.server, - [false, Username, StateData#state.server, - StateData#state.ip]), - send_element(StateData, #sasl_failure{reason = Reason, - text = xmpp:mk_text(ErrTxt)}), - fsm_next_state(wait_for_feature_request, StateData); - {error, Error} -> - {Reason, ErrTxt} = cyrsasl:format_error(Mech, Error), - send_element(StateData, #sasl_failure{reason = Reason, - text = xmpp:mk_text(ErrTxt)}), - fsm_next_state(wait_for_feature_request, StateData) - end; -wait_for_feature_request(#starttls{}, - #state{tls = true, tls_enabled = false} = StateData) -> - case (StateData#state.sockmod):get_sockmod(StateData#state.socket) of - gen_tcp -> - TLSOpts = case ejabberd_config:get_option( - {domain_certfile, StateData#state.server}, - fun iolist_to_binary/1) of - undefined -> - StateData#state.tls_options; - CertFile -> - lists:keystore(certfile, 1, - StateData#state.tls_options, - {certfile, CertFile}) - end, - Socket = StateData#state.socket, - BProceed = fxml:element_to_binary(xmpp:encode(#starttls_proceed{})), - TLSSocket = (StateData#state.sockmod):starttls(Socket, TLSOpts, BProceed), - fsm_next_state(wait_for_stream, - StateData#state{socket = TLSSocket, - streamid = new_id(), - tls_enabled = true}); - _ -> - Lang = StateData#state.lang, - Txt = <<"Unsupported TLS transport">>, - send_element(StateData, xmpp:serr_policy_violation(Txt, Lang)), - {stop, normal, StateData} - end; -wait_for_feature_request(#compress{} = Comp, StateData) -> - Zlib = StateData#state.zlib, - SockMod = (StateData#state.sockmod):get_sockmod(StateData#state.socket), - if Zlib == true, (SockMod == gen_tcp) or (SockMod == fast_tls) -> - process_compression_request(Comp, wait_for_feature_request, StateData); - true -> - send_element(StateData, #compress_failure{reason = 'setup-failed'}), - fsm_next_state(wait_for_feature_request, StateData) - end; -wait_for_feature_request(timeout, StateData) -> - {stop, normal, StateData}; -wait_for_feature_request({xmlstreamend, _Name}, - StateData) -> - {stop, normal, StateData}; -wait_for_feature_request({xmlstreamerror, _}, - StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_feature_request(closed, StateData) -> - {stop, normal, StateData}; -wait_for_feature_request(stop, StateData) -> - {stop, normal, StateData}; -wait_for_feature_request(_Pkt, - #state{tls_required = TLSRequired, - tls_enabled = TLSEnabled} = StateData) - when TLSRequired and not TLSEnabled -> - Lang = StateData#state.lang, - Txt = <<"Use of STARTTLS required">>, - send_element(StateData, xmpp:serr_policy_violation(Txt, Lang)), - {stop, normal, StateData}; -wait_for_feature_request(Pkt, StateData) -> - process_unauthenticated_stanza(StateData, Pkt), - fsm_next_state(wait_for_feature_request, StateData). - -wait_for_sasl_response({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_sasl_response, StateData); -wait_for_sasl_response(Pkt, StateData) when ?IS_STREAM_MGMT_PACKET(Pkt) -> - fsm_next_state(wait_for_sasl_response, - dispatch_stream_mgmt(Pkt, StateData)); -wait_for_sasl_response(#sasl_response{text = ClientIn}, StateData) -> - case cyrsasl:server_step(StateData#state.sasl_state, ClientIn) of - {ok, Props} -> - catch (StateData#state.sockmod):reset_stream(StateData#state.socket), - U = identity(Props), - AuthModule = proplists:get_value(auth_module, Props, <<>>), - ?INFO_MSG("(~w) Accepted authentication for ~s by ~p from ~s", - [StateData#state.socket, U, AuthModule, - ejabberd_config:may_hide_data(jlib:ip_to_list(StateData#state.ip))]), - ejabberd_hooks:run(c2s_auth_result, StateData#state.server, - [true, U, StateData#state.server, - StateData#state.ip]), - send_element(StateData, #sasl_success{}), - fsm_next_state(wait_for_stream, - StateData#state{streamid = new_id(), - authenticated = true, - auth_module = AuthModule, - sasl_state = undefined, - user = U}); - {ok, Props, ServerOut} -> - (StateData#state.sockmod):reset_stream(StateData#state.socket), - U = identity(Props), - AuthModule = proplists:get_value(auth_module, Props, undefined), - ?INFO_MSG("(~w) Accepted authentication for ~s by ~p from ~s", - [StateData#state.socket, U, AuthModule, - ejabberd_config:may_hide_data(jlib:ip_to_list(StateData#state.ip))]), - ejabberd_hooks:run(c2s_auth_result, StateData#state.server, - [true, U, StateData#state.server, - StateData#state.ip]), - send_element(StateData, #sasl_success{text = ServerOut}), - fsm_next_state(wait_for_stream, - StateData#state{streamid = new_id(), - authenticated = true, - auth_module = AuthModule, - sasl_state = undefined, - user = U}); - {continue, ServerOut, NewSASLState} -> - send_element(StateData, #sasl_challenge{text = ServerOut}), - fsm_next_state(wait_for_sasl_response, - StateData#state{sasl_state = NewSASLState}); - {error, Error, Username} -> - {Reason, ErrTxt} = cyrsasl:format_error(StateData#state.sasl_state, Error), - ?INFO_MSG("(~w) Failed authentication for ~s@~s from ~s: ~s", - [StateData#state.socket, - Username, StateData#state.server, - ejabberd_config:may_hide_data(jlib:ip_to_list(StateData#state.ip)), - ErrTxt]), - ejabberd_hooks:run(c2s_auth_result, StateData#state.server, - [false, Username, StateData#state.server, - StateData#state.ip]), - send_element(StateData, #sasl_failure{reason = Reason, - text = xmpp:mk_text(ErrTxt)}), - fsm_next_state(wait_for_feature_request, StateData); - {error, Error} -> - {Reason, ErrTxt} = cyrsasl:format_error(StateData#state.sasl_state, Error), - send_element(StateData, #sasl_failure{reason = Reason, - text = xmpp:mk_text(ErrTxt)}), - fsm_next_state(wait_for_feature_request, StateData) - end; -wait_for_sasl_response(timeout, StateData) -> - {stop, normal, StateData}; -wait_for_sasl_response({xmlstreamend, _Name}, - StateData) -> - {stop, normal, StateData}; -wait_for_sasl_response({xmlstreamerror, _}, - StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_sasl_response(closed, StateData) -> - {stop, normal, StateData}; -wait_for_sasl_response(stop, StateData) -> - {stop, normal, StateData}; -wait_for_sasl_response(Pkt, StateData) -> - process_unauthenticated_stanza(StateData, Pkt), - fsm_next_state(wait_for_feature_request, StateData). - --spec resource_conflict_action(binary(), binary(), binary()) -> - {accept_resource, binary()} | closenew. -resource_conflict_action(U, S, R) -> - OptionRaw = case ejabberd_sm:is_existing_resource(U, S, R) of - true -> - ejabberd_config:get_option( - {resource_conflict, S}, - fun(setresource) -> setresource; - (closeold) -> closeold; - (closenew) -> closenew; - (acceptnew) -> acceptnew - end); - false -> - acceptnew - end, - Option = case OptionRaw of - setresource -> setresource; - closeold -> - acceptnew; %% ejabberd_sm will close old session - closenew -> closenew; - acceptnew -> acceptnew; - _ -> acceptnew %% default ejabberd behavior - end, - case Option of - acceptnew -> {accept_resource, R}; - closenew -> closenew; - setresource -> - Rnew = new_uniq_id(), - {accept_resource, Rnew} - end. - --spec decode_element(xmlel(), state_name(), state()) -> fsm_transition(). -decode_element(#xmlel{} = El, StateName, StateData) -> - try case xmpp:decode(El, ?NS_CLIENT, [ignore_els]) of - #iq{sub_els = [_], type = T} = Pkt when T == set; T == get -> - NewPkt = xmpp:decode_els( - Pkt, ?NS_CLIENT, - fun(SubEl) when StateName == session_established -> - case xmpp:get_ns(SubEl) of - ?NS_PRIVACY -> true; - ?NS_BLOCKING -> true; - _ -> false - end; - (SubEl) -> - xmpp:is_known_tag(SubEl) - end), - ?MODULE:StateName(NewPkt, StateData); - Pkt -> - ?MODULE:StateName(Pkt, StateData) - end - catch error:{xmpp_codec, Why} -> - NS = xmpp:get_ns(El), - fsm_next_state( - StateName, - case xmpp:is_stanza(El) of - true -> - Lang = xmpp:get_lang(El), - Txt = xmpp:format_error(Why), - send_error(StateData, El, xmpp:err_bad_request(Txt, Lang)); - false when NS == ?NS_STREAM_MGMT_2; NS == ?NS_STREAM_MGMT_3 -> - Err = #sm_failed{reason = 'bad-request', xmlns = NS}, - send_element(StateData, Err), - StateData; - false -> - StateData - end) - end. - -wait_for_bind({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_bind, StateData); -wait_for_bind(#sm_resume{} = Pkt, StateData) -> - case handle_resume(StateData, Pkt) of - {ok, ResumedState} -> - fsm_next_state(session_established, ResumedState); - error -> - fsm_next_state(wait_for_bind, StateData) - end; -wait_for_bind(Pkt, StateData) when ?IS_STREAM_MGMT_PACKET(Pkt) -> - fsm_next_state(wait_for_bind, dispatch_stream_mgmt(Pkt, StateData)); -wait_for_bind(#iq{type = set, - sub_els = [#bind{resource = R0}]} = IQ, StateData) -> - U = StateData#state.user, - R = case R0 of - <<>> -> new_uniq_id(); - _ -> R0 - end, - case resource_conflict_action(U, StateData#state.server, R) of - closenew -> - Err = xmpp:make_error(IQ, xmpp:err_conflict()), - send_element(StateData, Err), - fsm_next_state(wait_for_bind, StateData); - {accept_resource, R2} -> - JID = jid:make(U, StateData#state.server, R2), - StateData2 = StateData#state{resource = R2, jid = JID}, - case open_session(StateData2) of - {ok, StateData3} -> - Res = xmpp:make_iq_result(IQ, #bind{jid = JID}), - try - send_element(StateData3, Res) - catch - exit:normal -> close(self()) - end, - fsm_next_state_pack(session_established,StateData3); - {error, Error} -> - Err = xmpp:make_error(IQ, Error), - send_element(StateData, Err), - fsm_next_state(wait_for_bind, StateData) - end - end; -wait_for_bind(#compress{} = Comp, StateData) -> - Zlib = StateData#state.zlib, - SockMod = (StateData#state.sockmod):get_sockmod(StateData#state.socket), - if Zlib == true, (SockMod == gen_tcp) or (SockMod == fast_tls) -> - process_compression_request(Comp, wait_for_bind, StateData); - true -> - send_element(StateData, #compress_failure{reason = 'setup-failed'}), - fsm_next_state(wait_for_bind, StateData) - end; -wait_for_bind(timeout, StateData) -> - {stop, normal, StateData}; -wait_for_bind({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -wait_for_bind({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_bind(closed, StateData) -> - {stop, normal, StateData}; -wait_for_bind(stop, StateData) -> - {stop, normal, StateData}; -wait_for_bind(Pkt, StateData) -> - fsm_next_state( - wait_for_bind, - case xmpp:is_stanza(Pkt) of - true -> - send_error(StateData, Pkt, xmpp:err_not_acceptable()); - false -> - StateData - end). - --spec open_session(state()) -> {ok, state()} | {error, stanza_error()}. -open_session(StateData) -> - U = StateData#state.user, - R = StateData#state.resource, - JID = StateData#state.jid, - Lang = StateData#state.lang, - IP = StateData#state.ip, - case acl:access_matches(StateData#state.access, - #{usr => jid:split(JID), ip => IP}, - StateData#state.server) of - allow -> - ?INFO_MSG("(~w) Opened session for ~s", - [StateData#state.socket, jid:to_string(JID)]), - change_shaper(StateData, JID), - {Fs, Ts} = ejabberd_hooks:run_fold( - roster_get_subscription_lists, - StateData#state.server, - {[], []}, - [U, StateData#state.server]), - LJID = jid:tolower(jid:remove_resource(JID)), - Fs1 = [LJID | Fs], - Ts1 = [LJID | Ts], - PrivList = - ejabberd_hooks:run_fold( - privacy_get_user_list, - StateData#state.server, - #userlist{}, - [U, StateData#state.server]), - Conn = get_conn_type(StateData), - Info = [{ip, StateData#state.ip}, {conn, Conn}, - {auth_module, StateData#state.auth_module}], - ejabberd_sm:open_session( - StateData#state.sid, U, StateData#state.server, R, Info), - UpdatedStateData = - StateData#state{ - conn = Conn, - pres_f = ?SETS:from_list(Fs1), - pres_t = ?SETS:from_list(Ts1), - privacy_list = PrivList}, - {ok, UpdatedStateData}; - _ -> - ejabberd_hooks:run(forbidden_session_hook, - StateData#state.server, [JID]), - ?INFO_MSG("(~w) Forbidden session for ~s", - [StateData#state.socket, jid:to_string(JID)]), - Txt = <<"Denied by ACL">>, - {error, xmpp:err_not_allowed(Txt, Lang)} - end. - -session_established({xmlstreamelement, El}, StateData) -> - decode_element(El, session_established, StateData); -session_established(Pkt, StateData) when ?IS_STREAM_MGMT_PACKET(Pkt) -> - fsm_next_state(session_established, dispatch_stream_mgmt(Pkt, StateData)); -session_established(#csi{type = active}, StateData) -> - NewStateData = csi_flush_queue(StateData), - fsm_next_state(session_established, NewStateData#state{csi_state = active}); -session_established(#csi{type = inactive}, StateData) -> - fsm_next_state(session_established, StateData#state{csi_state = inactive}); -%% We hibernate the process to reduce memory consumption after a -%% configurable activity timeout -session_established(timeout, StateData) -> - Options = [], - proc_lib:hibernate(?GEN_FSM, enter_loop, - [?MODULE, Options, session_established, StateData]), - fsm_next_state(session_established, StateData); -session_established({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -session_established({xmlstreamerror, - <<"XML stanza is too big">> = E}, - StateData) -> - send_element(StateData, - xmpp:serr_policy_violation(E, StateData#state.lang)), - {stop, normal, StateData}; -session_established({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -session_established(closed, #state{mgmt_state = active} = StateData) -> - catch (StateData#state.sockmod):close(StateData#state.socket), - fsm_next_state(wait_for_resume, StateData); -session_established(closed, StateData) -> - {stop, normal, StateData}; -session_established(stop, StateData) -> - {stop, normal, StateData}; -session_established(Pkt, StateData) when ?is_stanza(Pkt) -> - FromJID = StateData#state.jid, - case check_from(Pkt, FromJID) of - 'invalid-from' -> - send_element(StateData, xmpp:serr_invalid_from()), - {stop, normal, StateData}; - _ -> - NewStateData = update_num_stanzas_in(StateData, Pkt), - session_established2(Pkt, NewStateData) - end; -session_established(_Pkt, StateData) -> - fsm_next_state(session_established, StateData). - --spec session_established2(xmpp_element(), state()) -> fsm_next(). -%% Process packets sent by user (coming from user on c2s XMPP connection) -session_established2(Pkt, StateData) -> - User = StateData#state.user, - Server = StateData#state.server, - FromJID = StateData#state.jid, - ToJID = case xmpp:get_to(Pkt) of - undefined -> jid:make(User, Server, <<"">>); - J -> J - end, - Lang = case xmpp:get_lang(Pkt) of - <<"">> -> StateData#state.lang; - L -> L +handle_call(get_presence, From, #{jid := JID} = State) -> + Pres = try maps:get(pres_last, State) + catch _:{badkey, _} -> + BareJID = jid:remove_resource(JID), + #presence{from = JID, to = BareJID, type = unavailable} end, - NewPkt = xmpp:set_lang(Pkt, Lang), - NewState = - case NewPkt of - #presence{} -> - Presence0 = ejabberd_hooks:run_fold( - c2s_update_presence, Server, NewPkt, - [User, Server]), - Presence = ejabberd_hooks:run_fold( - user_send_packet, Server, Presence0, - [StateData, FromJID, ToJID]), - case ToJID of - #jid{user = User, server = Server, resource = <<"">>} -> - ?DEBUG("presence_update(~p,~n\t~p,~n\t~p)", - [FromJID, Presence, StateData]), - presence_update(FromJID, Presence, - StateData); - _ -> - presence_track(FromJID, ToJID, Presence, - StateData) - end; - #iq{type = T, sub_els = [El]} when T == set; T == get -> - NS = xmpp:get_ns(El), - if NS == ?NS_BLOCKING; NS == ?NS_PRIVACY -> - IQ = xmpp:set_from_to(Pkt, FromJID, ToJID), - process_privacy_iq(IQ, StateData); - NS == ?NS_SESSION -> - Res = xmpp:make_iq_result(Pkt), - send_stanza(StateData, Res); - true -> - NewPkt0 = ejabberd_hooks:run_fold( - user_send_packet, Server, NewPkt, - [StateData, FromJID, ToJID]), - check_privacy_route(FromJID, StateData, FromJID, - ToJID, NewPkt0) - end; - _ -> - NewPkt0 = ejabberd_hooks:run_fold( - user_send_packet, Server, NewPkt, - [StateData, FromJID, ToJID]), - check_privacy_route(FromJID, StateData, FromJID, - ToJID, NewPkt0) - end, - ejabberd_hooks:run(c2s_loop_debug, - [{xmlstreamelement, Pkt}]), - fsm_next_state(session_established, NewState). + reply(From, Pres), + State; +handle_call(get_subscribed, From, #{pres_f := PresF} = State) -> + reply(From, ?SETS:to_list(PresF)), + State; +handle_call(Request, From, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold( + c2s_handle_call, LServer, State, [Request, From]). -wait_for_resume({xmlstreamelement, _El} = Event, StateData) -> - Result = session_established(Event, StateData), - fsm_next_state(wait_for_resume, element(3, Result)); -wait_for_resume(timeout, StateData) -> - ?DEBUG("Timed out waiting for resumption of stream for ~s", - [jid:to_string(StateData#state.jid)]), - {stop, normal, StateData#state{mgmt_state = timeout}}; -wait_for_resume(Event, StateData) -> - ?DEBUG("Ignoring event while waiting for resumption: ~p", [Event]), - fsm_next_state(wait_for_resume, StateData). +handle_cast(Msg, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_handle_cast, LServer, State, [Msg]). -handle_event(_Event, StateName, StateData) -> - fsm_next_state(StateName, StateData). +handle_info(Info, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]). -handle_sync_event({get_presence}, _From, StateName, - StateData) -> - User = StateData#state.user, - PresLast = StateData#state.pres_last, - Show = get_showtag(PresLast), - Status = get_statustag(PresLast), - Resource = StateData#state.resource, - Reply = {User, Resource, Show, Status}, - fsm_reply(Reply, StateName, StateData); -handle_sync_event({get_last_presence}, _From, StateName, - StateData) -> - User = StateData#state.user, - Server = StateData#state.server, - PresLast = StateData#state.pres_last, - Resource = StateData#state.resource, - Reply = {User, Server, Resource, PresLast}, - fsm_reply(Reply, StateName, StateData); - -handle_sync_event(get_subscribed, _From, StateName, - StateData) -> - Subscribed = (?SETS):to_list(StateData#state.pres_f), - {reply, Subscribed, StateName, StateData}; -handle_sync_event({resume_session, Time}, _From, _StateName, - StateData) when element(1, StateData#state.sid) == Time -> - %% The old session should be closed before the new one is opened, so we do - %% this here instead of leaving it to the terminate callback - ejabberd_sm:close_session(StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource), - {stop, normal, {resume, StateData}, StateData#state{mgmt_state = resumed}}; -handle_sync_event({resume_session, _Time}, _From, StateName, - StateData) -> - {reply, {error, <<"Previous session not found">>}, StateName, StateData}; -handle_sync_event(_Event, _From, StateName, - StateData) -> - Reply = ok, fsm_reply(Reply, StateName, StateData). - -code_change(_OldVsn, StateName, StateData, _Extra) -> - {ok, StateName, StateData}. - -handle_info({send_text, Text}, StateName, StateData) -> - send_text(StateData, Text), - ejabberd_hooks:run(c2s_loop_debug, [Text]), - fsm_next_state(StateName, StateData); -handle_info(replaced, StateName, StateData) -> - Lang = StateData#state.lang, - Pkt = xmpp:serr_conflict(<<"Replaced by new connection">>, Lang), - handle_info({kick, replaced, Pkt}, StateName, StateData); -handle_info(kick, StateName, StateData) -> - Lang = StateData#state.lang, - Pkt = xmpp:serr_policy_violation(<<"has been kicked">>, Lang), - handle_info({kick, kicked_by_admin, Pkt}, StateName, StateData); -handle_info({kick, Reason, Pkt}, _StateName, StateData) -> - send_element(StateData, Pkt), - {stop, normal, - StateData#state{authenticated = Reason}}; -handle_info({route, _From, _To, {broadcast, Data}}, - StateName, StateData) -> - ?DEBUG("broadcast~n~p~n", [Data]), - case Data of - {item, IJID, ISubscription} -> - fsm_next_state(StateName, - roster_change(IJID, ISubscription, StateData)); - {exit, Reason} -> - Lang = StateData#state.lang, - send_element(StateData, xmpp:serr_conflict(Reason, Lang)), - {stop, normal, StateData}; - {privacy_list, PrivList, PrivListName} -> - case ejabberd_hooks:run_fold(privacy_updated_list, - StateData#state.server, - false, - [StateData#state.privacy_list, - PrivList]) of - false -> - fsm_next_state(StateName, StateData); - NewPL -> - PrivPushIQ = - #iq{type = set, - from = jid:remove_resource(StateData#state.jid), - to = StateData#state.jid, - id = <<"push", (randoms:get_string())/binary>>, - sub_els = [#privacy_query{ - lists = [#privacy_list{ - name = PrivListName}]}]}, - NewState = send_stanza(StateData, PrivPushIQ), - fsm_next_state(StateName, - NewState#state{privacy_list = NewPL}) - end; - {blocking, What} -> - NewState = route_blocking(What, StateData), - fsm_next_state(StateName, NewState); - _ -> - fsm_next_state(StateName, StateData) - end; -%% Process Packets that are to be send to the user -handle_info({route, From, To, Packet}, StateName, StateData) when ?is_stanza(Packet) -> - {Pass, NewState} = - case Packet of - #presence{type = T} -> - State = ejabberd_hooks:run_fold(c2s_presence_in, - StateData#state.server, - StateData, - [{From, To, Packet}]), - case T of - probe -> - LFrom = jid:tolower(From), - LBFrom = jid:remove_resource(LFrom), - NewStateData = - case (?SETS):is_element(LFrom, State#state.pres_a) - orelse (?SETS):is_element(LBFrom, State#state.pres_a) of - true -> State; - false -> - case (?SETS):is_element(LFrom, State#state.pres_f) of - true -> - A = (?SETS):add_element(LFrom, State#state.pres_a), - State#state{pres_a = A}; - false -> - case (?SETS):is_element(LBFrom, State#state.pres_f) of - true -> - A = (?SETS):add_element(LBFrom, State#state.pres_a), - State#state{pres_a = A}; - false -> - State - end - end - end, - process_presence_probe(From, To, NewStateData), - {false, NewStateData}; - error -> - NewA = ?SETS:del_element(jid:tolower(From), State#state.pres_a), - {true, State#state{pres_a = NewA}}; - subscribe -> - SRes = is_privacy_allow(State, From, To, Packet, in), - {SRes, State}; - subscribed -> - SRes = is_privacy_allow(State, From, To, Packet, in), - {SRes, State}; - unsubscribe -> - SRes = is_privacy_allow(State, From, To, Packet, in), - {SRes, State}; - unsubscribed -> - SRes = is_privacy_allow(State, From, To, Packet, in), - {SRes, State}; - _ -> - case privacy_check_packet(State, From, To, Packet, in) of - allow -> - LFrom = jid:tolower(From), - LBFrom = jid:remove_resource(LFrom), - case (?SETS):is_element(LFrom, State#state.pres_a) - orelse (?SETS):is_element(LBFrom, State#state.pres_a) of - true -> - {true, State}; - false -> - case (?SETS):is_element(LFrom, State#state.pres_f) of - true -> - A = (?SETS):add_element(LFrom, State#state.pres_a), - {true, State#state{pres_a = A}}; - false -> - case (?SETS):is_element(LBFrom, - State#state.pres_f) of - true -> - A = (?SETS):add_element( - LBFrom, - State#state.pres_a), - {true, State#state{pres_a = A}}; - false -> - {true, State} - end - end - end; - deny -> {false, State} - end - end; - #iq{type = T} -> - case xmpp:has_subtag(Packet, #last{}) of - true when T == get; T == set -> - LFrom = jid:tolower(From), - LBFrom = jid:remove_resource(LFrom), - HasFromSub = ((?SETS):is_element(LFrom, StateData#state.pres_f) - orelse (?SETS):is_element(LBFrom, StateData#state.pres_f)) - andalso is_privacy_allow(StateData, To, From, #presence{}, out), - case HasFromSub of - true -> - case privacy_check_packet( - StateData, From, To, Packet, in) of - allow -> - {true, StateData}; - deny -> - ejabberd_router:route_error( - To, From, Packet, - xmpp:err_service_unavailable()), - {false, StateData} - end; - _ -> - ejabberd_router:route_error( - To, From, Packet, xmpp:err_forbidden()), - {false, StateData} - end; - _ -> - case privacy_check_packet(StateData, From, To, Packet, in) of - allow -> - {true, StateData}; - deny -> - ejabberd_router:route_error( - To, From, Packet, xmpp:err_service_unavailable()), - {false, StateData} - end - end; - #message{type = T} -> - case privacy_check_packet(StateData, From, To, Packet, in) of - allow -> - {true, StateData}; - deny -> - case T of - groupchat -> ok; - headline -> ok; - _ -> - case xmpp:has_subtag(Packet, #muc_user{}) of - true -> - ok; - false -> - ejabberd_router:route_error( - To, From, Packet, xmpp:err_service_unavailable()) - end - end, - {false, StateData} - end - end, - if Pass -> - FixedPacket0 = xmpp:set_from_to(Packet, From, To), - FixedPacket = ejabberd_hooks:run_fold( - user_receive_packet, - NewState#state.server, - FixedPacket0, - [NewState, NewState#state.jid, From, To]), - SentStateData = send_packet(NewState, FixedPacket), - ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), - fsm_next_state(StateName, SentStateData); +terminate(Reason, #{sid := SID, + user := U, server := S, resource := R, + lserver := LServer} = State) -> + case maps:is_key(pres_last, State) of true -> - ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), - fsm_next_state(StateName, NewState) - end; -handle_info({'DOWN', Monitor, _Type, _Object, _Info}, - _StateName, StateData) - when Monitor == StateData#state.socket_monitor -> - if StateData#state.mgmt_state == active; - StateData#state.mgmt_state == pending -> - fsm_next_state(wait_for_resume, StateData); - true -> - {stop, normal, StateData} - end; -handle_info(system_shutdown, StateName, StateData) -> - case StateName of - wait_for_stream -> - send_header(StateData, ?MYNAME, {1,0}, <<"en">>), - send_element(StateData, xmpp:serr_system_shutdown()), - ok; - _ -> - send_element(StateData, xmpp:serr_system_shutdown()), - ok + Status = format_reason(State, Reason), + ejabberd_sm:close_session_unset_presence(SID, U, S, R, Status); + false -> + ejabberd_sm:close_session(SID, U, S, R) end, - {stop, normal, StateData}; -handle_info({route_xmlstreamelement, El}, _StateName, StateData) -> - {next_state, NStateName, NStateData, _Timeout} = - session_established({xmlstreamelement, El}, StateData), - fsm_next_state(NStateName, NStateData); -handle_info({force_update_presence, LUser, LServer}, StateName, - #state{jid = #jid{luser = LUser, lserver = LServer}} = StateData) -> - NewStateData = case StateData#state.pres_last of - #presence{} -> - Presence = - ejabberd_hooks:run_fold(c2s_update_presence, - LServer, - StateData#state.pres_last, - [LUser, LServer]), - StateData2 = StateData#state{pres_last = Presence}, - presence_update(StateData2#state.jid, Presence, - StateData2), - StateData2; - undefined -> StateData - end, - fsm_next_state(StateName, NewStateData); -handle_info({send_filtered, Feature, From, To, Packet}, StateName, StateData) -> - Drop = ejabberd_hooks:run_fold(c2s_filter_packet, StateData#state.server, - true, [StateData#state.server, StateData, - Feature, To, Packet]), - NewStateData = if Drop -> - ?DEBUG("Dropping packet from ~p to ~p", - [jid:to_string(From), - jid:to_string(To)]), - StateData; - true -> - FinalPacket = xmpp:set_from_to(Packet, From, To), - case StateData#state.jid of - To -> - case privacy_check_packet(StateData, From, To, - FinalPacket, in) of - deny -> - StateData; - allow -> - send_stanza(StateData, FinalPacket) - end; - _ -> - ejabberd_router:route(From, To, FinalPacket), - StateData - end - end, - fsm_next_state(StateName, NewStateData); -handle_info({broadcast, Type, From, Packet}, StateName, StateData) -> - Recipients = ejabberd_hooks:run_fold( - c2s_broadcast_recipients, StateData#state.server, - [], - [StateData#state.server, StateData, Type, From, Packet]), - lists:foreach( - fun(USR) -> - ejabberd_router:route( - From, jid:make(USR), Packet) - end, lists:usort(Recipients)), - fsm_next_state(StateName, StateData); -handle_info({set_csi_state, CsiState}, StateName, StateData) -> - fsm_next_state(StateName, StateData#state{csi_state = CsiState}); -handle_info({set_resume_timeout, Timeout}, StateName, StateData) -> - fsm_next_state(StateName, StateData#state{mgmt_timeout = Timeout}); -handle_info(dont_ask_offline, StateName, StateData) -> - fsm_next_state(StateName, StateData#state{ask_offline = false}); -handle_info(close, StateName, StateData) -> - ?DEBUG("Timeout waiting for stream management acknowledgement of ~s", - [jid:to_string(StateData#state.jid)]), - close(self()), - fsm_next_state(StateName, StateData#state{mgmt_ack_timer = undefined}); -handle_info({_Ref, {resume, OldStateData}}, StateName, StateData) -> - %% This happens if the resume_session/1 request timed out; the new session - %% now receives the late response. - ?DEBUG("Received old session state for ~s after failed resumption", - [jid:to_string(OldStateData#state.jid)]), - handle_unacked_stanzas(OldStateData#state{mgmt_resend = false}), - fsm_next_state(StateName, StateData); -handle_info(Info, StateName, StateData) -> - ?ERROR_MSG("Unexpected info: ~p", [Info]), - fsm_next_state(StateName, StateData). + ejabberd_hooks:run_fold(c2s_terminated, LServer, State, [Reason]); +terminate(Reason, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_terminated, LServer, State, [Reason]). --spec print_state(state()) -> state(). -print_state(State = #state{pres_t = T, pres_f = F, pres_a = A}) -> - State#state{pres_t = {pres_t, (?SETS):size(T)}, - pres_f = {pres_f, (?SETS):size(F)}, - pres_a = {pres_a, (?SETS):size(A)}}. +code_change(_OldVsn, State, _Extra) -> + {ok, State}. -terminate(_Reason, StateName, StateData) -> - case StateData#state.mgmt_state of - resumed -> - ?INFO_MSG("Closing former stream of resumed session for ~s", - [jid:to_string(StateData#state.jid)]); - _ -> - if StateName == session_established; - StateName == wait_for_resume -> - case StateData#state.authenticated of - replaced -> - ?INFO_MSG("(~w) Replaced session for ~s", - [StateData#state.socket, - jid:to_string(StateData#state.jid)]), - From = StateData#state.jid, - Lang = StateData#state.lang, - Status = <<"Replaced by new connection">>, - Packet = #presence{ - type = unavailable, - status = xmpp:mk_text(Status, Lang)}, - ejabberd_sm:close_session_unset_presence(StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource, - Status), - presence_broadcast(StateData, From, - StateData#state.pres_a, Packet); - _ -> - ?INFO_MSG("(~w) Close session for ~s", - [StateData#state.socket, - jid:to_string(StateData#state.jid)]), - EmptySet = (?SETS):new(), - case StateData of - #state{pres_last = undefined, pres_a = EmptySet} -> - ejabberd_sm:close_session(StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource); - _ -> - From = StateData#state.jid, - Packet = #presence{type = unavailable}, - ejabberd_sm:close_session_unset_presence(StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource, - <<"">>), - presence_broadcast(StateData, From, - StateData#state.pres_a, Packet) - end, - case StateData#state.mgmt_state of - timeout -> - Info = [{num_stanzas_in, - StateData#state.mgmt_stanzas_in}], - ejabberd_sm:set_offline_info(StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource, - Info); - _ -> +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec process_iq_in(state(), iq()) -> {boolean(), state()}. +process_iq_in(State, #iq{} = IQ) -> + case privacy_check_packet(State, IQ, in) of + allow -> + {true, State}; + deny -> + route_error(IQ, xmpp:err_service_unavailable()), + {false, State} + end. + +-spec process_message_in(state(), message()) -> {boolean(), state()}. +process_message_in(State, #message{type = T} = Msg) -> + case privacy_check_packet(State, Msg, in) of + allow -> + {true, State}; + deny when T == groupchat; T == headline -> + ok; + deny -> + case xmpp:has_subtag(Msg, #muc_user{}) of + true -> + ok; + false -> + route_error(Msg, xmpp:err_service_unavailable()) + end, + {false, State} + end. + +-spec process_presence_in(state(), presence()) -> {boolean(), state()}. +process_presence_in(#{lserver := LServer, pres_a := PresA} = State0, + #presence{from = From, to = To, type = T} = Pres) -> + State = ejabberd_hooks:run_fold(c2s_presence_in, LServer, State0, [Pres]), + case T of + probe -> + NewState = add_to_pres_a(State, From), + route_probe_reply(From, To, NewState), + {false, NewState}; + error -> + A = ?SETS:del_element(jid:tolower(From), PresA), + {true, State#{pres_a => A}}; + _ -> + case privacy_check_packet(State, Pres, in) of + allow when T == error -> + {true, State}; + allow -> + NewState = add_to_pres_a(State, From), + {true, NewState}; + deny -> + {false, State} + end + end. + +-spec route_probe_reply(jid(), jid(), state()) -> ok. +route_probe_reply(From, To, #{lserver := LServer, pres_f := PresF, + pres_last := LastPres, + pres_timestamp := TS} = State) -> + LFrom = jid:tolower(From), + LBFrom = jid:remove_resource(LFrom), + case ?SETS:is_element(LFrom, PresF) + orelse ?SETS:is_element(LBFrom, PresF) of + true -> + %% To is my JID + Packet = xmpp_util:add_delay_info(LastPres, To, TS), + case privacy_check_packet(State, Packet, out) of + deny -> + ok; + allow -> + ejabberd_hooks:run(presence_probe_hook, + LServer, + [From, To, self()]), + %% Don't route a presence probe to oneself + case From == To of + false -> + route(xmpp:set_from_to(Packet, To, From)); + true -> ok - end - end, - handle_unacked_stanzas(StateData), - bounce_messages(); - true -> - ok - end - end, - catch send_trailer(StateData), - (StateData#state.sockmod):close(StateData#state.socket), + end + end; + false -> + ok + end; +route_probe_reply(_, _, _) -> ok. -%%%---------------------------------------------------------------------- -%%% Internal functions -%%%---------------------------------------------------------------------- --spec change_shaper(state(), jid()) -> ok. -change_shaper(StateData, JID) -> - Shaper = acl:access_matches(StateData#state.shaper, - #{usr => jid:split(JID), ip => StateData#state.ip}, - StateData#state.server), - (StateData#state.sockmod):change_shaper(StateData#state.socket, - Shaper). - --spec send_text(state(), iodata()) -> ok | {error, any()}. -send_text(StateData, Text) when StateData#state.mgmt_state == pending -> - ?DEBUG("Cannot send text while waiting for resumption: ~p", [Text]); -send_text(StateData, Text) when StateData#state.xml_socket -> - ?DEBUG("Send Text on stream = ~p", [Text]), - (StateData#state.sockmod):send_xml(StateData#state.socket, - {xmlstreamraw, Text}); -send_text(StateData, Text) when StateData#state.mgmt_state == active -> - ?DEBUG("Send XML on stream = ~p", [Text]), - case catch (StateData#state.sockmod):send(StateData#state.socket, Text) of - {'EXIT', _} -> - (StateData#state.sockmod):close(StateData#state.socket), - {error, closed}; - _ -> - ok - end; -send_text(StateData, Text) -> - ?DEBUG("Send XML on stream = ~p", [Text]), - (StateData#state.sockmod):send(StateData#state.socket, Text). - --spec send_element(state(), xmlel() | xmpp_element()) -> ok | {error, any()}. -send_element(StateData, El) when StateData#state.mgmt_state == pending -> - ?DEBUG("Cannot send element while waiting for resumption: ~p", [El]); -send_element(StateData, #xmlel{} = El) when StateData#state.xml_socket -> - ?DEBUG("Send XML on stream = ~p", [fxml:element_to_binary(El)]), - (StateData#state.sockmod):send_xml(StateData#state.socket, - {xmlstreamelement, El}); -send_element(StateData, #xmlel{} = El) -> - send_text(StateData, fxml:element_to_binary(El)); -send_element(StateData, Pkt) -> - send_element(StateData, xmpp:encode(Pkt, ?NS_CLIENT)). - --spec send_error(state(), xmlel() | stanza(), stanza_error()) -> state(). -send_error(StateData, Stanza, Error) -> - Type = xmpp:get_type(Stanza), - if Type == error; Type == result; - Type == <<"error">>; Type == <<"result">> -> - StateData; - true -> - send_stanza(StateData, xmpp:make_error(Stanza, Error)) - end. - --spec send_stanza(state(), xmpp_element()) -> state(). -send_stanza(StateData, Stanza) when StateData#state.csi_state == inactive -> - csi_filter_stanza(StateData, Stanza); -send_stanza(StateData, Stanza) when StateData#state.mgmt_state == pending -> - mgmt_queue_add(StateData, Stanza); -send_stanza(StateData, Stanza) when StateData#state.mgmt_state == active -> - NewStateData = mgmt_queue_add(StateData, Stanza), - mgmt_send_stanza(NewStateData, Stanza); -send_stanza(StateData, Stanza) -> - send_element(StateData, Stanza), - StateData. - --spec send_packet(state(), xmpp_element()) -> state(). -send_packet(StateData, Packet) -> - case xmpp:is_stanza(Packet) of - true -> - send_stanza(StateData, Packet); - false -> - send_element(StateData, Packet), - StateData - end. - --spec send_header(state(), binary(), binary(), binary()) -> ok | {error, any()}. -send_header(StateData, Server, Version, Lang) -> - Header = #xmlel{name = Name, attrs = Attrs} = - xmpp:encode(#stream_start{version = Version, - lang = Lang, - xmlns = ?NS_CLIENT, - stream_xmlns = ?NS_STREAM, - id = StateData#state.streamid, - from = jid:make(Server)}), - if StateData#state.xml_socket -> - (StateData#state.sockmod):send_xml(StateData#state.socket, - {xmlstreamstart, Name, Attrs}); - true -> - send_text(StateData, fxml:element_to_header(Header)) - end. - --spec send_trailer(state()) -> ok | {error, any()}. -send_trailer(StateData) - when StateData#state.mgmt_state == pending -> - ?DEBUG("Cannot send stream trailer while waiting for resumption", []); -send_trailer(StateData) - when StateData#state.xml_socket -> - (StateData#state.sockmod):send_xml(StateData#state.socket, - {xmlstreamend, <<"stream:stream">>}); -send_trailer(StateData) -> - send_text(StateData, ?STREAM_TRAILER). - --spec new_id() -> binary(). -new_id() -> randoms:get_string(). - --spec new_uniq_id() -> binary(). -new_uniq_id() -> - iolist_to_binary([randoms:get_string(), - integer_to_binary(p1_time_compat:unique_integer([positive]))]). - --spec get_conn_type(state()) -> c2s | c2s_tls | c2s_compressed | websocket | - c2s_compressed_tls | http_bind. -get_conn_type(StateData) -> - case (StateData#state.sockmod):get_transport(StateData#state.socket) of - tcp -> c2s; - tls -> c2s_tls; - tcp_zlib -> c2s_compressed; - tls_zlib -> c2s_compressed_tls; - http_bind -> http_bind; - websocket -> websocket - end. - --spec process_presence_probe(jid(), jid(), state()) -> ok. -process_presence_probe(From, To, StateData) -> - LFrom = jid:tolower(From), - LBFrom = setelement(3, LFrom, <<"">>), - case StateData#state.pres_last of - undefined -> - ok; - _ -> - Cond = ((?SETS):is_element(LFrom, StateData#state.pres_f) - orelse - ((LFrom /= LBFrom) andalso - (?SETS):is_element(LBFrom, StateData#state.pres_f))), - if Cond -> - %% To is the one sending the presence (the probe target) - Packet = xmpp_util:add_delay_info( - StateData#state.pres_last, To, - StateData#state.pres_timestamp), - case privacy_check_packet(StateData, To, From, Packet, out) of - deny -> - ok; - allow -> - Pid=element(2, StateData#state.sid), - ejabberd_hooks:run(presence_probe_hook, StateData#state.server, [From, To, Pid]), - %% Don't route a presence probe to oneself - case From == To of - false -> - ejabberd_router:route(To, From, Packet); - true -> - ok - end - end; - true -> - ok - end - end. - -%% User updates his presence (non-directed presence packet) --spec presence_update(jid(), presence(), state()) -> state(). -presence_update(From, Packet, StateData) -> - #presence{type = Type} = Packet, - case Type of - unavailable -> - Status = xmpp:get_text(Packet#presence.status), - Info = [{ip, StateData#state.ip}, - {conn, StateData#state.conn}, - {auth_module, StateData#state.auth_module}], - ejabberd_sm:unset_presence(StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource, Status, Info), - presence_broadcast(StateData, From, - StateData#state.pres_a, Packet), - StateData#state{pres_last = undefined, - pres_timestamp = undefined, pres_a = (?SETS):new()}; - error -> StateData; - probe -> StateData; - subscribe -> StateData; - subscribed -> StateData; - unsubscribe -> StateData; - unsubscribed -> StateData; - _ -> - OldPriority = case StateData#state.pres_last of - undefined -> 0; - OldPresence -> get_priority_from_presence(OldPresence) - end, - NewPriority = get_priority_from_presence(Packet), - update_priority(NewPriority, Packet, StateData), - FromUnavail = (StateData#state.pres_last == undefined), - ?DEBUG("from unavail = ~p~n", [FromUnavail]), - NewStateData = StateData#state{pres_last = Packet, - pres_timestamp = p1_time_compat:timestamp()}, - NewState = if FromUnavail -> - ejabberd_hooks:run(user_available_hook, - NewStateData#state.server, - [NewStateData#state.jid]), - ResentStateData = if NewPriority >= 0 -> - resend_offline_messages(NewStateData), - resend_subscription_requests(NewStateData); - true -> NewStateData - end, - presence_broadcast_first(From, ResentStateData, - Packet); - true -> - presence_broadcast_to_trusted(NewStateData, From, - NewStateData#state.pres_f, - NewStateData#state.pres_a, - Packet), - if OldPriority < 0, NewPriority >= 0 -> - resend_offline_messages(NewStateData); - true -> ok - end, - NewStateData - end, - NewState - end. - -%% User sends a directed presence packet --spec presence_track(jid(), jid(), presence(), state()) -> state(). -presence_track(From, To, Packet, StateData) -> - #presence{type = Type} = Packet, +-spec process_presence_out(state(), presence()) -> state(). +process_presence_out(#{user := User, server := Server, lserver := LServer, + jid := JID, lang := Lang, pres_a := PresA} = State, + #presence{from = From, to = To, type = Type} = Pres) -> LTo = jid:tolower(To), - User = StateData#state.user, - Server = StateData#state.server, - Lang = StateData#state.lang, - case privacy_check_packet(StateData, From, To, Packet, out) of + case privacy_check_packet(State, Pres, out) of deny -> ErrText = <<"Your active privacy list has denied " "the routing of this stanza.">>, Err = xmpp:err_not_acceptable(ErrText, Lang), - send_error(StateData, xmpp:set_from_to(Packet, From, To), Err); + xmpp_stream_in:send_error(State, Pres, Err); allow when Type == subscribe; Type == subscribed; Type == unsubscribe; Type == unsubscribed -> - Access = gen_mod:get_module_opt(Server, mod_roster, access, + Access = gen_mod:get_module_opt(LServer, mod_roster, access, fun(A) when is_atom(A) -> A end, all), - MyBareJID = jid:make(User, Server, <<"">>), - case acl:match_rule(Server, Access, MyBareJID) of + MyBareJID = jid:remove_resource(JID), + case acl:match_rule(LServer, Access, MyBareJID) of deny -> ErrText = <<"Denied by ACL">>, Err = xmpp:err_forbidden(ErrText, Lang), - send_error(StateData, xmpp:set_from_to(Packet, From, To), Err); + xmpp_stream_in:send_error(State, Pres, Err); allow -> ejabberd_hooks:run(roster_out_subscription, - Server, + LServer, [User, Server, To, Type]), - ejabberd_router:route(jid:remove_resource(From), To, Packet), - StateData + BareFrom = jid:remove_resource(From), + route(xmpp:set_from_to(Pres, BareFrom, To)), + State end; allow when Type == error; Type == probe -> - ejabberd_router:route(From, To, Packet), - StateData; + route(Pres), + State; allow -> - ejabberd_router:route(From, To, Packet), + route(Pres), A = case Type of - available -> - ?SETS:add_element(LTo, StateData#state.pres_a); - unavailable -> - ?SETS:del_element(LTo, StateData#state.pres_a) + available -> ?SETS:add_element(LTo, PresA); + unavailable -> ?SETS:del_element(LTo, PresA) end, - StateData#state{pres_a = A} + State#{pres_a => A} end. --spec check_privacy_route(jid(), state(), jid(), jid(), stanza()) -> state(). -check_privacy_route(From, StateData, FromRoute, To, - Packet) -> - case privacy_check_packet(StateData, From, To, Packet, - out) of +-spec process_self_presence(state(), presence()) -> state(). +process_self_presence(#{ip := IP, conn := Conn, lserver := LServer, + auth_module := AuthMod, sid := SID, + user := U, server := S, resource := R} = State, + #presence{type = unavailable} = Pres) -> + Status = xmpp:get_text(Pres#presence.status), + Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}], + ejabberd_sm:unset_presence(SID, U, S, R, Status, Info), + {Pres1, State1} = ejabberd_hooks:run_fold( + c2s_self_presence, LServer, {Pres, State}, []), + State2 = broadcast_presence_unavailable(State1, Pres1), + maps:remove(pres_last, maps:remove(pres_timestamp, State2)); +process_self_presence(#{lserver := LServer} = State, + #presence{type = available} = Pres) -> + PreviousPres = maps:get(pres_last, State, undefined), + update_priority(State, Pres), + {Pres1, State1} = ejabberd_hooks:run_fold( + c2s_self_presence, LServer, {Pres, State}, []), + State2 = State1#{pres_last => Pres1, + pres_timestamp => p1_time_compat:timestamp()}, + FromUnavailable = PreviousPres == undefined, + broadcast_presence_available(State2, Pres1, FromUnavailable); +process_self_presence(State, _Pres) -> + State. + +-spec update_priority(state(), presence()) -> ok. +update_priority(#{ip := IP, conn := Conn, auth_module := AuthMod, + sid := SID, user := U, server := S, resource := R}, + Pres) -> + Priority = get_priority_from_presence(Pres), + Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}], + ejabberd_sm:set_presence(SID, U, S, R, Priority, Pres, Info). + +-spec broadcast_presence_unavailable(state(), presence()) -> state(). +broadcast_presence_unavailable(#{pres_a := PresA} = State, Pres) -> + JIDs = filter_blocked(State, Pres, PresA), + route_multiple(State, JIDs, Pres), + State#{pres_a => ?SETS:new()}. + +-spec broadcast_presence_available(state(), presence(), boolean()) -> state(). +broadcast_presence_available(#{pres_a := PresA, pres_f := PresF, + pres_t := PresT, jid := JID} = State, + Pres, _FromUnavailable = true) -> + Probe = #presence{from = JID, type = probe}, + TJIDs = filter_blocked(State, Probe, PresT), + FJIDs = filter_blocked(State, Pres, PresF), + route_multiple(State, TJIDs, Probe), + route_multiple(State, FJIDs, Pres), + State#{pres_a => ?SETS:union(PresA, PresF)}; +broadcast_presence_available(#{pres_a := PresA, pres_f := PresF} = State, + Pres, _FromUnavailable = false) -> + JIDs = filter_blocked(State, Pres, ?SETS:intersection(PresA, PresF)), + route_multiple(State, JIDs, Pres), + State. + +-spec check_privacy_then_route(state(), stanza()) -> state(). +check_privacy_then_route(#{lang := Lang} = State, Pkt) -> + case privacy_check_packet(State, Pkt, out) of deny -> - Lang = StateData#state.lang, ErrText = <<"Your active privacy list has denied " "the routing of this stanza.">>, Err = xmpp:err_not_acceptable(ErrText, Lang), - send_error(StateData, xmpp:set_from_to(Packet, From, To), Err); + xmpp_stream_in:send_error(State, Pkt, Err); allow -> - ejabberd_router:route(FromRoute, To, Packet), - StateData + route(Pkt), + State end. -%% Check if privacy rules allow this delivery --spec privacy_check_packet(state(), jid(), jid(), stanza(), in | out) -> allow | deny. -privacy_check_packet(StateData, From, To, Packet, - Dir) -> - ejabberd_hooks:run_fold(privacy_check_packet, - StateData#state.server, allow, - [StateData#state.user, StateData#state.server, - StateData#state.privacy_list, {From, To, Packet}, - Dir]). - --spec is_privacy_allow(state(), jid(), jid(), stanza(), in | out) -> boolean(). -is_privacy_allow(StateData, From, To, Packet, Dir) -> - allow == - privacy_check_packet(StateData, From, To, Packet, Dir). - -%% Send presence when disconnecting --spec presence_broadcast(state(), jid(), ?SETS:set(), presence()) -> ok. -presence_broadcast(StateData, From, JIDSet, Packet) -> - JIDs = ?SETS:to_list(JIDSet), - JIDs2 = format_and_check_privacy(From, StateData, Packet, JIDs, out), - Server = StateData#state.server, - send_multiple(From, Server, JIDs2, Packet). - --spec presence_broadcast_to_trusted( - state(), jid(), ?SETS:set(), ?SETS:set(), presence()) -> ok. -%% Send presence when updating presence -presence_broadcast_to_trusted(StateData, From, Trusted, JIDSet, Packet) -> - JIDs = ?SETS:to_list(?SETS:intersection(Trusted, JIDSet)), - JIDs2 = format_and_check_privacy(From, StateData, Packet, JIDs, out), - Server = StateData#state.server, - send_multiple(From, Server, JIDs2, Packet). - -%% Send presence when connecting --spec presence_broadcast_first(jid(), state(), presence()) -> state(). -presence_broadcast_first(From, StateData, Packet) -> - JIDsProbe = - ?SETS:fold( - fun(JID, L) -> [JID | L] end, - [], - StateData#state.pres_t), - PacketProbe = #presence{type = probe}, - JIDs2Probe = format_and_check_privacy(From, StateData, PacketProbe, JIDsProbe, out), - Server = StateData#state.server, - send_multiple(From, Server, JIDs2Probe, PacketProbe), - {As, JIDs} = - ?SETS:fold( - fun(JID, {A, JID_list}) -> - {?SETS:add_element(JID, A), JID_list++[JID]} - end, - {StateData#state.pres_a, []}, - StateData#state.pres_f), - JIDs2 = format_and_check_privacy(From, StateData, Packet, JIDs, out), - send_multiple(From, Server, JIDs2, Packet), - StateData#state{pres_a = As}. - --spec format_and_check_privacy( - jid(), state(), stanza(), [ljid()], in | out) -> [jid()]. -format_and_check_privacy(From, StateData, Packet, JIDs, Dir) -> - FJIDs = [jid:make(JID) || JID <- JIDs], - lists:filter( - fun(FJID) -> - case ejabberd_hooks:run_fold( - privacy_check_packet, StateData#state.server, - allow, - [StateData#state.user, - StateData#state.server, - StateData#state.privacy_list, - {From, FJID, Packet}, - Dir]) of - deny -> false; - allow -> true - end - end, - FJIDs). - --spec send_multiple(jid(), binary(), [jid()], stanza()) -> ok. -send_multiple(From, Server, JIDs, Packet) -> - ejabberd_router_multicast:route_multicast(From, Server, JIDs, Packet). - --spec roster_change(jid(), both | from | none | remove | to, state()) -> state(). -roster_change(IJID, ISubscription, StateData) -> - LIJID = jid:tolower(IJID), - IsFrom = (ISubscription == both) or (ISubscription == from), - IsTo = (ISubscription == both) or (ISubscription == to), - OldIsFrom = (?SETS):is_element(LIJID, StateData#state.pres_f), - FSet = if - IsFrom -> (?SETS):add_element(LIJID, StateData#state.pres_f); - true -> ?SETS:del_element(LIJID, StateData#state.pres_f) - end, - TSet = if - IsTo -> (?SETS):add_element(LIJID, StateData#state.pres_t); - true -> ?SETS:del_element(LIJID, StateData#state.pres_t) - end, - case StateData#state.pres_last of - undefined -> - StateData#state{pres_f = FSet, pres_t = TSet}; - P -> - ?DEBUG("roster changed for ~p~n", - [StateData#state.user]), - From = StateData#state.jid, - To = jid:make(IJID), - Cond1 = IsFrom andalso not OldIsFrom, - Cond2 = not IsFrom andalso OldIsFrom andalso - ((?SETS):is_element(LIJID, StateData#state.pres_a)), - if Cond1 -> - ?DEBUG("C1: ~p~n", [LIJID]), - case privacy_check_packet(StateData, From, To, P, out) - of - deny -> ok; - allow -> ejabberd_router:route(From, To, P) - end, - A = (?SETS):add_element(LIJID, StateData#state.pres_a), - StateData#state{pres_a = A, pres_f = FSet, - pres_t = TSet}; - Cond2 -> - ?DEBUG("C2: ~p~n", [LIJID]), - PU = #presence{type = unavailable}, - case privacy_check_packet(StateData, From, To, PU, out) - of - deny -> ok; - allow -> ejabberd_router:route(From, To, PU) - end, - A = ?SETS:del_element(LIJID, StateData#state.pres_a), - StateData#state{pres_a = A, pres_f = FSet, - pres_t = TSet}; - true -> StateData#state{pres_f = FSet, pres_t = TSet} - end - end. - --spec update_priority(integer(), presence(), state()) -> ok. -update_priority(Priority, Packet, StateData) -> - Info = [{ip, StateData#state.ip}, {conn, StateData#state.conn}, - {auth_module, StateData#state.auth_module}], - ejabberd_sm:set_presence(StateData#state.sid, - StateData#state.user, StateData#state.server, - StateData#state.resource, Priority, Packet, Info). +-spec privacy_check_packet(state(), stanza(), in | out) -> allow | deny. +privacy_check_packet(#{lserver := LServer} = State, Pkt, Dir) -> + ejabberd_hooks:run_fold(privacy_check_packet, LServer, allow, [State, Pkt, Dir]). -spec get_priority_from_presence(presence()) -> integer(). get_priority_from_presence(#presence{priority = Prio}) -> @@ -2031,817 +738,177 @@ get_priority_from_presence(#presence{priority = Prio}) -> _ -> Prio end. --spec process_privacy_iq(iq(), state()) -> state(). -process_privacy_iq(#iq{from = From, to = To, - type = Type, lang = Lang} = IQ, StateData) -> - Txt = <<"No module is handling this query">>, - {Res, NewStateData} = - case Type of - get -> - R = ejabberd_hooks:run_fold( - privacy_iq_get, - StateData#state.server, - {error, xmpp:err_feature_not_implemented(Txt, Lang)}, - [IQ, StateData#state.privacy_list]), - {R, StateData}; - set -> - case ejabberd_hooks:run_fold( - privacy_iq_set, - StateData#state.server, - {error, xmpp:err_feature_not_implemented(Txt, Lang)}, - [IQ, StateData#state.privacy_list]) - of - {result, R, NewPrivList} -> - {{result, R}, - StateData#state{privacy_list = - NewPrivList}}; - R -> {R, StateData} - end - end, - IQRes = case Res of - {result, Result} -> - xmpp:make_iq_result(IQ, Result); - {error, Error} -> - xmpp:make_error(IQ, Error) - end, - ejabberd_router:route(To, From, IQRes), - NewStateData. +-spec filter_blocked(state(), presence(), ?SETS:set()) -> [jid()]. +filter_blocked(#{jid := From} = State, Pres, LJIDSet) -> + ?SETS:fold( + fun(LJID, Acc) -> + To = jid:make(LJID), + Pkt = xmpp:set_from_to(Pres, From, To), + case privacy_check_packet(State, Pkt, out) of + allow -> [To|Acc]; + deny -> Acc + end + end, [], LJIDSet). --spec resend_offline_messages(state()) -> ok. -resend_offline_messages(#state{ask_offline = true} = StateData) -> - case ejabberd_hooks:run_fold(resend_offline_messages_hook, - StateData#state.server, [], - [StateData#state.user, StateData#state.server]) - of - Rs -> %%when is_list(Rs) -> - lists:foreach(fun ({route, From, To, Packet}) -> - Pass = case privacy_check_packet(StateData, - From, To, - Packet, in) - of - allow -> true; - deny -> false - end, - if Pass -> - ejabberd_router:route(From, To, Packet); - true -> ok - end - end, - Rs) - end; -resend_offline_messages(_StateData) -> - ok. +-spec route(stanza()) -> ok. +route(Pkt) -> + From = xmpp:get_from(Pkt), + To = xmpp:get_to(Pkt), + ejabberd_router:route(From, To, Pkt). --spec resend_subscription_requests(state()) -> state(). -resend_subscription_requests(#state{user = User, - server = Server} = StateData) -> - PendingSubscriptions = - ejabberd_hooks:run_fold(resend_subscription_requests_hook, - Server, [], [User, Server]), - lists:foldl(fun (XMLPacket, AccStateData) -> - send_packet(AccStateData, XMLPacket) +-spec route_error(stanza(), stanza_error()) -> ok. +route_error(Pkt, Err) -> + From = xmpp:get_from(Pkt), + To = xmpp:get_to(Pkt), + ejabberd_router:route_error(To, From, Pkt, Err). + +-spec route_multiple(state(), [jid()], stanza()) -> ok. +route_multiple(#{lserver := LServer}, JIDs, Pkt) -> + From = xmpp:get_from(Pkt), + ejabberd_router_multicast:route_multicast(From, LServer, JIDs, Pkt). + +-spec resource_conflict_action(binary(), binary(), binary()) -> + {accept_resource, binary()} | closenew. +resource_conflict_action(U, S, R) -> + OptionRaw = case ejabberd_sm:is_existing_resource(U, S, R) of + true -> + ejabberd_config:get_option( + {resource_conflict, S}, + fun(setresource) -> setresource; + (closeold) -> closeold; + (closenew) -> closenew; + (acceptnew) -> acceptnew + end); + false -> + acceptnew end, - StateData, - PendingSubscriptions). - --spec get_showtag(undefined | presence()) -> binary(). -get_showtag(undefined) -> <<"unavailable">>; -get_showtag(#presence{show = undefined}) -> <<"available">>; -get_showtag(#presence{show = Show}) -> atom_to_binary(Show, utf8). - --spec get_statustag(undefined | presence()) -> binary(). -get_statustag(#presence{status = Status}) -> xmpp:get_text(Status); -get_statustag(undefined) -> <<"">>. - --spec process_unauthenticated_stanza(state(), iq()) -> ok | {error, any()}. -process_unauthenticated_stanza(StateData, #iq{type = T, lang = L} = IQ) - when T == set; T == get -> - Lang = if L == undefined; L == <<"">> -> StateData#state.lang; - true -> L - end, - NewIQ = IQ#iq{lang = Lang}, - Res = ejabberd_hooks:run_fold(c2s_unauthenticated_iq, - StateData#state.server, empty, - [StateData#state.server, NewIQ, - StateData#state.ip]), - case Res of - empty -> - Txt = <<"Authentication required">>, - Err0 = xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)), - Err1 = Err0#iq{from = jid:make(<<>>, StateData#state.server, <<>>), - to = undefined}, - send_element(StateData, Err1); - _ -> - send_element(StateData, Res) - end; -process_unauthenticated_stanza(_StateData, _) -> - %% Drop any stanza, which isn't IQ stanza - ok. - --spec peerip(ejabberd_socket:sockmod(), - ejabberd_socket:socket()) -> - {inet:ip_address(), non_neg_integer()} | undefined. -peerip(SockMod, Socket) -> - IP = case SockMod of - gen_tcp -> inet:peername(Socket); - _ -> SockMod:peername(Socket) - end, - case IP of - {ok, IPOK} -> IPOK; - _ -> undefined + Option = case OptionRaw of + setresource -> setresource; + closeold -> + acceptnew; %% ejabberd_sm will close old session + closenew -> closenew; + acceptnew -> acceptnew; + _ -> acceptnew %% default ejabberd behavior + end, + case Option of + acceptnew -> {accept_resource, R}; + closenew -> closenew; + setresource -> + Rnew = new_uniq_id(), + {accept_resource, Rnew} end. -%% fsm_next_state_pack: Pack the StateData structure to improve -%% sharing. --spec fsm_next_state_pack(state_name(), state()) -> fsm_transition(). -fsm_next_state_pack(StateName, StateData) -> - fsm_next_state_gc(StateName, pack(StateData)). +-spec bounce_message_queue() -> ok. +bounce_message_queue() -> + receive {route, From, To, Pkt} -> + ejabberd_router:route(From, To, Pkt), + bounce_message_queue() + after 0 -> + ok + end. --spec fsm_next_state_gc(state_name(), state()) -> fsm_transition(). -%% fsm_next_state_gc: Garbage collect the process heap to make use of -%% the newly packed StateData structure. -fsm_next_state_gc(StateName, PackedStateData) -> - erlang:garbage_collect(), - fsm_next_state(StateName, PackedStateData). +-spec new_uniq_id() -> binary(). +new_uniq_id() -> + iolist_to_binary( + [randoms:get_string(), + integer_to_binary(p1_time_compat:unique_integer([positive]))]). -%% fsm_next_state: Generate the next_state FSM tuple with different -%% timeout, depending on the future state --spec fsm_next_state(state_name(), state()) -> fsm_transition(). -fsm_next_state(session_established, #state{mgmt_max_queue = exceeded} = - StateData) -> - ?WARNING_MSG("ACK queue too long, terminating session for ~s", - [jid:to_string(StateData#state.jid)]), - Err = xmpp:serr_policy_violation(<<"Too many unacked stanzas">>, - StateData#state.lang), - send_element(StateData, Err), - {stop, normal, StateData#state{mgmt_resend = false}}; -fsm_next_state(session_established, #state{mgmt_state = pending} = StateData) -> - fsm_next_state(wait_for_resume, StateData); -fsm_next_state(session_established, StateData) -> - {next_state, session_established, StateData, - ?C2S_HIBERNATE_TIMEOUT}; -fsm_next_state(wait_for_resume, #state{mgmt_timeout = 0} = StateData) -> - {stop, normal, StateData}; -fsm_next_state(wait_for_resume, #state{mgmt_pending_since = undefined, - sid = SID, jid = JID, ip = IP, - conn = Conn, auth_module = AuthModule, - server = Host} = StateData) -> - case StateData of - #state{mgmt_ack_timer = undefined} -> - ok; - #state{mgmt_ack_timer = Timer} -> - erlang:cancel_timer(Timer) - end, - ?INFO_MSG("Waiting for resumption of stream for ~s", - [jid:to_string(JID)]), - Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}], - NewStateData = ejabberd_hooks:run_fold(c2s_session_pending, Host, StateData, - [SID, JID, Info]), - {next_state, wait_for_resume, - NewStateData#state{mgmt_state = pending, - mgmt_pending_since = os:timestamp()}, - NewStateData#state.mgmt_timeout}; -fsm_next_state(wait_for_resume, StateData) -> - Diff = timer:now_diff(os:timestamp(), StateData#state.mgmt_pending_since), - Timeout = max(StateData#state.mgmt_timeout - Diff div 1000, 1), - {next_state, wait_for_resume, StateData, Timeout}; -fsm_next_state(StateName, StateData) -> - {next_state, StateName, StateData, ?C2S_OPEN_TIMEOUT}. +-spec get_conn_type(state()) -> c2s | c2s_tls | c2s_compressed | websocket | + c2s_compressed_tls | http_bind. +get_conn_type(State) -> + case xmpp_stream_in:get_transport(State) of + tcp -> c2s; + tls -> c2s_tls; + tcp_zlib -> c2s_compressed; + tls_zlib -> c2s_compressed_tls; + http_bind -> http_bind; + websocket -> websocket + end. -%% fsm_reply: Generate the reply FSM tuple with different timeout, -%% depending on the future state --spec fsm_reply(_, state_name(), state()) -> fsm_reply(). -fsm_reply(Reply, session_established, StateData) -> - {reply, Reply, session_established, StateData, - ?C2S_HIBERNATE_TIMEOUT}; -fsm_reply(Reply, wait_for_resume, StateData) -> - Diff = timer:now_diff(os:timestamp(), StateData#state.mgmt_pending_since), - Timeout = max(StateData#state.mgmt_timeout - Diff div 1000, 1), - {reply, Reply, wait_for_resume, StateData, Timeout}; -fsm_reply(Reply, StateName, StateData) -> - {reply, Reply, StateName, StateData, ?C2S_OPEN_TIMEOUT}. +-spec fix_from_to(xmpp_element(), state()) -> stanza(). +fix_from_to(Pkt, #{jid := JID}) when ?is_stanza(Pkt) -> + #jid{luser = U, lserver = S, lresource = R} = JID, + From = xmpp:get_from(Pkt), + From1 = case jid:tolower(From) of + {U, S, R} -> JID; + {U, S, _} -> jid:replace_resource(JID, From#jid.resource); + _ -> From + end, + xmpp:set_from_to(Pkt, From1, JID); +fix_from_to(Pkt, _State) -> + Pkt. -%% Used by c2s blacklist plugins --spec is_ip_blacklisted(undefined | {inet:ip_address(), non_neg_integer()}, - binary()) -> false | {true, binary(), binary()}. -is_ip_blacklisted(undefined, _Lang) -> false; -is_ip_blacklisted({IP, _Port}, Lang) -> - ejabberd_hooks:run_fold(check_bl_c2s, false, [IP, Lang]). +-spec change_shaper(state()) -> ok. +change_shaper(#{shaper := ShaperName, ip := IP, lserver := LServer, + user := U, server := S, resource := R} = State) -> + JID = jid:make(U, S, R), + Shaper = acl:access_matches(ShaperName, + #{usr => jid:split(JID), ip => IP}, + LServer), + xmpp_stream_in:change_shaper(State, Shaper). -%% Check from attributes -%% returns invalid-from|NewElement --spec check_from(stanza(), jid()) -> 'invalid-from' | stanza(). -check_from(Pkt, FromJID) -> - JID = xmpp:get_from(Pkt), - case JID of - undefined -> - Pkt; - #jid{} -> - if - (JID#jid.luser == FromJID#jid.luser) and - (JID#jid.lserver == FromJID#jid.lserver) and - (JID#jid.lresource == FromJID#jid.lresource) -> - Pkt; - (JID#jid.luser == FromJID#jid.luser) and - (JID#jid.lserver == FromJID#jid.lserver) and - (JID#jid.lresource == <<"">>) -> - Pkt; +-spec add_to_pres_a(state(), jid()) -> state(). +add_to_pres_a(#{pres_a := PresA, pres_f := PresF} = State, From) -> + LFrom = jid:tolower(From), + LBFrom = jid:remove_resource(LFrom), + case (?SETS):is_element(LFrom, PresA) orelse + (?SETS):is_element(LBFrom, PresA) of + true -> + State; + false -> + case (?SETS):is_element(LFrom, PresF) of true -> - 'invalid-from' + A = (?SETS):add_element(LFrom, PresA), + State#{pres_a => A}; + false -> + case (?SETS):is_element(LBFrom, PresF) of + true -> + A = (?SETS):add_element(LBFrom, PresA), + State#{pres_a => A}; + false -> + State + end end end. -fsm_limit_opts(Opts) -> - case lists:keysearch(max_fsm_queue, 1, Opts) of - {value, {_, N}} when is_integer(N) -> [{max_queue, N}]; - _ -> - case ejabberd_config:get_option( - max_fsm_queue, - fun(I) when is_integer(I), I > 0 -> I end) of - undefined -> []; - N -> [{max_queue, N}] - end - end. - --spec bounce_messages() -> ok. -bounce_messages() -> - receive - {route, From, To, El} -> - ejabberd_router:route(From, To, El), bounce_messages() - after 0 -> ok - end. - --spec process_compression_request(compress(), state_name(), state()) -> fsm_next(). -process_compression_request(#compress{methods = []}, StateName, StateData) -> - send_element(StateData, #compress_failure{reason = 'setup-failed'}), - fsm_next_state(StateName, StateData); -process_compression_request(#compress{methods = Ms}, StateName, StateData) -> - case lists:member(<<"zlib">>, Ms) of - true -> - Socket = StateData#state.socket, - BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})), - ZlibSocket = (StateData#state.sockmod):compress(Socket, BCompressed), - fsm_next_state(wait_for_stream, - StateData#state{socket = ZlibSocket, - streamid = new_id()}); - false -> - send_element(StateData, - #compress_failure{reason = 'unsupported-method'}), - fsm_next_state(StateName, StateData) - end. - -%%%---------------------------------------------------------------------- -%%% XEP-0191 -%%%---------------------------------------------------------------------- - --spec route_blocking( - {block, [jid()]} | {unblock, [jid()]} | unblock_all, state()) -> state(). -route_blocking(What, StateData) -> - SubEl = case What of - {block, JIDs} -> - #block{items = JIDs}; - {unblock, JIDs} -> - #unblock{items = JIDs}; - unblock_all -> - #unblock{} - end, - PrivPushIQ = #iq{type = set, id = <<"push">>, sub_els = [SubEl], - from = jid:remove_resource(StateData#state.jid), - to = StateData#state.jid}, - %% No need to replace active privacy list here, - %% blocking pushes are always accompanied by - %% Privacy List pushes - send_stanza(StateData, PrivPushIQ). - -%%%---------------------------------------------------------------------- -%%% XEP-0198 -%%%---------------------------------------------------------------------- --spec stream_mgmt_enabled(state()) -> boolean(). -stream_mgmt_enabled(#state{mgmt_state = disabled}) -> - false; -stream_mgmt_enabled(_StateData) -> - true. - --spec dispatch_stream_mgmt(xmpp_element(), state()) -> state(). -dispatch_stream_mgmt(El, #state{mgmt_state = MgmtState} = StateData) - when MgmtState == active; - MgmtState == pending -> - perform_stream_mgmt(El, StateData); -dispatch_stream_mgmt(El, StateData) -> - negotiate_stream_mgmt(El, StateData). - --spec negotiate_stream_mgmt(xmpp_element(), state()) -> state(). -negotiate_stream_mgmt(_El, #state{resource = <<"">>} = StateData) -> - %% XEP-0198 says: "For client-to-server connections, the client MUST NOT - %% attempt to enable stream management until after it has completed Resource - %% Binding unless it is resuming a previous session". However, it also - %% says: "Stream management errors SHOULD be considered recoverable", so we - %% won't bail out. - send_element(StateData, #sm_failed{reason = 'unexpected-request', - xmlns = ?NS_STREAM_MGMT_3}), - StateData; -negotiate_stream_mgmt(Pkt, StateData) -> - Xmlns = xmpp:get_ns(Pkt), - case stream_mgmt_enabled(StateData) of - true -> - case Pkt of - #sm_enable{} -> - handle_enable(StateData#state{mgmt_xmlns = Xmlns}, Pkt); - _ -> - Res = if is_record(Pkt, sm_a); - is_record(Pkt, sm_r); - is_record(Pkt, sm_resume) -> - #sm_failed{reason = 'unexpected-request', - xmlns = Xmlns}; - true -> - #sm_failed{reason = 'bad-request', - xmlns = Xmlns} - end, - send_element(StateData, Res), - StateData - end; - false -> - send_element(StateData, - #sm_failed{reason = 'service-unavailable', - xmlns = Xmlns}), - StateData - end. - --spec perform_stream_mgmt(xmpp_element(), state()) -> state(). -perform_stream_mgmt(Pkt, StateData) -> - case xmpp:get_ns(Pkt) of - Xmlns when Xmlns == StateData#state.mgmt_xmlns -> - case Pkt of - #sm_r{} -> - handle_r(StateData); - #sm_a{} -> - handle_a(StateData, Pkt); - _ -> - Res = if is_record(Pkt, sm_enable); - is_record(Pkt, sm_resume) -> - #sm_failed{reason = 'unexpected-request', - xmlns = Xmlns}; - true -> - #sm_failed{reason = 'bad-request', - xmlns = Xmlns} - end, - send_element(StateData, Res), - StateData - end; - _ -> - send_element(StateData, - #sm_failed{reason = 'unsupported-version', - xmlns = StateData#state.mgmt_xmlns}) - end. - --spec handle_enable(state(), sm_enable()) -> state(). -handle_enable(#state{mgmt_timeout = DefaultTimeout, - mgmt_max_timeout = MaxTimeout} = StateData, - #sm_enable{resume = Resume, max = Max}) -> - Timeout = if Resume == false -> - 0; - Max /= undefined, Max > 0, Max =< MaxTimeout -> - Max; - true -> - DefaultTimeout - end, - Res = if Timeout > 0 -> - ?INFO_MSG("Stream management with resumption enabled for ~s", - [jid:to_string(StateData#state.jid)]), - #sm_enabled{xmlns = StateData#state.mgmt_xmlns, - id = make_resume_id(StateData), - resume = true, - max = Timeout}; - true -> - ?INFO_MSG("Stream management without resumption enabled for ~s", - [jid:to_string(StateData#state.jid)]), - #sm_enabled{xmlns = StateData#state.mgmt_xmlns} - end, - send_element(StateData, Res), - StateData#state{mgmt_state = active, - mgmt_queue = queue:new(), - mgmt_timeout = Timeout * 1000}. - --spec handle_r(state()) -> state(). -handle_r(StateData) -> - Res = #sm_a{xmlns = StateData#state.mgmt_xmlns, - h = StateData#state.mgmt_stanzas_in}, - send_element(StateData, Res), - StateData. - --spec handle_a(state(), sm_a()) -> state(). -handle_a(StateData, #sm_a{h = H}) -> - NewStateData = check_h_attribute(StateData, H), - maybe_renew_ack_request(NewStateData). - --spec handle_resume(state(), sm_resume()) -> {ok, state()} | error. -handle_resume(StateData, #sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) -> - R = case stream_mgmt_enabled(StateData) of - true -> - case inherit_session_state(StateData, PrevID) of - {ok, InheritedState, Info} -> - {ok, InheritedState, Info, H}; - {error, Err, InH} -> - {error, #sm_failed{reason = 'item-not-found', - h = InH, xmlns = Xmlns}, Err}; - {error, Err} -> - {error, #sm_failed{reason = 'item-not-found', - xmlns = Xmlns}, Err} - end; - false -> - {error, #sm_failed{reason = 'service-unavailable', - xmlns = Xmlns}, - <<"XEP-0198 disabled">>} - end, - case R of - {ok, ResumedState, ResumedInfo, NumHandled} -> - NewState = check_h_attribute(ResumedState, NumHandled), - AttrXmlns = NewState#state.mgmt_xmlns, - AttrId = make_resume_id(NewState), - AttrH = NewState#state.mgmt_stanzas_in, - send_element(NewState, #sm_resumed{xmlns = AttrXmlns, - h = AttrH, - previd = AttrId}), - SendFun = fun(_F, _T, El, Time) -> - NewEl = add_resent_delay_info(NewState, El, Time), - send_element(NewState, NewEl) - end, - handle_unacked_stanzas(NewState, SendFun), - send_element(NewState, #sm_r{xmlns = AttrXmlns}), - NewState1 = csi_flush_queue(NewState), - NewState2 = ejabberd_hooks:run_fold(c2s_session_resumed, - StateData#state.server, - NewState1, - [NewState1#state.sid, - NewState1#state.jid, - ResumedInfo]), - ?INFO_MSG("Resumed session for ~s", - [jid:to_string(NewState2#state.jid)]), - {ok, NewState2}; - {error, El, Msg} -> - send_element(StateData, El), - ?INFO_MSG("Cannot resume session for ~s@~s: ~s", - [StateData#state.user, StateData#state.server, Msg]), - error - end. - --spec check_h_attribute(state(), non_neg_integer()) -> state(). -check_h_attribute(#state{mgmt_stanzas_out = NumStanzasOut} = StateData, H) - when H > NumStanzasOut -> - ?DEBUG("~s acknowledged ~B stanzas, but only ~B were sent", - [jid:to_string(StateData#state.jid), H, NumStanzasOut]), - mgmt_queue_drop(StateData#state{mgmt_stanzas_out = H}, NumStanzasOut); -check_h_attribute(#state{mgmt_stanzas_out = NumStanzasOut} = StateData, H) -> - ?DEBUG("~s acknowledged ~B of ~B stanzas", - [jid:to_string(StateData#state.jid), H, NumStanzasOut]), - mgmt_queue_drop(StateData, H). - --spec update_num_stanzas_in(state(), xmpp_element()) -> state(). -update_num_stanzas_in(#state{mgmt_state = MgmtState} = StateData, El) - when MgmtState == active; - MgmtState == pending -> - NewNum = case {xmpp:is_stanza(El), StateData#state.mgmt_stanzas_in} of - {true, 4294967295} -> - 0; - {true, Num} -> - Num + 1; - {false, Num} -> - Num - end, - StateData#state{mgmt_stanzas_in = NewNum}; -update_num_stanzas_in(StateData, _El) -> - StateData. - -mgmt_send_stanza(StateData, Stanza) -> - case send_element(StateData, Stanza) of - ok -> - maybe_request_ack(StateData); - _ -> - StateData#state{mgmt_state = pending} - end. - -maybe_request_ack(#state{mgmt_ack_timer = undefined} = StateData) -> - request_ack(StateData); -maybe_request_ack(StateData) -> - StateData. - -request_ack(#state{mgmt_xmlns = Xmlns, - mgmt_ack_timeout = AckTimeout} = StateData) -> - AckReq = #sm_r{xmlns = Xmlns}, - case {send_element(StateData, AckReq), AckTimeout} of - {ok, undefined} -> - ok; - {ok, Timeout} -> - Timer = erlang:send_after(Timeout, self(), close), - StateData#state{mgmt_ack_timer = Timer, - mgmt_stanzas_req = StateData#state.mgmt_stanzas_out}; - _ -> - StateData#state{mgmt_state = pending} - end. - -maybe_renew_ack_request(#state{mgmt_ack_timer = undefined} = StateData) -> - StateData; -maybe_renew_ack_request(#state{mgmt_ack_timer = Timer, - mgmt_queue = Queue, - mgmt_stanzas_out = NumStanzasOut, - mgmt_stanzas_req = NumStanzasReq} = StateData) -> - erlang:cancel_timer(Timer), - case NumStanzasReq < NumStanzasOut andalso not queue:is_empty(Queue) of - true -> - request_ack(StateData#state{mgmt_ack_timer = undefined}); - false -> - StateData#state{mgmt_ack_timer = undefined} - end. - --spec mgmt_queue_add(state(), xmpp_element()) -> state(). -mgmt_queue_add(StateData, El) -> - NewNum = case StateData#state.mgmt_stanzas_out of - 4294967295 -> - 0; - Num -> - Num + 1 - end, - NewQueue = queue:in({NewNum, p1_time_compat:timestamp(), El}, StateData#state.mgmt_queue), - NewState = StateData#state{mgmt_queue = NewQueue, - mgmt_stanzas_out = NewNum}, - check_queue_length(NewState). - --spec mgmt_queue_drop(state(), non_neg_integer()) -> state(). -mgmt_queue_drop(StateData, NumHandled) -> - NewQueue = jlib:queue_drop_while(fun({N, _T, _E}) -> N =< NumHandled end, - StateData#state.mgmt_queue), - StateData#state{mgmt_queue = NewQueue}. - --spec check_queue_length(state()) -> state(). -check_queue_length(#state{mgmt_max_queue = Limit} = StateData) - when Limit == infinity; - Limit == exceeded -> - StateData; -check_queue_length(#state{mgmt_queue = Queue, - mgmt_max_queue = Limit} = StateData) -> - case queue:len(Queue) > Limit of - true -> - StateData#state{mgmt_max_queue = exceeded}; - false -> - StateData - end. - --spec handle_unacked_stanzas(state(), fun((_, _, _, _) -> _)) -> ok. -handle_unacked_stanzas(#state{mgmt_state = MgmtState} = StateData, F) - when MgmtState == active; - MgmtState == pending; - MgmtState == timeout -> - Queue = StateData#state.mgmt_queue, - case queue:len(Queue) of - 0 -> - ok; - N -> - ?DEBUG("~B stanza(s) were not acknowledged by ~s", - [N, jid:to_string(StateData#state.jid)]), - lists:foreach( - fun({_, Time, Pkt}) -> - From = xmpp:get_from(Pkt), - To = xmpp:get_to(Pkt), - case {From, To} of - {#jid{}, #jid{}} -> - F(From, To, Pkt, Time); - {_, _} -> - ?DEBUG("Dropping stanza due to invalid JID(s)", []) - end - end, queue:to_list(Queue)) - end; -handle_unacked_stanzas(_StateData, _F) -> - ok. - --spec handle_unacked_stanzas(state()) -> ok. -handle_unacked_stanzas(#state{mgmt_state = MgmtState} = StateData) - when MgmtState == active; - MgmtState == pending; - MgmtState == timeout -> - ResendOnTimeout = - case StateData#state.mgmt_resend of - Resend when is_boolean(Resend) -> - Resend; - if_offline -> - Resource = StateData#state.resource, - case ejabberd_sm:get_user_resources(StateData#state.user, - StateData#state.server) of - [Resource] -> % Same resource opened new session - true; - [] -> - true; - _ -> - false - end - end, - Lang = StateData#state.lang, - ReRoute = case ResendOnTimeout of - true -> - fun(From, To, El, Time) -> - NewEl = add_resent_delay_info(StateData, El, Time), - ejabberd_router:route(From, To, NewEl) - end; - false -> - fun(From, To, El, _Time) -> - Txt = <<"User session terminated">>, - ejabberd_router:route_error( - To, From, El, xmpp:err_service_unavailable(Txt, Lang)) - end - end, - F = fun(From, _To, #presence{}, _Time) -> - ?DEBUG("Dropping presence stanza from ~s", - [jid:to_string(From)]); - (From, To, #iq{} = El, _Time) -> - Txt = <<"User session terminated">>, - ejabberd_router:route_error( - To, From, El, xmpp:err_service_unavailable(Txt, Lang)); - (From, _To, #message{meta = #{carbon_copy := true}}, _Time) -> - %% XEP-0280 says: "When a receiving server attempts to deliver a - %% forked message, and that message bounces with an error for - %% any reason, the receiving server MUST NOT forward that error - %% back to the original sender." Resending such a stanza could - %% easily lead to unexpected results as well. - ?DEBUG("Dropping forwarded message stanza from ~s", - [jid:to_string(From)]); - (From, To, El, Time) -> - case ejabberd_hooks:run_fold(message_is_archived, - StateData#state.server, false, - [StateData, From, - StateData#state.jid, El]) of - true -> - ?DEBUG("Dropping archived message stanza from ~p", - [jid:to_string(xmpp:get_from(El))]); - false -> - ReRoute(From, To, El, Time) - end - end, - handle_unacked_stanzas(StateData, F); -handle_unacked_stanzas(_StateData) -> - ok. - --spec inherit_session_state(state(), binary()) -> {ok, state()} | - {error, binary()} | - {error, binary(), non_neg_integer()}. -inherit_session_state(#state{user = U, server = S} = StateData, ResumeID) -> - case jlib:base64_to_term(ResumeID) of - {term, {R, Time}} -> - case ejabberd_sm:get_session_pid(U, S, R) of - none -> - case ejabberd_sm:get_offline_info(Time, U, S, R) of - none -> - {error, <<"Previous session PID not found">>}; - Info -> - case proplists:get_value(num_stanzas_in, Info) of - undefined -> - {error, <<"Previous session timed out">>}; - H -> - {error, <<"Previous session timed out">>, H} - end - end; - OldPID -> - OldSID = {Time, OldPID}, - case catch resume_session(OldSID) of - {resume, OldStateData} -> - NewSID = {Time, self()}, % Old time, new PID - Priority = case OldStateData#state.pres_last of - undefined -> - 0; - Presence -> - get_priority_from_presence(Presence) - end, - Conn = get_conn_type(StateData), - Info = [{ip, StateData#state.ip}, {conn, Conn}, - {auth_module, StateData#state.auth_module}], - ejabberd_sm:open_session(NewSID, U, S, R, - Priority, Info), - {ok, StateData#state{conn = Conn, - sid = NewSID, - jid = OldStateData#state.jid, - resource = OldStateData#state.resource, - pres_t = OldStateData#state.pres_t, - pres_f = OldStateData#state.pres_f, - pres_a = OldStateData#state.pres_a, - pres_last = OldStateData#state.pres_last, - pres_timestamp = OldStateData#state.pres_timestamp, - privacy_list = OldStateData#state.privacy_list, - aux_fields = OldStateData#state.aux_fields, - mgmt_xmlns = OldStateData#state.mgmt_xmlns, - mgmt_queue = OldStateData#state.mgmt_queue, - mgmt_timeout = OldStateData#state.mgmt_timeout, - mgmt_stanzas_in = OldStateData#state.mgmt_stanzas_in, - mgmt_stanzas_out = OldStateData#state.mgmt_stanzas_out, - mgmt_state = active, - csi_state = active}, Info}; - {error, Msg} -> - {error, Msg}; - _ -> - {error, <<"Cannot grab session state">>} - end - end; - _ -> - {error, <<"Invalid 'previd' value">>} - end. - --spec resume_session({integer(), pid()}) -> any(). -resume_session({Time, PID}) -> - (?GEN_FSM):sync_send_all_state_event(PID, {resume_session, Time}, 15000). - --spec make_resume_id(state()) -> binary(). -make_resume_id(StateData) -> - {Time, _} = StateData#state.sid, - jlib:term_to_base64({StateData#state.resource, Time}). - --spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza(). -add_resent_delay_info(_State, #iq{} = El, _Time) -> - El; -add_resent_delay_info(#state{server = From}, El, Time) -> - xmpp_util:add_delay_info(El, jid:make(From), Time, <<"Resent">>). - -%%%---------------------------------------------------------------------- -%%% XEP-0352 -%%%---------------------------------------------------------------------- --spec csi_filter_stanza(state(), stanza()) -> state(). -csi_filter_stanza(#state{csi_state = CsiState, jid = JID, server = Server} = - StateData, Stanza) -> - {StateData1, Stanzas} = ejabberd_hooks:run_fold(csi_filter_stanza, Server, - {StateData, [Stanza]}, - [Server, JID, Stanza]), - StateData2 = lists:foldl(fun(CurStanza, AccState) -> - send_stanza(AccState, CurStanza) - end, StateData1#state{csi_state = active}, - Stanzas), - StateData2#state{csi_state = CsiState}. - --spec csi_flush_queue(state()) -> state(). -csi_flush_queue(#state{csi_state = CsiState, jid = JID, server = Server} = - StateData) -> - {StateData1, Stanzas} = ejabberd_hooks:run_fold(csi_flush_queue, Server, - {StateData, []}, - [Server, JID]), - StateData2 = lists:foldl(fun(CurStanza, AccState) -> - send_stanza(AccState, CurStanza) - end, StateData1#state{csi_state = active}, - Stanzas), - StateData2#state{csi_state = CsiState}. - -%%%---------------------------------------------------------------------- -%%% JID Set memory footprint reduction code -%%%---------------------------------------------------------------------- - -%% Try to reduce the heap footprint of the four presence sets -%% by ensuring that we re-use strings and Jids wherever possible. --spec pack(state()) -> state(). -pack(S = #state{pres_a = A, pres_f = F, - pres_t = T}) -> - {NewA, Pack2} = pack_jid_set(A, gb_trees:empty()), - {NewF, Pack3} = pack_jid_set(F, Pack2), - {NewT, _Pack4} = pack_jid_set(T, Pack3), - S#state{pres_a = NewA, pres_f = NewF, - pres_t = NewT}. - -pack_jid_set(Set, Pack) -> - Jids = (?SETS):to_list(Set), - {PackedJids, NewPack} = pack_jids(Jids, Pack, []), - {(?SETS):from_list(PackedJids), NewPack}. - -pack_jids([], Pack, Acc) -> {Acc, Pack}; -pack_jids([{U, S, R} = Jid | Jids], Pack, Acc) -> - case gb_trees:lookup(Jid, Pack) of - {value, PackedJid} -> - pack_jids(Jids, Pack, [PackedJid | Acc]); - none -> - {NewU, Pack1} = pack_string(U, Pack), - {NewS, Pack2} = pack_string(S, Pack1), - {NewR, Pack3} = pack_string(R, Pack2), - NewJid = {NewU, NewS, NewR}, - NewPack = gb_trees:insert(NewJid, NewJid, Pack3), - pack_jids(Jids, NewPack, [NewJid | Acc]) - end. - -pack_string(String, Pack) -> - case gb_trees:lookup(String, Pack) of - {value, PackedString} -> {PackedString, Pack}; - none -> {String, gb_trees:insert(String, String, Pack)} - end. +-spec format_reason(state(), term()) -> binary(). +format_reason(#{stop_reason := Reason}, _) -> + xmpp_stream_in:format_error(Reason); +format_reason(_, normal) -> + <<"unknown reason">>; +format_reason(_, shutdown) -> + <<"stopped by supervisor">>; +format_reason(_, {shutdown, _}) -> + <<"stopped by supervisor">>; +format_reason(_, _) -> + <<"internal server error">>. transform_listen_option(Opt, Opts) -> [Opt|Opts]. --spec identity([{atom(), binary()}]) -> binary(). -identity(Props) -> - case proplists:get_value(authzid, Props, <<>>) of - <<>> -> proplists:get_value(username, Props, <<>>); - AuthzId -> AuthzId - end. - opt_type(domain_certfile) -> fun iolist_to_binary/1; -opt_type(max_fsm_queue) -> - fun (I) when is_integer(I), I > 0 -> I end; +opt_type(c2s_certfile) -> fun iolist_to_binary/1; +opt_type(c2s_ciphers) -> fun iolist_to_binary/1; +opt_type(c2s_dhfile) -> fun iolist_to_binary/1; +opt_type(c2s_cafile) -> fun iolist_to_binary/1; +opt_type(c2s_protocol_options) -> + fun (Options) -> str:join(Options, <<"|">>) end; +opt_type(c2s_tls_compression) -> + fun (true) -> true; + (false) -> false + end; opt_type(resource_conflict) -> fun (setresource) -> setresource; (closeold) -> closeold; (closenew) -> closenew; (acceptnew) -> acceptnew end; +opt_type(disable_sasl_mechanisms) -> + fun (V) when is_list(V) -> + lists:map(fun (M) -> str:to_upper(M) end, V); + (V) -> [str:to_upper(V)] + end; opt_type(_) -> - [domain_certfile, max_fsm_queue, resource_conflict]. + [domain_certfile, c2s_certfile, c2s_ciphers, c2s_cafile, + c2s_protocol_options, c2s_tls_compression, resource_conflict, + disable_sasl_mechanisms]. diff --git a/src/ejabberd_cluster.erl b/src/ejabberd_cluster.erl index 17e21af94..a331a0084 100644 --- a/src/ejabberd_cluster.erl +++ b/src/ejabberd_cluster.erl @@ -28,6 +28,7 @@ %% API -export([get_nodes/0, call/4, multicall/3, multicall/4]). -export([join/1, leave/1, get_known_nodes/0]). +-export([node_id/0, get_node_by_id/1]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -108,3 +109,31 @@ leave([Master|_], Node) -> erlang:halt(0) end), ok. + +-spec node_id() -> binary(). +node_id() -> + integer_to_binary(erlang:phash2(node())). + +-spec get_node_by_id(binary()) -> node(). +get_node_by_id(Hash) -> + try binary_to_integer(Hash) of + I -> match_node_id(I) + catch _:_ -> + node() + end. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec match_node_id(integer()) -> node(). +match_node_id(I) -> + match_node_id(I, get_nodes()). + +-spec match_node_id(integer(), [node()]) -> node(). +match_node_id(I, [Node|Nodes]) -> + case erlang:phash2(Node) of + I -> Node; + _ -> match_node_id(I, Nodes) + end; +match_node_id(_I, []) -> + node(). diff --git a/src/ejabberd_config.erl b/src/ejabberd_config.erl index 989f21c43..e15da3319 100644 --- a/src/ejabberd_config.erl +++ b/src/ejabberd_config.erl @@ -35,10 +35,12 @@ get_version/0, get_myhosts/0, get_mylang/0, get_ejabberd_config_path/0, is_using_elixir_config/0, prepare_opt_val/4, convert_table_to_binary/5, - transform_options/1, collect_options/1, default_db/2, + transform_options/1, collect_options/1, convert_to_yaml/1, convert_to_yaml/2, v_db/2, env_binary_to_list/2, opt_type/1, may_hide_data/1, - is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1]). + is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1, + default_db/1, default_db/2, default_ram_db/1, default_ram_db/2, + fsm_limit_opts/1]). -export([start/2]). @@ -906,11 +908,26 @@ v_dbs_mods(Mod) -> (atom_to_binary(M, utf8))/binary>>, utf8) end, ets:match(module_db, {Mod, '$1'})). --spec default_db(binary(), module()) -> atom(). +-spec default_db(module()) -> atom(). +default_db(Module) -> + default_db(global, Module). +-spec default_db(binary(), module()) -> atom(). default_db(Host, Module) -> + default_db(default_db, Host, Module). + +-spec default_ram_db(module()) -> atom(). +default_ram_db(Module) -> + default_ram_db(global, Module). + +-spec default_ram_db(binary(), module()) -> atom(). +default_ram_db(Host, Module) -> + default_db(default_ram_db, Host, Module). + +-spec default_db(default_db | default_ram_db, binary(), module()) -> atom(). +default_db(Opt, Host, Module) -> case ejabberd_config:get_option( - {default_db, Host}, fun(T) when is_atom(T) -> T end) of + {Opt, Host}, fun(T) when is_atom(T) -> T end) of undefined -> mnesia; DBType -> @@ -918,8 +935,8 @@ default_db(Host, Module) -> v_db(Module, DBType) catch error:badarg -> ?WARNING_MSG("Module '~s' doesn't support database '~s' " - "defined in option 'default_db', using " - "'mnesia' as fallback", [Module, DBType]), + "defined in option '~s', using " + "'mnesia' as fallback", [Module, DBType, Opt]), mnesia end end. @@ -1405,8 +1422,15 @@ opt_type(hosts) -> end; opt_type(language) -> fun iolist_to_binary/1; +opt_type(max_fsm_queue) -> + fun (I) when is_integer(I), I > 0 -> I end; +opt_type(default_db) -> + fun(T) when is_atom(T) -> T end; +opt_type(default_ram_db) -> + fun(T) when is_atom(T) -> T end; opt_type(_) -> - [hide_sensitive_log_data, hosts, language]. + [hide_sensitive_log_data, hosts, language, + default_db, default_ram_db]. -spec may_hide_data(string()) -> string(); (binary()) -> binary(). @@ -1423,3 +1447,17 @@ may_hide_data(Data) -> true -> "hidden_by_ejabberd" end. + +-spec fsm_limit_opts([proplists:property()]) -> [{max_queue, pos_integer()}]. +fsm_limit_opts(Opts) -> + case lists:keyfind(max_fsm_queue, 1, Opts) of + {_, I} when is_integer(I), I>0 -> + [{max_queue, I}]; + false -> + case get_option( + max_fsm_queue, + fun(I) when is_integer(I), I>0 -> I end) of + undefined -> []; + N -> [{max_queue, N}] + end + end. diff --git a/src/ejabberd_frontend_socket.erl b/src/ejabberd_frontend_socket.erl deleted file mode 100644 index cc0223741..000000000 --- a/src/ejabberd_frontend_socket.erl +++ /dev/null @@ -1,261 +0,0 @@ -%%%------------------------------------------------------------------- -%%% File : ejabberd_frontend_socket.erl -%%% Author : Alexey Shchepin -%%% Purpose : Frontend socket with zlib and TLS support library -%%% Created : 23 Aug 2006 by Alexey Shchepin -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2017 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(ejabberd_frontend_socket). - --author('alexey@process-one.net'). - --behaviour(gen_server). - -%% API --export([start/4, - start_link/5, - %connect/3, - starttls/2, - starttls/3, - compress/1, - compress/2, - reset_stream/1, - send/2, - change_shaper/2, - monitor/1, - get_sockmod/1, - get_transport/1, - get_peer_certificate/1, - get_verify_result/1, - close/1, - sockname/1, peername/1]). - -%% gen_server callbacks --export([init/1, handle_call/3, handle_cast/2, - handle_info/2, terminate/2, code_change/3]). - --record(state, {sockmod, socket, receiver}). - --define(HIBERNATE_TIMEOUT, 90000). - -%%==================================================================== -%% API -%%==================================================================== -start_link(Module, SockMod, Socket, Opts, Receiver) -> - gen_server:start_link(?MODULE, - [Module, SockMod, Socket, Opts, Receiver], []). - -start(Module, SockMod, Socket, Opts) -> - case Module:socket_type() of - xml_stream -> - MaxStanzaSize = case lists:keysearch(max_stanza_size, 1, - Opts) - of - {value, {_, Size}} -> Size; - _ -> infinity - end, - Receiver = ejabberd_receiver:start(Socket, SockMod, - none, MaxStanzaSize), - case SockMod:controlling_process(Socket, Receiver) of - ok -> ok; - {error, _Reason} -> SockMod:close(Socket) - end, - supervisor:start_child(ejabberd_frontend_socket_sup, - [Module, SockMod, Socket, Opts, Receiver]); - raw -> - %{ok, Pid} = Module:start({SockMod, Socket}, Opts), - %case SockMod:controlling_process(Socket, Pid) of - % ok -> - % ok; - % {error, _Reason} -> - % SockMod:close(Socket) - %end - todo - end. - -starttls(FsmRef, _TLSOpts) -> - %% TODO: Frontend improvements planned by Aleksey - %%gen_server:call(FsmRef, {starttls, TLSOpts}), - FsmRef. - -starttls(FsmRef, TLSOpts, Data) -> - gen_server:call(FsmRef, {starttls, TLSOpts, Data}), - FsmRef. - -compress(FsmRef) -> compress(FsmRef, undefined). - -compress(FsmRef, Data) -> - gen_server:call(FsmRef, {compress, Data}), FsmRef. - -reset_stream(FsmRef) -> - gen_server:call(FsmRef, reset_stream). - -send(FsmRef, Data) -> - gen_server:call(FsmRef, {send, Data}). - -change_shaper(FsmRef, Shaper) -> - gen_server:call(FsmRef, {change_shaper, Shaper}). - -monitor(FsmRef) -> erlang:monitor(process, FsmRef). - -get_sockmod(FsmRef) -> - gen_server:call(FsmRef, get_sockmod). - -get_transport(FsmRef) -> - gen_server:call(FsmRef, get_transport). - -get_peer_certificate(FsmRef) -> - gen_server:call(FsmRef, get_peer_certificate). - -get_verify_result(FsmRef) -> - gen_server:call(FsmRef, get_verify_result). - -close(FsmRef) -> gen_server:call(FsmRef, close). - -sockname(FsmRef) -> gen_server:call(FsmRef, sockname). - -peername(_FsmRef) -> - %% TODO: Frontend improvements planned by Aleksey - %%gen_server:call(FsmRef, peername). - {ok, {{0, 0, 0, 0}, 0}}. - -%%==================================================================== -%% gen_server callbacks -%%==================================================================== - -init([Module, SockMod, Socket, Opts, Receiver]) -> - Node = ejabberd_node_groups:get_closest_node(backend), - {SockMod2, Socket2} = check_starttls(SockMod, Socket, Receiver, Opts), - {ok, Pid} = - rpc:call(Node, Module, start, [{?MODULE, self()}, Opts]), - ejabberd_receiver:become_controller(Receiver, Pid), - {ok, #state{sockmod = SockMod2, - socket = Socket2, - receiver = Receiver}}. - -handle_call({starttls, TLSOpts}, _From, State) -> - {ok, TLSSocket} = fast_tls:tcp_to_tls(State#state.socket, TLSOpts), - ejabberd_receiver:starttls(State#state.receiver, TLSSocket), - Reply = ok, - {reply, Reply, State#state{socket = TLSSocket, sockmod = fast_tls}, - ?HIBERNATE_TIMEOUT}; - -handle_call({starttls, TLSOpts, Data}, _From, State) -> - {ok, TLSSocket} = fast_tls:tcp_to_tls(State#state.socket, TLSOpts), - ejabberd_receiver:starttls(State#state.receiver, TLSSocket), - catch (State#state.sockmod):send( - State#state.socket, Data), - Reply = ok, - {reply, Reply, - State#state{socket = TLSSocket, sockmod = fast_tls}, - ?HIBERNATE_TIMEOUT}; -handle_call({compress, Data}, _From, State) -> - {ok, ZlibSocket} = - ejabberd_receiver:compress(State#state.receiver, Data), - Reply = ok, - {reply, Reply, - State#state{socket = ZlibSocket, sockmod = ezlib}, - ?HIBERNATE_TIMEOUT}; -handle_call(reset_stream, _From, State) -> - ejabberd_receiver:reset_stream(State#state.receiver), - Reply = ok, - {reply, Reply, State, ?HIBERNATE_TIMEOUT}; -handle_call({send, Data}, _From, State) -> - catch (State#state.sockmod):send(State#state.socket, Data), - Reply = ok, - {reply, Reply, State, ?HIBERNATE_TIMEOUT}; -handle_call({change_shaper, Shaper}, _From, State) -> - ejabberd_receiver:change_shaper(State#state.receiver, - Shaper), - Reply = ok, - {reply, Reply, State, ?HIBERNATE_TIMEOUT}; -handle_call(get_sockmod, _From, State) -> - Reply = State#state.sockmod, - {reply, Reply, State, ?HIBERNATE_TIMEOUT}; -handle_call(get_transport, _From, State) -> - Reply = case State#state.sockmod of - gen_tcp -> tcp; - fast_tls -> tls; - ezlib -> - case ezlib:get_sockmod(State#state.socket) of - tcp -> tcp_zlib; - tls -> tls_zlib - end; - ejabberd_http_bind -> http_bind; - ejabberd_http_ws -> websocket - end, - {reply, Reply, State, ?HIBERNATE_TIMEOUT}; -handle_call(get_peer_certificate, _From, State) -> - Reply = fast_tls:get_peer_certificate(State#state.socket), - {reply, Reply, State, ?HIBERNATE_TIMEOUT}; -handle_call(get_verify_result, _From, State) -> - Reply = fast_tls:get_verify_result(State#state.socket), - {reply, Reply, State, ?HIBERNATE_TIMEOUT}; -handle_call(close, _From, State) -> - ejabberd_receiver:close(State#state.receiver), - Reply = ok, - {stop, normal, Reply, State}; -handle_call(sockname, _From, State) -> - #state{sockmod = SockMod, socket = Socket} = State, - Reply = - case SockMod of - gen_tcp -> - inet:sockname(Socket); - _ -> - SockMod:sockname(Socket) - end, - {reply, Reply, State, ?HIBERNATE_TIMEOUT}; -handle_call(peername, _From, State) -> - #state{sockmod = SockMod, socket = Socket} = State, - Reply = case SockMod of - gen_tcp -> inet:peername(Socket); - _ -> SockMod:peername(Socket) - end, - {reply, Reply, State, ?HIBERNATE_TIMEOUT}; -handle_call(_Request, _From, State) -> - Reply = ok, {reply, Reply, State, ?HIBERNATE_TIMEOUT}. - -handle_cast(_Msg, State) -> - {noreply, State, ?HIBERNATE_TIMEOUT}. - -handle_info(timeout, State) -> - proc_lib:hibernate(gen_server, enter_loop, - [?MODULE, [], State]), - {noreply, State, ?HIBERNATE_TIMEOUT}; -handle_info(_Info, State) -> - {noreply, State, ?HIBERNATE_TIMEOUT}. - -terminate(_Reason, _State) -> ok. - -code_change(_OldVsn, State, _Extra) -> {ok, State}. - -check_starttls(SockMod, Socket, Receiver, Opts) -> - TLSEnabled = proplists:get_bool(tls, Opts), - TLSOpts = lists:filter(fun({certfile, _}) -> true; - (_) -> false - end, Opts), - if TLSEnabled -> - {ok, TLSSocket} = fast_tls:tcp_to_tls(Socket, TLSOpts), - ejabberd_receiver:starttls(Receiver, TLSSocket), - {fast_tls, TLSSocket}; - true -> - {SockMod, Socket} - end. diff --git a/src/ejabberd_hooks.erl b/src/ejabberd_hooks.erl index 589b0d6a3..9f782b235 100644 --- a/src/ejabberd_hooks.erl +++ b/src/ejabberd_hooks.erl @@ -326,10 +326,9 @@ run1([{_Seq, Node, Module, Function} | Ls], Hook, Args) -> run1(Ls, Hook, Args) end; run1([{_Seq, Module, Function} | Ls], Hook, Args) -> - Res = safe_apply(Module, Function, Args), + Res = safe_apply(Hook, Module, Function, Args), case Res of - {'EXIT', Reason} -> - ?ERROR_MSG("~p~nrunning hook: ~p", [Reason, {Hook, Args}]), + 'EXIT' -> run1(Ls, Hook, Args); stop -> ok; @@ -362,10 +361,9 @@ run_fold1([{_Seq, Node, Module, Function} | Ls], Hook, Val, Args) -> run_fold1(Ls, Hook, NewVal, Args) end; run_fold1([{_Seq, Module, Function} | Ls], Hook, Val, Args) -> - Res = safe_apply(Module, Function, [Val | Args]), + Res = safe_apply(Hook, Module, Function, [Val | Args]), case Res of - {'EXIT', Reason} -> - ?ERROR_MSG("~p~nrunning hook: ~p", [Reason, {Hook, Args}]), + 'EXIT' -> run_fold1(Ls, Hook, Val, Args); stop -> stopped; @@ -375,9 +373,20 @@ run_fold1([{_Seq, Module, Function} | Ls], Hook, Val, Args) -> run_fold1(Ls, Hook, NewVal, Args) end. -safe_apply(Module, Function, Args) -> - if is_function(Function) -> - catch apply(Function, Args); +safe_apply(Hook, Module, Function, Args) -> + try if is_function(Function) -> + apply(Function, Args); true -> - catch apply(Module, Function, Args) + apply(Module, Function, Args) + end + catch E:R when E /= exit, R /= normal -> + ?ERROR_MSG("Hook ~p crashed when running ~p:~p/~p:~n" + "** Reason = ~p~n" + "** Arguments = ~p", + [Hook, Module, Function, length(Args), + {E, R, get_stacktrace()}, Args]), + 'EXIT' end. + +get_stacktrace() -> + [{Mod, Fun, Loc, Args} || {Mod, Fun, Args, Loc} <- erlang:get_stacktrace()]. diff --git a/src/ejabberd_http.erl b/src/ejabberd_http.erl index 84de6da33..bd3291508 100644 --- a/src/ejabberd_http.erl +++ b/src/ejabberd_http.erl @@ -327,7 +327,7 @@ add_header(Name, Value, State)-> get_host_really_served(undefined, Provided) -> Provided; get_host_really_served(Default, Provided) -> - case lists:member(Provided, ?MYHOSTS) of + case ejabberd_router:is_my_host(Provided) of true -> Provided; false -> Default end. diff --git a/src/ejabberd_http_ws.erl b/src/ejabberd_http_ws.erl index 4a4677aee..6ac257c91 100644 --- a/src/ejabberd_http_ws.erl +++ b/src/ejabberd_http_ws.erl @@ -120,7 +120,7 @@ init([{#ws{ip = IP, http_opts = HOpts}, _} = WS]) -> ({resend_on_timeout, _}) -> true; (_) -> false end, HOpts), - Opts = [{xml_socket, true} | ejabberd_c2s_config:get_c2s_limits() ++ SOpts], + Opts = ejabberd_c2s_config:get_c2s_limits() ++ SOpts, PingInterval = ejabberd_config:get_option( {websocket_ping_interval, ?MYNAME}, fun(I) when is_integer(I), I>=0 -> I end, diff --git a/src/ejabberd_listener.erl b/src/ejabberd_listener.erl index 66a2775a7..bad1da134 100644 --- a/src/ejabberd_listener.erl +++ b/src/ejabberd_listener.erl @@ -56,8 +56,7 @@ bind_tcp_ports() -> Ls -> lists:foreach( fun({Port, Module, Opts}) -> - ModuleRaw = strip_frontend(Module), - case ModuleRaw:socket_type() of + case Module:socket_type() of independent -> ok; _ -> bind_tcp_port(Port, Module, Opts) @@ -112,9 +111,8 @@ report_duplicated_portips(L) -> start(Port, Module, Opts) -> %% Check if the module is an ejabberd listener or an independent listener - ModuleRaw = strip_frontend(Module), - case ModuleRaw:socket_type() of - independent -> ModuleRaw:start_listener(Port, Opts); + case Module:socket_type() of + independent -> Module:start_listener(Port, Opts); _ -> start_dependent(Port, Module, Opts) end. @@ -186,7 +184,9 @@ init_tcp(PortIP, Module, Opts, SockOpts, Port, IPS) -> listen_tcp(PortIP, Module, SockOpts, Port, IPS) -> case ets:lookup(listen_sockets, PortIP) of [{PortIP, ListenSocket}] -> - ?INFO_MSG("Reusing listening port for ~p", [PortIP]), + {_, _, Transport} = PortIP, + ?INFO_MSG("Reusing listening ~s port ~p at ~s", + [Transport, Port, IPS]), ets:delete(listen_sockets, PortIP), ListenSocket; _ -> @@ -330,21 +330,22 @@ accept(ListenSocket, Module, Opts, Interval) -> {ok, Socket} -> case {inet:sockname(Socket), inet:peername(Socket)} of {{ok, {Addr, Port}}, {ok, {PAddr, PPort}}} -> - ?INFO_MSG("(~w) Accepted connection ~s:~p -> ~s:~p", - [Socket, ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)), PPort, - inet_parse:ntoa(Addr), Port]); + Receiver = case ejabberd_socket:start(Module, + gen_tcp, Socket, Opts) of + {ok, RecvPid} -> RecvPid; + _ -> none + end, + ?INFO_MSG("(~p) Accepted connection ~s:~p -> ~s:~p", + [Receiver, + ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)), + PPort, inet_parse:ntoa(Addr), Port]); _ -> ok end, - CallMod = case is_frontend(Module) of - true -> ejabberd_frontend_socket; - false -> ejabberd_socket - end, - CallMod:start(strip_frontend(Module), gen_tcp, Socket, Opts), accept(ListenSocket, Module, Opts, NewInterval); {error, Reason} -> - ?ERROR_MSG("(~w) Failed TCP accept: ~w", - [ListenSocket, Reason]), + ?ERROR_MSG("(~w) Failed TCP accept: ~s", + [ListenSocket, inet:format_error(Reason)]), accept(ListenSocket, Module, Opts, NewInterval) end. @@ -394,7 +395,7 @@ start_module_sup(_Port, Module) -> Proc1 = gen_mod:get_module_proc(<<"sup">>, Module), ChildSpec1 = {Proc1, - {ejabberd_tmp_sup, start_link, [Proc1, strip_frontend(Module)]}, + {ejabberd_tmp_sup, start_link, [Proc1, Module]}, permanent, infinity, supervisor, @@ -489,18 +490,6 @@ delete_listener(PortIP, Module, Opts) -> stop_listener(PortIP1, Module). --spec is_frontend({frontend, module} | module()) -> boolean(). - -is_frontend({frontend, _Module}) -> true; -is_frontend(_) -> false. - -%% @doc(FrontMod) -> atom() -%% where FrontMod = atom() | {frontend, atom()} --spec strip_frontend({frontend, module()} | module()) -> module(). - -strip_frontend({frontend, Module}) -> Module; -strip_frontend(Module) when is_atom(Module) -> Module. - maybe_start_sip(esip_socket) -> ejabberd:start_app(esip); maybe_start_sip(_) -> diff --git a/src/ejabberd_local.erl b/src/ejabberd_local.erl index 9c7345e78..b5e1d8abc 100644 --- a/src/ejabberd_local.erl +++ b/src/ejabberd_local.erl @@ -30,14 +30,13 @@ -behaviour(gen_server). %% API --export([start_link/0]). +-export([start/0, start_link/0]). -export([route/3, route_iq/4, route_iq/5, process_iq/3, process_iq_reply/3, register_iq_handler/4, register_iq_handler/5, register_iq_response_handler/4, register_iq_response_handler/5, unregister_iq_handler/2, - unregister_iq_response_handler/2, refresh_iq_handlers/0, - bounce_resource_packet/3]). + unregister_iq_response_handler/2, bounce_resource_packet/3]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, @@ -69,6 +68,11 @@ %% Function: start_link() -> {ok,Pid} | ignore | {error,Error} %% Description: Starts the server %%-------------------------------------------------------------------- +start() -> + ChildSpec = {?MODULE, {?MODULE, start_link, []}, + transient, 1000, worker, [?MODULE]}, + supervisor:start_child(ejabberd_sup, ChildSpec). + start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). @@ -90,8 +94,13 @@ process_iq(From, To, #iq{type = T, lang = Lang, sub_els = [El]} = Packet) Err = xmpp:err_service_unavailable(Txt, Lang), ejabberd_router:route_error(To, From, Packet, Err) end; -process_iq(From, To, #iq{type = T} = Packet) when T == get; T == set -> - Err = xmpp:err_bad_request(), +process_iq(From, To, #iq{type = T, lang = Lang, sub_els = SubEls} = Packet) + when T == get; T == set -> + Txt = case SubEls of + [] -> <<"No child elements found">>; + _ -> <<"Too many child elements">> + end, + Err = xmpp:err_bad_request(Txt, Lang), ejabberd_router:route_error(To, From, Packet, Err); process_iq(From, To, #iq{type = T} = Packet) when T == result; T == error -> process_iq_reply(From, To, Packet). @@ -171,10 +180,6 @@ unregister_iq_response_handler(_Host, ID) -> unregister_iq_handler(Host, XMLNS) -> ejabberd_local ! {unregister_iq_handler, Host, XMLNS}. --spec refresh_iq_handlers() -> any(). -refresh_iq_handlers() -> - ejabberd_local ! refresh_iq_handlers. - -spec bounce_resource_packet(jid(), jid(), stanza()) -> stop. bounce_resource_packet(_From, #jid{lresource = <<"">>}, #presence{}) -> ok; @@ -228,14 +233,12 @@ handle_info({register_iq_handler, Host, XMLNS, Module, Function}, State) -> ets:insert(?IQTABLE, {{XMLNS, Host}, Module, Function}), - catch mod_disco:register_feature(Host, XMLNS), {noreply, State}; handle_info({register_iq_handler, Host, XMLNS, Module, Function, Opts}, State) -> ets:insert(?IQTABLE, {{XMLNS, Host}, Module, Function, Opts}), - catch mod_disco:register_feature(Host, XMLNS), {noreply, State}; handle_info({unregister_iq_handler, Host, XMLNS}, State) -> @@ -245,19 +248,6 @@ handle_info({unregister_iq_handler, Host, XMLNS}, _ -> ok end, ets:delete(?IQTABLE, {XMLNS, Host}), - catch mod_disco:unregister_feature(Host, XMLNS), - {noreply, State}; -handle_info(refresh_iq_handlers, State) -> - lists:foreach(fun (T) -> - case T of - {{XMLNS, Host}, _Module, _Function, _Opts} -> - catch mod_disco:register_feature(Host, XMLNS); - {{XMLNS, Host}, _Module, _Function} -> - catch mod_disco:register_feature(Host, XMLNS); - _ -> ok - end - end, - ets:tab2list(?IQTABLE)), {noreply, State}; handle_info({timeout, _TRef, ID}, State) -> process_iq_timeout(ID), diff --git a/src/ejabberd_node_groups.erl b/src/ejabberd_node_groups.erl deleted file mode 100644 index f44df73fd..000000000 --- a/src/ejabberd_node_groups.erl +++ /dev/null @@ -1,173 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : ejabberd_node_groups.erl -%%% Author : Alexey Shchepin -%%% Purpose : Distributed named node groups based on pg2 module -%%% Created : 1 Nov 2006 by Alexey Shchepin -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2017 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(ejabberd_node_groups). - --behaviour(ejabberd_config). --author('alexey@process-one.net'). - --behaviour(gen_server). - -%% API --export([start_link/0, - join/1, - leave/1, - get_members/1, - get_closest_node/1]). - --export([init/1, handle_call/3, handle_cast/2, - handle_info/2, terminate/2, code_change/3, opt_type/1]). - --define(PG2, pg2). - --record(state, {}). - -%%==================================================================== -%% API -%%==================================================================== -%%-------------------------------------------------------------------- -%% Function: start_link() -> {ok,Pid} | ignore | {error,Error} -%% Description: Starts the server -%%-------------------------------------------------------------------- -start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). - -join(Name) -> - PG = {?MODULE, Name}, - pg2:create(PG), - pg2:join(PG, whereis(?MODULE)). - -leave(Name) -> - PG = {?MODULE, Name}, - pg2:leave(PG, whereis(?MODULE)). - -get_members(Name) -> - PG = {?MODULE, Name}, - [node(P) || P <- pg2:get_members(PG)]. - -get_closest_node(Name) -> - PG = {?MODULE, Name}, - node(pg2:get_closest_pid(PG)). - -%%==================================================================== -%% gen_server callbacks -%%==================================================================== - -%%-------------------------------------------------------------------- -%% Function: init(Args) -> {ok, State} | -%% {ok, State, Timeout} | -%% ignore | -%% {stop, Reason} -%% Description: Initiates the server -%%-------------------------------------------------------------------- -init([]) -> - {FE, BE} = - case ejabberd_config:get_option( - node_type, - fun(frontend) -> frontend; - (backend) -> backend; - (generic) -> generic - end, generic) of - frontend -> - {true, false}; - backend -> - {false, true}; - generic -> - {true, true}; - undefined -> - {true, true} - end, - if - FE -> - join(frontend); - true -> - ok - end, - if - BE -> - join(backend); - true -> - ok - end, - {ok, #state{}}. - -%%-------------------------------------------------------------------- -%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | -%% {reply, Reply, State, Timeout} | -%% {noreply, State} | -%% {noreply, State, Timeout} | -%% {stop, Reason, Reply, State} | -%% {stop, Reason, State} -%% Description: Handling call messages -%%-------------------------------------------------------------------- -handle_call(_Request, _From, State) -> - Reply = ok, - {reply, Reply, State}. - -%%-------------------------------------------------------------------- -%% Function: handle_cast(Msg, State) -> {noreply, State} | -%% {noreply, State, Timeout} | -%% {stop, Reason, State} -%% Description: Handling cast messages -%%-------------------------------------------------------------------- -handle_cast(_Msg, State) -> - {noreply, State}. - -%%-------------------------------------------------------------------- -%% Function: handle_info(Info, State) -> {noreply, State} | -%% {noreply, State, Timeout} | -%% {stop, Reason, State} -%% Description: Handling all non call/cast messages -%%-------------------------------------------------------------------- -handle_info(_Info, State) -> - {noreply, State}. - -%%-------------------------------------------------------------------- -%% Function: terminate(Reason, State) -> void() -%% Description: This function is called by a gen_server when it is about to -%% terminate. It should be the opposite of Module:init/1 and do any necessary -%% cleaning up. When it returns, the gen_server terminates with Reason. -%% The return value is ignored. -%%-------------------------------------------------------------------- -terminate(_Reason, _State) -> - ok. - -%%-------------------------------------------------------------------- -%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} -%% Description: Convert process state when code is changed -%%-------------------------------------------------------------------- -code_change(_OldVsn, State, _Extra) -> - {ok, State}. - -%%-------------------------------------------------------------------- -%%% Internal functions -%%-------------------------------------------------------------------- - -opt_type(node_type) -> - fun (frontend) -> frontend; - (backend) -> backend; - (generic) -> generic - end; -opt_type(_) -> [node_type]. diff --git a/src/ejabberd_piefxis.erl b/src/ejabberd_piefxis.erl index 6eefc045b..1115f16cb 100644 --- a/src/ejabberd_piefxis.erl +++ b/src/ejabberd_piefxis.erl @@ -347,7 +347,7 @@ process_el({xmlstreamelement, #xmlel{name = <<"host">>, JIDS = fxml:get_attr_s(<<"jid">>, Attrs), case jid:from_string(JIDS) of #jid{lserver = S} -> - case lists:member(S, ?MYHOSTS) of + case ejabberd_router:is_my_host(S) of true -> process_users(Els, State#state{server = S}); false -> @@ -481,17 +481,16 @@ process_privacy(#privacy_query{lists = Lists, JID = jid:make(U, S), IQ = #iq{type = set, id = randoms:get_string(), from = JID, to = JID, sub_els = [PrivacyQuery]}, - Txt = <<"No module is handling this query">>, - Error = {error, xmpp:err_feature_not_implemented(Txt, ?MYLANG)}, - case mod_privacy:process_iq_set(Error, IQ, #userlist{}) of - {error, #stanza_error{reason = Reason}} = Err -> + case mod_privacy:process_iq(IQ) of + #iq{type = error} = ResIQ -> + #stanza_error{reason = Reason} = xmpp:get_error(ResIQ), if Reason == 'item-not-found', Lists == [], Active == undefined, Default /= undefined -> %% Failed to set default list because there is no %% list with such name. We shouldn't stop here. {ok, State}; true -> - stop("Failed to write privacy: ~p", [Err]) + stop("Failed to write privacy: ~p", [Reason]) end; _ -> {ok, State} diff --git a/src/ejabberd_receiver.erl b/src/ejabberd_receiver.erl index 4f9be17da..355fcbbd2 100644 --- a/src/ejabberd_receiver.erl +++ b/src/ejabberd_receiver.erl @@ -135,8 +135,8 @@ handle_call({starttls, TLSSocket}, _From, State) -> {ok, TLSData} -> {reply, ok, process_data(TLSData, NewState), ?HIBERNATE_TIMEOUT}; - {error, _Reason} -> - {stop, normal, ok, NewState} + {error, _} = Err -> + {stop, normal, Err, NewState} end; handle_call({compress, Data}, _From, #state{socket = Socket, sock_mod = SockMod} = diff --git a/src/ejabberd_router.erl b/src/ejabberd_router.erl index 17cb7e279..64a6234c8 100644 --- a/src/ejabberd_router.erl +++ b/src/ejabberd_router.erl @@ -5,7 +5,7 @@ %%% Created : 27 Nov 2002 by Alexey Shchepin %%% %%% -%%% ejabberd, Copyright (C) 2002-2017 ProcessOne +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne %%% %%% This program is free software; you can redistribute it and/or %%% modify it under the terms of the GNU General Public License as @@ -34,7 +34,6 @@ %% API -export([route/3, route_error/4, - register_route/1, register_route/2, register_route/3, register_routes/1, @@ -42,41 +41,48 @@ process_iq/3, unregister_route/1, unregister_routes/1, - dirty_get_all_routes/0, - dirty_get_all_domains/0 - ]). + get_all_routes/0, + is_my_route/1, + is_my_host/1, + get_backend/0]). --export([start_link/0]). +-export([start/0, start_link/0]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3, opt_type/1]). -include("ejabberd.hrl"). -include("logger.hrl"). - +-include("ejabberd_router.hrl"). -include("xmpp.hrl"). --type local_hint() :: undefined | integer() | {apply, atom(), atom()}. - --record(route, {domain, server_host, pid, local_hint}). +-callback init() -> any(). +-callback register_route(binary(), binary(), local_hint(), + undefined | pos_integer()) -> ok | {error, term()}. +-callback unregister_route(binary(), undefined | pos_integer()) -> ok | {error, term()}. +-callback find_routes(binary()) -> [#route{}]. +-callback host_of_route(binary()) -> {ok, binary()} | error. +-callback is_my_route(binary()) -> boolean(). +-callback is_my_host(binary()) -> boolean(). +-callback get_all_routes() -> [binary()]. -record(state, {}). %%==================================================================== %% API %%==================================================================== -%%-------------------------------------------------------------------- -%% Function: start_link() -> {ok,Pid} | ignore | {error,Error} -%% Description: Starts the server -%%-------------------------------------------------------------------- +start() -> + ChildSpec = {?MODULE, {?MODULE, start_link, []}, + transient, 1000, worker, [?MODULE]}, + supervisor:start_child(ejabberd_sup, ChildSpec). + start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). -spec route(jid(), jid(), xmlel() | stanza()) -> ok. - route(#jid{} = From, #jid{} = To, #xmlel{} = El) -> try xmpp:decode(El, ?NS_CLIENT, [ignore_els]) of - Pkt -> route(From, To, xmpp:set_from_to(Pkt, From, To)) + Pkt -> route(From, To, Pkt) catch _:{xmpp_codec, Why} -> ?ERROR_MSG("failed to decode xml element ~p when " "routing from ~s to ~s: ~s", @@ -96,7 +102,6 @@ route(#jid{} = From, #jid{} = To, Packet) -> %% RFC3920 9.3.1 -spec route_error(jid(), jid(), xmlel(), xmlel()) -> ok; (jid(), jid(), stanza(), stanza_error()) -> ok. - route_error(From, To, #xmlel{} = ErrPacket, #xmlel{} = OrigPacket) -> #xmlel{attrs = Attrs} = OrigPacket, case <<"error">> == fxml:get_attr_s(<<"type">>, Attrs) of @@ -111,222 +116,118 @@ route_error(From, To, Packet, #stanza_error{} = Err) -> ejabberd_router:route(From, To, xmpp:make_error(Packet, Err)) end. --spec register_route(binary()) -> term(). - -register_route(Domain) -> - ?WARNING_MSG("~s:register_route/1 is deprected, " - "use ~s:register_route/2 instead", - [?MODULE, ?MODULE]), - register_route(Domain, ?MYNAME). - --spec register_route(binary(), binary()) -> term(). - +-spec register_route(binary(), binary()) -> ok. register_route(Domain, ServerHost) -> register_route(Domain, ServerHost, undefined). --spec register_route(binary(), binary(), local_hint()) -> term(). - +-spec register_route(binary(), binary(), local_hint()) -> ok. register_route(Domain, ServerHost, LocalHint) -> case {jid:nameprep(Domain), jid:nameprep(ServerHost)} of - {error, _} -> erlang:error({invalid_domain, Domain}); - {_, error} -> erlang:error({invalid_domain, ServerHost}); - {LDomain, LServerHost} -> - Pid = self(), - case get_component_number(LDomain) of - undefined -> - F = fun () -> - mnesia:write(#route{domain = LDomain, pid = Pid, - server_host = LServerHost, - local_hint = LocalHint}) - end, - mnesia:transaction(F); - N -> - F = fun () -> - case mnesia:wread({route, LDomain}) of - [] -> - mnesia:write(#route{domain = LDomain, - server_host = LServerHost, - pid = Pid, - local_hint = 1}), - lists:foreach( - fun (I) -> - mnesia:write( - #route{domain = LDomain, - pid = undefined, - server_host = LServerHost, - local_hint = I}) - end, - lists:seq(2, N)); - Rs -> - lists:any( - fun (#route{pid = undefined, - local_hint = I} = R) -> - mnesia:write( - #route{domain = LDomain, - pid = Pid, - server_host = LServerHost, - local_hint = I}), - mnesia:delete_object(R), - true; - (_) -> false - end, - Rs) - end - end, - mnesia:transaction(F) - end + {error, _} -> + erlang:error({invalid_domain, Domain}); + {_, error} -> + erlang:error({invalid_domain, ServerHost}); + {LDomain, LServerHost} -> + Mod = get_backend(), + case Mod:register_route(LDomain, LServerHost, LocalHint, + get_component_number(LDomain)) of + ok -> + ?DEBUG("Route registered: ~s", [LDomain]); + {error, Err} -> + ?ERROR_MSG("Failed to register route ~s: ~p", + [LDomain, Err]) + end end. -spec register_routes([{binary(), binary()}]) -> ok. - register_routes(Domains) -> lists:foreach(fun ({Domain, ServerHost}) -> register_route(Domain, ServerHost) end, Domains). --spec unregister_route(binary()) -> term(). - +-spec unregister_route(binary()) -> ok. unregister_route(Domain) -> case jid:nameprep(Domain) of - error -> erlang:error({invalid_domain, Domain}); - LDomain -> - Pid = self(), - case get_component_number(LDomain) of - undefined -> - F = fun () -> - case mnesia:match_object(#route{domain = LDomain, - pid = Pid, _ = '_'}) - of - [R] -> mnesia:delete_object(R); - _ -> ok - end - end, - mnesia:transaction(F); - _ -> - F = fun () -> - case mnesia:match_object(#route{domain = LDomain, - pid = Pid, _ = '_'}) - of - [R] -> - I = R#route.local_hint, - ServerHost = R#route.server_host, - mnesia:write(#route{domain = LDomain, - server_host = ServerHost, - pid = undefined, - local_hint = I}), - mnesia:delete_object(R); - _ -> ok - end - end, - mnesia:transaction(F) - end + error -> + erlang:error({invalid_domain, Domain}); + LDomain -> + Mod = get_backend(), + case Mod:unregister_route(LDomain, get_component_number(LDomain)) of + ok -> + ?DEBUG("Route unregistered: ~s", [LDomain]); + {error, Err} -> + ?ERROR_MSG("Failed to unregister route ~s: ~p", + [LDomain, Err]) + end end. -spec unregister_routes([binary()]) -> ok. - unregister_routes(Domains) -> lists:foreach(fun (Domain) -> unregister_route(Domain) end, Domains). --spec dirty_get_all_routes() -> [binary()]. - -dirty_get_all_routes() -> - lists:usort(mnesia:dirty_all_keys(route)) -- (?MYHOSTS). - --spec dirty_get_all_domains() -> [binary()]. - -dirty_get_all_domains() -> - lists:usort(mnesia:dirty_all_keys(route)). +-spec get_all_routes() -> [binary()]. +get_all_routes() -> + Mod = get_backend(), + Mod:get_all_routes(). -spec host_of_route(binary()) -> binary(). - host_of_route(Domain) -> case jid:nameprep(Domain) of error -> erlang:error({invalid_domain, Domain}); LDomain -> - case mnesia:dirty_read(route, LDomain) of - [#route{server_host = ServerHost}|_] -> - ServerHost; - [] -> - erlang:error({unregistered_route, Domain}) + Mod = get_backend(), + case Mod:host_of_route(LDomain) of + {ok, ServerHost} -> ServerHost; + error -> erlang:error({unregistered_route, Domain}) end end. --spec process_iq(jid(), jid(), iq() | xmlel()) -> any(). +-spec is_my_route(binary()) -> boolean(). +is_my_route(Domain) -> + case jid:nameprep(Domain) of + error -> + erlang:error({invalid_domain, Domain}); + LDomain -> + Mod = get_backend(), + Mod:is_my_route(LDomain) + end. + +-spec is_my_host(binary()) -> boolean(). +is_my_host(Domain) -> + case jid:nameprep(Domain) of + error -> + erlang:error({invalid_domain, Domain}); + LDomain -> + Mod = get_backend(), + Mod:is_my_host(LDomain) + end. + +-spec process_iq(jid(), jid(), iq()) -> any(). process_iq(From, To, #iq{} = IQ) -> if To#jid.luser == <<"">> -> ejabberd_local:process_iq(From, To, IQ); true -> ejabberd_sm:process_iq(From, To, IQ) - end; -process_iq(From, To, El) -> - try xmpp:decode(El, ?NS_CLIENT, [ignore_els]) of - IQ -> process_iq(From, To, IQ) - catch _:{xmpp_codec, Why} -> - Type = xmpp:get_type(El), - if Type == <<"get">>; Type == <<"set">> -> - Txt = xmpp:format_error(Why), - Lang = xmpp:get_lang(El), - Err = xmpp:make_error(El, xmpp:err_bad_request(Txt, Lang)), - ejabberd_router:route(To, From, Err); - true -> - ok - end end. %%==================================================================== %% gen_server callbacks %%==================================================================== - -%%-------------------------------------------------------------------- -%% Function: init(Args) -> {ok, State} | -%% {ok, State, Timeout} | -%% ignore | -%% {stop, Reason} -%% Description: Initiates the server -%%-------------------------------------------------------------------- init([]) -> - update_tables(), - ejabberd_mnesia:create(?MODULE, route, - [{ram_copies, [node()]}, - {type, bag}, - {attributes, record_info(fields, route)}]), - mnesia:add_table_copy(route, node(), ram_copies), - mnesia:subscribe({table, route, simple}), - lists:foreach(fun (Pid) -> erlang:monitor(process, Pid) - end, - mnesia:dirty_select(route, - [{{route, '_', '$1', '_'}, [], ['$1']}])), + Mod = get_backend(), + Mod:init(), {ok, #state{}}. -%%-------------------------------------------------------------------- -%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | -%% {reply, Reply, State, Timeout} | -%% {noreply, State} | -%% {noreply, State, Timeout} | -%% {stop, Reason, Reply, State} | -%% {stop, Reason, State} -%% Description: Handling call messages -%%-------------------------------------------------------------------- handle_call(_Request, _From, State) -> - Reply = ok, {reply, Reply, State}. + Reply = ok, + {reply, Reply, State}. -%%-------------------------------------------------------------------- -%% Function: handle_cast(Msg, State) -> {noreply, State} | -%% {noreply, State, Timeout} | -%% {stop, Reason, State} -%% Description: Handling cast messages -%%-------------------------------------------------------------------- -handle_cast(_Msg, State) -> {noreply, State}. +handle_cast(_Msg, State) -> + {noreply, State}. -%%-------------------------------------------------------------------- -%% Function: handle_info(Info, State) -> {noreply, State} | -%% {noreply, State, Timeout} | -%% {stop, Reason, State} -%% Description: Handling all non call/cast messages -%%-------------------------------------------------------------------- handle_info({route, From, To, Packet}, State) -> case catch do_route(From, To, Packet) of {'EXIT', Reason} -> @@ -335,106 +236,71 @@ handle_info({route, From, To, Packet}, State) -> _ -> ok end, {noreply, State}; -handle_info({mnesia_table_event, - {write, #route{pid = Pid}, _ActivityId}}, - State) -> - erlang:monitor(process, Pid), {noreply, State}; -handle_info({'DOWN', _Ref, _Type, Pid, _Info}, State) -> - F = fun () -> - Es = mnesia:select(route, - [{#route{pid = Pid, _ = '_'}, [], ['$_']}]), - lists:foreach(fun (E) -> - if is_integer(E#route.local_hint) -> - LDomain = E#route.domain, - I = E#route.local_hint, - ServerHost = E#route.server_host, - mnesia:write(#route{domain = - LDomain, - server_host = - ServerHost, - pid = - undefined, - local_hint = - I}), - mnesia:delete_object(E); - true -> mnesia:delete_object(E) - end - end, - Es) - end, - mnesia:transaction(F), - {noreply, State}; -handle_info(_Info, State) -> +handle_info(Info, State) -> + ?ERROR_MSG("unexpected info: ~p", [Info]), {noreply, State}. -%%-------------------------------------------------------------------- -%% Function: terminate(Reason, State) -> void() -%% Description: This function is called by a gen_server when it is about to -%% terminate. It should be the opposite of Module:init/1 and do any necessary -%% cleaning up. When it returns, the gen_server terminates with Reason. -%% The return value is ignored. -%%-------------------------------------------------------------------- terminate(_Reason, _State) -> ok. -%%-------------------------------------------------------------------- -%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} -%% Description: Convert process state when code is changed -%%-------------------------------------------------------------------- code_change(_OldVsn, State, _Extra) -> {ok, State}. %%-------------------------------------------------------------------- %%% Internal functions %%-------------------------------------------------------------------- --spec do_route(jid(), jid(), xmlel() | xmpp_element()) -> any(). +-spec do_route(jid(), jid(), stanza()) -> any(). do_route(OrigFrom, OrigTo, OrigPacket) -> - ?DEBUG("route~n\tfrom ~p~n\tto ~p~n\tpacket " - "~p~n", - [OrigFrom, OrigTo, OrigPacket]), + ?DEBUG("route:~n~s", [xmpp:pp(OrigPacket)]), case ejabberd_hooks:run_fold(filter_packet, - {OrigFrom, OrigTo, OrigPacket}, []) - of - {From, To, Packet} -> - LDstDomain = To#jid.lserver, - case mnesia:dirty_read(route, LDstDomain) of - [] -> - ejabberd_s2s:route(From, To, Packet); - [R] -> - do_route(From, To, Packet, R); - Rs -> - Value = get_domain_balancing(From, To, LDstDomain), - case get_component_number(LDstDomain) of - undefined -> - case [R || R <- Rs, node(R#route.pid) == node()] of - [] -> - R = lists:nth(erlang:phash(Value, length(Rs)), Rs), - do_route(From, To, Packet, R); - LRs -> - R = lists:nth(erlang:phash(Value, length(LRs)), LRs), - do_route(From, To, Packet, R) - end; - _ -> - SRs = lists:ukeysort(#route.local_hint, Rs), - R = lists:nth(erlang:phash(Value, length(SRs)), SRs), - do_route(From, To, Packet, R) - end - end; - drop -> ok + {OrigFrom, OrigTo, OrigPacket}, []) of + {From, To, Packet} -> + LDstDomain = To#jid.lserver, + Mod = get_backend(), + case Mod:find_routes(LDstDomain) of + [] -> + ejabberd_s2s:route(From, To, Packet); + [Route] -> + do_route(From, To, Packet, Route); + Routes -> + balancing_route(From, To, Packet, Routes) + end; + drop -> + ok end. --spec do_route(jid(), jid(), xmlel() | xmpp_element(), #route{}) -> any(). -do_route(From, To, Packet, #route{local_hint = LocalHint, - pid = Pid}) when is_pid(Pid) -> +-spec do_route(jid(), jid(), stanza(), #route{}) -> any(). +do_route(From, To, Pkt, #route{local_hint = LocalHint, + pid = Pid}) when is_pid(Pid) -> case LocalHint of {apply, Module, Function} when node(Pid) == node() -> - Module:Function(From, To, Packet); + Module:Function(From, To, Pkt); _ -> - Pid ! {route, From, To, Packet} + Pid ! {route, From, To, Pkt} end; -do_route(_From, _To, _Packet, _Route) -> +do_route(_From, _To, _Pkt, _Route) -> drop. +-spec balancing_route(jid(), jid(), stanza(), [#route{}]) -> any(). +balancing_route(From, To, Packet, Rs) -> + LDstDomain = To#jid.lserver, + Value = get_domain_balancing(From, To, LDstDomain), + case get_component_number(LDstDomain) of + undefined -> + case [R || R <- Rs, node(R#route.pid) == node()] of + [] -> + R = lists:nth(erlang:phash(Value, length(Rs)), Rs), + do_route(From, To, Packet, R); + LRs -> + R = lists:nth(erlang:phash(Value, length(LRs)), LRs), + do_route(From, To, Packet, R) + end; + _ -> + SRs = lists:ukeysort(#route.local_hint, Rs), + R = lists:nth(erlang:phash(Value, length(SRs)), SRs), + do_route(From, To, Packet, R) + end. + -spec get_component_number(binary()) -> pos_integer() | undefined. get_component_number(LDomain) -> ejabberd_config:get_option( @@ -454,19 +320,17 @@ get_domain_balancing(From, To, LDomain) -> bare_destination -> jid:remove_resource(jid:tolower(To)) end. --spec update_tables() -> ok. -update_tables() -> - try - mnesia:transform_table(route, ignore, record_info(fields, route)) - catch exit:{aborted, {no_exists, _}} -> - ok - end, - case lists:member(local_route, - mnesia:system_info(tables)) - of - true -> mnesia:delete_table(local_route); - false -> ok - end. +-spec get_backend() -> module(). +get_backend() -> + DBType = case ejabberd_config:get_option( + router_db_type, + fun(T) -> ejabberd_config:v_db(?MODULE, T) end) of + undefined -> + ejabberd_config:default_ram_db(?MODULE); + T -> + T + end, + list_to_atom("ejabberd_router_" ++ atom_to_list(DBType)). opt_type(domain_balancing) -> fun (random) -> random; @@ -477,4 +341,7 @@ opt_type(domain_balancing) -> end; opt_type(domain_balancing_component_number) -> fun (N) when is_integer(N), N > 1 -> N end; -opt_type(_) -> [domain_balancing, domain_balancing_component_number]. +opt_type(router_db_type) -> fun(T) -> ejabberd_config:v_db(?MODULE, T) end; +opt_type(_) -> + [domain_balancing, domain_balancing_component_number, + router_db_type]. diff --git a/src/ejabberd_router_mnesia.erl b/src/ejabberd_router_mnesia.erl new file mode 100644 index 000000000..d9946cef5 --- /dev/null +++ b/src/ejabberd_router_mnesia.erl @@ -0,0 +1,220 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2017, Evgeny Khramtsov +%%% @doc +%%% +%%% @end +%%% Created : 11 Jan 2017 by Evgeny Khramtsov +%%%------------------------------------------------------------------- +-module(ejabberd_router_mnesia). +-behaviour(ejabberd_router). +-behaviour(gen_server). + +%% API +-export([init/0, register_route/4, unregister_route/2, find_routes/1, + host_of_route/1, is_my_route/1, is_my_host/1, get_all_routes/0]). +%% gen_server callbacks +-export([init/1, handle_cast/2, handle_call/3, handle_info/2, + terminate/2, code_change/3]). + +-include("ejabberd.hrl"). +-include("ejabberd_router.hrl"). +-include("logger.hrl"). +-include_lib("stdlib/include/ms_transform.hrl"). + +-record(state, {}). + +%%%=================================================================== +%%% API +%%%=================================================================== +init() -> + case gen_server:start_link({local, ?MODULE}, ?MODULE, [], []) of + {ok, _Pid} -> + ok; + Err -> + Err + end. + +register_route(Domain, ServerHost, LocalHint, undefined) -> + F = fun () -> + mnesia:write(#route{domain = Domain, + pid = self(), + server_host = ServerHost, + local_hint = LocalHint}) + end, + transaction(F); +register_route(Domain, ServerHost, _LocalHint, N) -> + Pid = self(), + F = fun () -> + case mnesia:wread({route, Domain}) of + [] -> + mnesia:write(#route{domain = Domain, + server_host = ServerHost, + pid = Pid, + local_hint = 1}), + lists:foreach( + fun (I) -> + mnesia:write( + #route{domain = Domain, + pid = undefined, + server_host = ServerHost, + local_hint = I}) + end, + lists:seq(2, N)); + Rs -> + lists:any( + fun (#route{pid = undefined, + local_hint = I} = R) -> + mnesia:write( + #route{domain = Domain, + pid = Pid, + server_host = ServerHost, + local_hint = I}), + mnesia:delete_object(R), + true; + (_) -> false + end, + Rs) + end + end, + transaction(F). + +unregister_route(Domain, undefined) -> + F = fun () -> + case mnesia:match_object( + #route{domain = Domain, pid = self(), _ = '_'}) of + [R] -> mnesia:delete_object(R); + _ -> ok + end + end, + transaction(F); +unregister_route(Domain, _) -> + F = fun () -> + case mnesia:match_object( + #route{domain = Domain, pid = self(), _ = '_'}) of + [R] -> + I = R#route.local_hint, + ServerHost = R#route.server_host, + mnesia:write(#route{domain = Domain, + server_host = ServerHost, + pid = undefined, + local_hint = I}), + mnesia:delete_object(R); + _ -> ok + end + end, + transaction(F). + +find_routes(Domain) -> + mnesia:dirty_read(route, Domain). + +host_of_route(Domain) -> + case mnesia:dirty_read(route, Domain) of + [#route{server_host = ServerHost}|_] -> + {ok, ServerHost}; + [] -> + error + end. + +is_my_route(Domain) -> + mnesia:dirty_read(route, Domain) /= []. + +is_my_host(Domain) -> + case mnesia:dirty_read(route, Domain) of + [#route{server_host = Host}|_] -> + Host == Domain; + [] -> + false + end. + +get_all_routes() -> + mnesia:dirty_select( + route, + ets:fun2ms( + fun(#route{domain = Domain, server_host = ServerHost}) + when Domain /= ServerHost -> Domain + end)). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([]) -> + update_tables(), + ejabberd_mnesia:create(?MODULE, route, + [{ram_copies, [node()]}, + {type, bag}, + {attributes, record_info(fields, route)}]), + mnesia:add_table_copy(route, node(), ram_copies), + mnesia:subscribe({table, route, simple}), + lists:foreach( + fun (Pid) -> erlang:monitor(process, Pid) end, + mnesia:dirty_select(route, + [{{route, '_', '$1', '_'}, [], ['$1']}])), + {ok, #state{}}. + +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info({mnesia_table_event, + {write, #route{pid = Pid}, _ActivityId}}, State) -> + erlang:monitor(process, Pid), + {noreply, State}; +handle_info({'DOWN', _Ref, _Type, Pid, _Info}, State) -> + F = fun () -> + Es = mnesia:select(route, + [{#route{pid = Pid, _ = '_'}, [], ['$_']}]), + lists:foreach( + fun(E) -> + if is_integer(E#route.local_hint) -> + LDomain = E#route.domain, + I = E#route.local_hint, + ServerHost = E#route.server_host, + mnesia:write(#route{domain = LDomain, + server_host = ServerHost, + pid = undefined, + local_hint = I}), + mnesia:delete_object(E); + true -> + mnesia:delete_object(E) + end + end, Es) + end, + transaction(F), + {noreply, State}; +handle_info(Info, State) -> + ?ERROR_MSG("unexpected info: ~p", [Info]), + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +transaction(F) -> + case mnesia:transaction(F) of + {atomic, _} -> + ok; + {aborted, Reason} -> + ?ERROR_MSG("Mnesia transaction failed: ~p", [Reason]), + {error, Reason} + end. + +-spec update_tables() -> ok. +update_tables() -> + try + mnesia:transform_table(route, ignore, record_info(fields, route)) + catch exit:{aborted, {no_exists, _}} -> + ok + end, + case lists:member(local_route, mnesia:system_info(tables)) of + true -> mnesia:delete_table(local_route); + false -> ok + end. diff --git a/src/ejabberd_router_multicast.erl b/src/ejabberd_router_multicast.erl index ce744c06f..19c6da144 100644 --- a/src/ejabberd_router_multicast.erl +++ b/src/ejabberd_router_multicast.erl @@ -35,7 +35,7 @@ unregister_route/1 ]). --export([start_link/0]). +-export([start/0, start_link/0]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, @@ -56,6 +56,11 @@ %% Function: start_link() -> {ok,Pid} | ignore | {error,Error} %% Description: Starts the server %%-------------------------------------------------------------------- +start() -> + ChildSpec = {?MODULE, {?MODULE, start_link, []}, + transient, 1000, worker, [?MODULE]}, + supervisor:start_child(ejabberd_sup, ChildSpec). + start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). diff --git a/src/ejabberd_s2s.erl b/src/ejabberd_s2s.erl index 07ae5e70e..86cf1a1f5 100644 --- a/src/ejabberd_s2s.erl +++ b/src/ejabberd_s2s.erl @@ -35,16 +35,16 @@ %% API -export([start_link/0, route/3, have_connection/1, - make_key/2, get_connections_pids/1, try_register/1, - remove_connection/2, find_connection/2, + get_connections_pids/1, try_register/1, + remove_connection/2, start_connection/2, start_connection/3, dirty_get_connections/0, allow_host/2, incoming_s2s_number/0, outgoing_s2s_number/0, stop_all_connections/0, clean_temporarily_blocked_table/0, list_temporarily_blocked_hosts/0, external_host_overloaded/1, is_temporarly_blocked/1, - check_peer_certificate/3, - get_commands_spec/0]). + get_commands_spec/0, zlib_enabled/1, get_idle_timeout/1, + tls_required/1, tls_verify/1, tls_enabled/1, tls_options/2]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, @@ -196,39 +196,94 @@ try_register(FromTo) -> dirty_get_connections() -> mnesia:dirty_all_keys(s2s). -check_peer_certificate(SockMod, Sock, Peer) -> - case SockMod:get_peer_certificate(Sock) of - {ok, Cert} -> - case SockMod:get_verify_result(Sock) of - 0 -> - case ejabberd_idna:domain_utf8_to_ascii(Peer) of - false -> - {error, <<"Cannot decode remote server name">>}; - AsciiPeer -> - case - lists:any(fun(D) -> match_domain(AsciiPeer, D) end, - get_cert_domains(Cert)) of - true -> - {ok, <<"Verification successful">>}; - false -> - {error, <<"Certificate host name mismatch">>} - end - end; - VerifyRes -> - {error, fast_tls:get_cert_verify_string(VerifyRes, Cert)} - end; - {error, _Reason} -> - {error, <<"Cannot get peer certificate">>}; - error -> - {error, <<"Cannot get peer certificate">>} +-spec tls_options(binary(), [proplists:property()]) -> [proplists:property()]. +tls_options(LServer, DefaultOpts) -> + TLSOpts1 = case ejabberd_config:get_option( + {s2s_certfile, LServer}, + fun iolist_to_binary/1, + ejabberd_config:get_option( + {domain_certfile, LServer}, + fun iolist_to_binary/1)) of + undefined -> []; + CertFile -> lists:keystore(certfile, 1, DefaultOpts, + {certfile, CertFile}) + end, + TLSOpts2 = case ejabberd_config:get_option( + {s2s_ciphers, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts1; + Ciphers -> lists:keystore(ciphers, 1, TLSOpts1, + {ciphers, Ciphers}) + end, + TLSOpts3 = case ejabberd_config:get_option( + {s2s_protocol_options, LServer}, + fun (Options) -> str:join(Options, <<$|>>) end) of + undefined -> TLSOpts2; + ProtoOpts -> lists:keystore(protocol_options, 1, TLSOpts2, + {protocol_options, ProtoOpts}) + end, + TLSOpts4 = case ejabberd_config:get_option( + {s2s_dhfile, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts3; + DHFile -> lists:keystore(dhfile, 1, TLSOpts3, + {dhfile, DHFile}) + end, + TLSOpts5 = case ejabberd_config:get_option( + {s2s_cafile, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts4; + CAFile -> lists:keystore(cafile, 1, TLSOpts4, + {cafile, CAFile}) + end, + case ejabberd_config:get_option( + {s2s_tls_compression, LServer}, + fun(B) when is_boolean(B) -> B end) of + undefined -> TLSOpts5; + false -> [compression_none | TLSOpts5]; + true -> lists:delete(compression_none, TLSOpts5) end. --spec make_key({binary(), binary()}, binary()) -> binary(). -make_key({From, To}, StreamID) -> - Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end), - p1_sha:to_hexlist( - crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)), - [To, " ", From, " ", StreamID])). +-spec tls_required(binary()) -> boolean(). +tls_required(LServer) -> + TLS = use_starttls(LServer), + TLS == required orelse TLS == required_trusted. + +-spec tls_verify(binary()) -> boolean(). +tls_verify(LServer) -> + TLS = use_starttls(LServer), + TLS == required_trusted. + +-spec tls_enabled(binary()) -> boolean(). +tls_enabled(LServer) -> + TLS = use_starttls(LServer), + TLS /= false. + +-spec zlib_enabled(binary()) -> boolean(). +zlib_enabled(LServer) -> + ejabberd_config:get_option( + {s2s_zlib, LServer}, + fun(B) when is_boolean(B) -> B end, + false). + +-spec use_starttls(binary()) -> boolean() | optional | required | required_trusted. +use_starttls(LServer) -> + ejabberd_config:get_option( + {s2s_use_starttls, LServer}, + fun(true) -> true; + (false) -> false; + (optional) -> optional; + (required) -> required; + (required_trusted) -> required_trusted + end, false). + +-spec get_idle_timeout(binary()) -> non_neg_integer() | infinity. +get_idle_timeout(LServer) -> + ejabberd_config:get_option( + {s2s_timeout, LServer}, + fun(I) when is_integer(I), I >= 0 -> timer:seconds(I); + (infinity) -> infinity + end, timer:minutes(10)). %%==================================================================== %% gen_server callbacks @@ -246,6 +301,8 @@ init([]) -> ejabberd_mnesia:create(?MODULE, temporarily_blocked, [{ram_copies, [node()]}, {attributes, record_info(fields, temporarily_blocked)}]), + ejabberd_s2s_in:add_hooks(), + ejabberd_s2s_out:add_hooks(), {ok, #state{}}. handle_call(_Request, _From, State) -> @@ -291,30 +348,40 @@ clean_table_from_bad_node(Node) -> end, mnesia:async_dirty(F). --spec do_route(jid(), jid(), stanza()) -> ok | false. +-spec do_route(jid(), jid(), stanza()) -> ok. do_route(From, To, Packet) -> ?DEBUG("s2s manager~n\tfrom ~p~n\tto ~p~n\tpacket " "~P~n", [From, To, Packet, 8]), - case find_connection(From, To) of - {atomic, Pid} when is_pid(Pid) -> + case start_connection(From, To) of + {ok, Pid} when is_pid(Pid) -> ?DEBUG("sending to process ~p~n", [Pid]), #jid{lserver = MyServer} = From, - ejabberd_hooks:run(s2s_send_packet, MyServer, - [From, To, Packet]), - send_element(Pid, xmpp:set_from_to(Packet, From, To)), - ok; - {aborted, _Reason} -> + ejabberd_hooks:run(s2s_send_packet, MyServer, [From, To, Packet]), + ejabberd_s2s_out:route(Pid, xmpp:set_from_to(Packet, From, To)); + {error, Reason} -> Lang = xmpp:get_lang(Packet), - Txt = <<"No s2s connection found">>, - Err = xmpp:err_service_unavailable(Txt, Lang), - ejabberd_router:route_error(To, From, Packet, Err), - false + Err = case Reason of + policy_violation -> + xmpp:err_policy_violation( + <<"Server connections to local " + "subdomains are forbidden">>, Lang); + forbidden -> + xmpp:err_forbidden(<<"Denied by ACL">>, Lang); + internal_server_error -> + xmpp:err_internal_server_error() + end, + ejabberd_router:route_error(To, From, Packet, Err) end. --spec find_connection(jid(), jid()) -> {aborted, any()} | {atomic, pid()}. +-spec start_connection(jid(), jid()) + -> {ok, pid()} | {error, policy_violation | forbidden | internal_server_error}. +start_connection(From, To) -> + start_connection(From, To, []). -find_connection(From, To) -> +-spec start_connection(jid(), jid(), [proplists:property()]) + -> {ok, pid()} | {error, policy_violation | forbidden | internal_server_error}. +start_connection(From, To, Opts) -> #jid{lserver = MyServer} = From, #jid{lserver = Server} = To, FromTo = {MyServer, Server}, @@ -323,24 +390,29 @@ find_connection(From, To) -> MaxS2SConnectionsNumberPerNode = max_s2s_connections_number_per_node(FromTo), ?DEBUG("Finding connection for ~p~n", [FromTo]), - case catch mnesia:dirty_read(s2s, FromTo) of - {'EXIT', Reason} -> {aborted, Reason}; + case mnesia:dirty_read(s2s, FromTo) of [] -> %% We try to establish all the connections if the host is not a %% service and if the s2s host is not blacklisted or %% is in whitelist: - case not is_service(From, To) andalso - allow_host(MyServer, Server) - of + LServer = ejabberd_router:host_of_route(MyServer), + case is_service(From, To) of true -> - NeededConnections = needed_connections_number([], + {error, policy_violation}; + false -> + case allow_host(LServer, Server) of + true -> + NeededConnections = needed_connections_number( + [], MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode), open_several_connections(NeededConnections, MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode); - false -> {aborted, error} + MaxS2SConnectionsNumberPerNode, Opts); + false -> + {error, forbidden} + end end; L when is_list(L) -> NeededConnections = needed_connections_number(L, @@ -351,10 +423,10 @@ find_connection(From, To) -> open_several_connections(NeededConnections, MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode); + MaxS2SConnectionsNumberPerNode, Opts); true -> %% We choose a connexion from the pool of opened ones. - {atomic, choose_connection(From, L)} + {ok, choose_connection(From, L)} end end. @@ -377,20 +449,22 @@ choose_pid(From, Pids) -> open_several_connections(N, MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode) -> - ConnectionsResult = [new_connection(MyServer, Server, + MaxS2SConnectionsNumberPerNode, Opts) -> + case lists:flatmap( + fun(_) -> + new_connection(MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode) - || _N <- lists:seq(1, N)], - case [PID || {atomic, PID} <- ConnectionsResult] of - [] -> hd(ConnectionsResult); - PIDs -> {atomic, choose_pid(From, PIDs)} + MaxS2SConnectionsNumberPerNode, Opts) + end, lists:seq(1, N)) of + [] -> + {error, internal_server_error}; + PIDs -> + {ok, choose_pid(From, PIDs)} end. new_connection(MyServer, Server, From, FromTo, - MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode) -> - {ok, Pid} = ejabberd_s2s_out:start( - MyServer, Server, new), + MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) -> + {ok, Pid} = ejabberd_s2s_out:start(MyServer, Server, Opts), F = fun() -> L = mnesia:read({s2s, FromTo}), NeededConnections = needed_connections_number(L, @@ -398,17 +472,21 @@ new_connection(MyServer, Server, From, FromTo, MaxS2SConnectionsNumberPerNode), if NeededConnections > 0 -> mnesia:write(#s2s{fromto = FromTo, pid = Pid}), - ?INFO_MSG("New s2s connection started ~p", [Pid]), Pid; true -> choose_connection(From, L) end end, TRes = mnesia:transaction(F), case TRes of - {atomic, Pid} -> ejabberd_s2s_out:start_connection(Pid); - _ -> ejabberd_s2s_out:stop_connection(Pid) - end, - TRes. + {atomic, Pid} -> + ejabberd_s2s_out:connect(Pid), + [Pid]; + {aborted, Reason} -> + ?ERROR_MSG("failed to register connection ~s -> ~s: ~p", + [MyServer, Server, Reason]), + ejabberd_s2s_out:stop(Pid), + [] + end. -spec max_s2s_connections_number({binary(), binary()}) -> integer(). max_s2s_connections_number({From, To}) -> @@ -459,9 +537,6 @@ parent_domains(Domain) -> end, [], lists:reverse(str:tokens(Domain, <<".">>))). -send_element(Pid, El) -> - Pid ! {send_element, El}. - %%%---------------------------------------------------------------------- %%% ejabberd commands @@ -536,24 +611,13 @@ update_tables() -> %% Check if host is in blacklist or white list allow_host(MyServer, S2SHost) -> - allow_host2(MyServer, S2SHost) andalso + allow_host1(MyServer, S2SHost) andalso not is_temporarly_blocked(S2SHost). -allow_host2(MyServer, S2SHost) -> - Hosts = (?MYHOSTS), - case lists:dropwhile(fun (ParentDomain) -> - not lists:member(ParentDomain, Hosts) - end, - parent_domains(MyServer)) - of - [MyHost | _] -> allow_host1(MyHost, S2SHost); - [] -> allow_host1(MyServer, S2SHost) - end. - allow_host1(MyHost, S2SHost) -> Rule = ejabberd_config:get_option( - s2s_access, - fun(A) -> A end, + {s2s_access, MyHost}, + fun acl:access_rules_validator/1, all), JID = jid:make(S2SHost), case acl:match_rule(MyHost, Rule, JID) of @@ -624,133 +688,34 @@ get_s2s_state(S2sPid) -> end, [{s2s_pid, S2sPid} | Infos]. -get_cert_domains(Cert) -> - TBSCert = Cert#'Certificate'.tbsCertificate, - Subject = case TBSCert#'TBSCertificate'.subject of - {rdnSequence, Subj} -> lists:flatten(Subj); - _ -> [] - end, - Extensions = case TBSCert#'TBSCertificate'.extensions of - Exts when is_list(Exts) -> Exts; - _ -> [] - end, - lists:flatmap(fun (#'AttributeTypeAndValue'{type = - ?'id-at-commonName', - value = Val}) -> - case 'OTP-PUB-KEY':decode('X520CommonName', Val) of - {ok, {_, D1}} -> - D = if is_binary(D1) -> D1; - is_list(D1) -> list_to_binary(D1); - true -> error - end, - if D /= error -> - case jid:from_string(D) of - #jid{luser = <<"">>, lserver = LD, - lresource = <<"">>} -> - [LD]; - _ -> [] - end; - true -> [] - end; - _ -> [] - end; - (_) -> [] - end, - Subject) - ++ - lists:flatmap(fun (#'Extension'{extnID = - ?'id-ce-subjectAltName', - extnValue = Val}) -> - BVal = if is_list(Val) -> list_to_binary(Val); - true -> Val - end, - case 'OTP-PUB-KEY':decode('SubjectAltName', BVal) - of - {ok, SANs} -> - lists:flatmap(fun ({otherName, - #'AnotherName'{'type-id' = - ?'id-on-xmppAddr', - value = - XmppAddr}}) -> - case - 'XmppAddr':decode('XmppAddr', - XmppAddr) - of - {ok, D} - when - is_binary(D) -> - case - jid:from_string((D)) - of - #jid{luser = - <<"">>, - lserver = - LD, - lresource = - <<"">>} -> - case - ejabberd_idna:domain_utf8_to_ascii(LD) - of - false -> - []; - PCLD -> - [PCLD] - end; - _ -> [] - end; - _ -> [] - end; - ({dNSName, D}) - when is_list(D) -> - case - jid:from_string(list_to_binary(D)) - of - #jid{luser = <<"">>, - lserver = LD, - lresource = - <<"">>} -> - [LD]; - _ -> [] - end; - (_) -> [] - end, - SANs); - _ -> [] - end; - (_) -> [] - end, - Extensions). - -match_domain(Domain, Domain) -> true; -match_domain(Domain, Pattern) -> - DLabels = str:tokens(Domain, <<".">>), - PLabels = str:tokens(Pattern, <<".">>), - match_labels(DLabels, PLabels). - -match_labels([], []) -> true; -match_labels([], [_ | _]) -> false; -match_labels([_ | _], []) -> false; -match_labels([DL | DLabels], [PL | PLabels]) -> - case lists:all(fun (C) -> - $a =< C andalso C =< $z orelse - $0 =< C andalso C =< $9 orelse - C == $- orelse C == $* - end, - binary_to_list(PL)) - of - true -> - Regexp = ejabberd_regexp:sh_to_awk(PL), - case ejabberd_regexp:run(DL, Regexp) of - match -> match_labels(DLabels, PLabels); - nomatch -> false - end; - false -> false - end. - opt_type(route_subdomains) -> fun (s2s) -> s2s; (local) -> local end; opt_type(s2s_access) -> fun acl:access_rules_validator/1; -opt_type(_) -> [route_subdomains, s2s_access]. +opt_type(domain_certfile) -> fun iolist_to_binary/1; +opt_type(s2s_certfile) -> fun iolist_to_binary/1; +opt_type(s2s_ciphers) -> fun iolist_to_binary/1; +opt_type(s2s_dhfile) -> fun iolist_to_binary/1; +opt_type(s2s_protocol_options) -> + fun (Options) -> str:join(Options, <<"|">>) end; +opt_type(s2s_tls_compression) -> + fun (true) -> true; + (false) -> false + end; +opt_type(s2s_use_starttls) -> + fun (true) -> true; + (false) -> false; + (optional) -> optional; + (required) -> required; + (required_trusted) -> required_trusted + end; +opt_type(s2s_timeout) -> + fun(I) when is_integer(I), I>=0 -> I; + (infinity) -> infinity + end; +opt_type(_) -> + [route_subdomains, s2s_access, s2s_certfile, + s2s_ciphers, s2s_dhfile, s2s_protocol_options, + s2s_tls_compression, s2s_use_starttls, s2s_timeout]. diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index ffdadc135..3b4b6a989 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -1,8 +1,5 @@ -%%%---------------------------------------------------------------------- -%%% File : ejabberd_s2s_in.erl -%%% Author : Alexey Shchepin -%%% Purpose : Serve incoming s2s connection -%%% Created : 6 Dec 2002 by Alexey Shchepin +%%%------------------------------------------------------------------- +%%% Created : 12 Dec 2016 by Evgeny Khramtsov %%% %%% %%% ejabberd, Copyright (C) 2002-2017 ProcessOne @@ -21,645 +18,314 @@ %%% with this program; if not, write to the Free Software Foundation, Inc., %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. %%% -%%%---------------------------------------------------------------------- - +%%%------------------------------------------------------------------- -module(ejabberd_s2s_in). - +-behaviour(xmpp_stream_in). -behaviour(ejabberd_config). +-behaviour(ejabberd_socket). --author('alexey@process-one.net'). - --behaviour(p1_fsm). - -%% External exports +%% ejabberd_socket callbacks -export([start/2, start_link/2, socket_type/0]). - --export([init/1, wait_for_stream/2, - wait_for_feature_request/2, stream_established/2, - handle_event/3, handle_sync_event/4, code_change/4, - handle_info/3, print_state/1, terminate/3, opt_type/1]). +%% ejabberd_config callbacks +-export([opt_type/1]). +%% xmpp_stream_in callbacks +-export([init/1, handle_call/3, handle_cast/2, + handle_info/2, terminate/2, code_change/3]). +-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, + compress_methods/1, + unauthenticated_stream_features/1, authenticated_stream_features/1, + handle_stream_start/2, handle_stream_end/2, + handle_stream_established/1, handle_auth_success/4, + handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2, + handle_unauthenticated_packet/2, handle_authenticated_packet/2]). +%% Hooks +-export([handle_unexpected_info/2, handle_unexpected_cast/2, + reject_unauthenticated_packet/2, process_closed/2]). +%% API +-export([stop/1, close/1, send/2, update_state/2, establish/1, add_hooks/0]). -include("ejabberd.hrl"). +-include("xmpp.hrl"). -include("logger.hrl"). --include("xmpp.hrl"). - --define(DICT, dict). - --record(state, - {socket :: ejabberd_socket:socket_state(), - sockmod = ejabberd_socket :: ejabberd_socket | ejabberd_frontend_socket, - streamid = <<"">> :: binary(), - shaper = none :: shaper:shaper(), - tls = false :: boolean(), - tls_enabled = false :: boolean(), - tls_required = false :: boolean(), - tls_certverify = false :: boolean(), - tls_options = [] :: list(), - server = <<"">> :: binary(), - authenticated = false :: boolean(), - auth_domain = <<"">> :: binary(), - connections = (?DICT):new() :: ?TDICT, - timer = make_ref() :: reference()}). - --type state_name() :: wait_for_stream | wait_for_feature_request | stream_established. --type state() :: #state{}. --type fsm_next() :: {next_state, state_name(), state()}. --type fsm_stop() :: {stop, normal, state()}. --type fsm_transition() :: fsm_stop() | fsm_next(). - -%%-define(DBGFSM, true). --ifdef(DBGFSM). --define(FSMOPTS, [{debug, [trace]}]). --else. --define(FSMOPTS, []). --endif. +-type state() :: map(). +-export_type([state/0]). +%%%=================================================================== +%%% API +%%%=================================================================== start(SockData, Opts) -> - supervisor:start_child(ejabberd_s2s_in_sup, - [SockData, Opts]). + case proplists:get_value(supervisor, Opts, true) of + true -> + supervisor:start_child(ejabberd_s2s_in_sup, [SockData, Opts]); + _ -> + xmpp_stream_in:start(?MODULE, [SockData, Opts], + ejabberd_config:fsm_limit_opts(Opts)) + end. start_link(SockData, Opts) -> - p1_fsm:start_link(ejabberd_s2s_in, [SockData, Opts], - ?FSMOPTS ++ fsm_limit_opts(Opts)). + xmpp_stream_in:start_link(?MODULE, [SockData, Opts], + ejabberd_config:fsm_limit_opts(Opts)). -socket_type() -> xml_stream. +close(Ref) -> + xmpp_stream_in:close(Ref). -%%%---------------------------------------------------------------------- -%%% Callback functions from gen_fsm -%%%---------------------------------------------------------------------- +stop(Ref) -> + xmpp_stream_in:stop(Ref). -init([{SockMod, Socket}, Opts]) -> - ?DEBUG("started: ~p", [{SockMod, Socket}]), - Shaper = case lists:keysearch(shaper, 1, Opts) of - {value, {_, S}} -> S; - _ -> none - end, - {StartTLS, TLSRequired, TLSCertverify} = - case ejabberd_config:get_option( - s2s_use_starttls, - fun(false) -> false; - (true) -> true; - (optional) -> optional; - (required) -> required; - (required_trusted) -> required_trusted - end, - false) of - UseTls - when (UseTls == undefined) or - (UseTls == false) -> - {false, false, false}; - UseTls - when (UseTls == true) or - (UseTls == - optional) -> - {true, false, false}; - required -> {true, true, false}; - required_trusted -> - {true, true, true} - end, - TLSOpts1 = case ejabberd_config:get_option( - s2s_certfile, - fun iolist_to_binary/1) of - undefined -> []; - CertFile -> [{certfile, CertFile}] - end, - TLSOpts2 = case ejabberd_config:get_option( - s2s_ciphers, fun iolist_to_binary/1) of - undefined -> TLSOpts1; - Ciphers -> [{ciphers, Ciphers} | TLSOpts1] - end, - TLSOpts3 = case ejabberd_config:get_option( - s2s_protocol_options, - fun (Options) -> - [_|O] = lists:foldl( - fun(X, Acc) -> X ++ Acc end, [], - [["|" | binary_to_list(Opt)] || Opt <- Options, is_binary(Opt)] - ), - iolist_to_binary(O) - end) of - undefined -> TLSOpts2; - ProtocolOpts -> [{protocol_options, ProtocolOpts} | TLSOpts2] - end, - TLSOpts4 = case ejabberd_config:get_option( - s2s_dhfile, fun iolist_to_binary/1) of - undefined -> TLSOpts3; - DHFile -> [{dhfile, DHFile} | TLSOpts3] - end, - TLSOpts = case proplists:get_bool(tls_compression, Opts) of - false -> [compression_none | TLSOpts4]; - true -> TLSOpts4 - end, - Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - {ok, wait_for_stream, - #state{socket = Socket, sockmod = SockMod, - streamid = new_id(), shaper = Shaper, tls = StartTLS, - tls_enabled = false, tls_required = TLSRequired, - tls_certverify = TLSCertverify, tls_options = TLSOpts, - timer = Timer}}. +socket_type() -> + xml_stream. -%%---------------------------------------------------------------------- -%% Func: StateName/2 -%% Returns: {next_state, NextStateName, NextStateData} | -%% {next_state, NextStateName, NextStateData, Timeout} | -%% {stop, Reason, NewStateData} -%%---------------------------------------------------------------------- -wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) -> - try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of - #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM} - when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM -> - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_invalid_namespace()), - {stop, normal, StateData}; - #stream_start{to = #jid{lserver = Server}, - from = From, version = {1,0}} - when StateData#state.tls and not StateData#state.authenticated -> - send_header(StateData, {1,0}), - Auth = if StateData#state.tls_enabled -> - case From of - #jid{} -> - {Result, Message} = - ejabberd_s2s:check_peer_certificate( - StateData#state.sockmod, - StateData#state.socket, - From#jid.lserver), - {Result, From#jid.lserver, Message}; - undefined -> - {error, <<"(unknown)">>, - <<"Got no valid 'from' attribute">>} - end; +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Stream, Pkt) -> + xmpp_stream_in:send(Stream, Pkt). + +-spec establish(state()) -> state(). +establish(State) -> + xmpp_stream_in:establish(State). + +-spec update_state(pid(), fun((state()) -> state()) | + {module(), atom(), list()}) -> ok. +update_state(Ref, Callback) -> + xmpp_stream_in:cast(Ref, {update_state, Callback}). + +-spec add_hooks() -> ok. +add_hooks() -> + lists:foreach( + fun(Host) -> + ejabberd_hooks:add(s2s_in_closed, Host, ?MODULE, + process_closed, 100), + ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE, + reject_unauthenticated_packet, 100), + ejabberd_hooks:add(s2s_in_handle_info, Host, ?MODULE, + handle_unexpected_info, 100), + ejabberd_hooks:add(s2s_in_handle_cast, Host, ?MODULE, + handle_unexpected_cast, 100) + end, ?MYHOSTS). + +%%%=================================================================== +%%% Hooks +%%%=================================================================== +handle_unexpected_info(State, Info) -> + ?WARNING_MSG("got unexpected info: ~p", [Info]), + State. + +handle_unexpected_cast(State, Msg) -> + ?WARNING_MSG("got unexpected cast: ~p", [Msg]), + State. + +reject_unauthenticated_packet(State, _Pkt) -> + Err = xmpp:serr_not_authorized(), + send(State, Err). + +process_closed(State, _Reason) -> + stop(State). + +%%%=================================================================== +%%% xmpp_stream_in callbacks +%%%=================================================================== +tls_options(#{tls_options := TLSOpts, server_host := LServer}) -> + ejabberd_s2s:tls_options(LServer, TLSOpts). + +tls_required(#{server_host := LServer}) -> + ejabberd_s2s:tls_required(LServer). + +tls_verify(#{server_host := LServer}) -> + ejabberd_s2s:tls_verify(LServer). + +tls_enabled(#{server_host := LServer}) -> + ejabberd_s2s:tls_enabled(LServer). + +compress_methods(#{server_host := LServer}) -> + case ejabberd_s2s:zlib_enabled(LServer) of + true -> [<<"zlib">>]; + false -> [] + end. + +unauthenticated_stream_features(#{server_host := LServer}) -> + ejabberd_hooks:run_fold(s2s_in_pre_auth_features, LServer, [], [LServer]). + +authenticated_stream_features(#{server_host := LServer}) -> + ejabberd_hooks:run_fold(s2s_in_post_auth_features, LServer, [], [LServer]). + +handle_stream_start(_StreamStart, #{lserver := LServer} = State) -> + case check_to(jid:make(LServer), State) of + false -> + send(State, xmpp:serr_host_unknown()); true -> - {no_verify, <<"(unknown)">>, <<"TLS not (yet) enabled">>} - end, - StartTLS = if StateData#state.tls_enabled -> []; - not StateData#state.tls_enabled and - not StateData#state.tls_required -> - [#starttls{required = false}]; - not StateData#state.tls_enabled and - StateData#state.tls_required -> - [#starttls{required = true}] - end, - case Auth of - {error, RemoteServer, CertError} - when StateData#state.tls_certverify -> - ?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)", - [StateData#state.server, RemoteServer, CertError]), - send_element(StateData, - xmpp:serr_policy_violation(CertError, ?MYLANG)), - {stop, normal, StateData}; - {VerifyResult, RemoteServer, Msg} -> - {SASL, NewStateData} = - case VerifyResult of - ok -> - {[#sasl_mechanisms{list = [<<"EXTERNAL">>]}], - StateData#state{auth_domain = RemoteServer}}; - error -> - ?DEBUG("Won't accept certificate of ~s: ~s", - [RemoteServer, Msg]), - {[], StateData}; - no_verify -> - {[], StateData} - end, - send_element(NewStateData, - #stream_features{ - sub_els = SASL ++ StartTLS ++ - ejabberd_hooks:run_fold( - s2s_stream_features, Server, [], - [Server])}), - {next_state, wait_for_feature_request, - NewStateData#state{server = Server}} - end; - #stream_start{to = #jid{lserver = Server}, - version = {1,0}} when StateData#state.authenticated -> - send_header(StateData, {1,0}), - send_element(StateData, - #stream_features{ - sub_els = ejabberd_hooks:run_fold( - s2s_stream_features, Server, [], - [Server])}), - {next_state, stream_established, StateData}; - #stream_start{db_xmlns = ?NS_SERVER_DIALBACK} - when (StateData#state.tls_required and StateData#state.tls_enabled) - or (not StateData#state.tls_required) -> - send_header(StateData, undefined), - {next_state, stream_established, StateData}; - #stream_start{} -> - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_undefined_condition()), - {stop, normal, StateData}; - _ -> - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_invalid_xml()), - {stop, normal, StateData} - catch _:{xmpp_codec, Why} -> - Txt = xmpp:format_error(Why), - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)), - {stop, normal, StateData} - end; -wait_for_stream({xmlstreamerror, _}, StateData) -> - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_stream(timeout, StateData) -> - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_connection_timeout()), - {stop, normal, StateData}; -wait_for_stream(closed, StateData) -> - {stop, normal, StateData}. + ServerHost = ejabberd_router:host_of_route(LServer), + State#{server_host => ServerHost} + end. -wait_for_feature_request({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_feature_request, StateData); -wait_for_feature_request(#starttls{}, - #state{tls = true, tls_enabled = false} = StateData) -> - case (StateData#state.sockmod):get_sockmod(StateData#state.socket) of - gen_tcp -> - ?DEBUG("starttls", []), - Socket = StateData#state.socket, - TLSOpts1 = case - ejabberd_config:get_option( - {domain_certfile, StateData#state.server}, - fun iolist_to_binary/1) of - undefined -> StateData#state.tls_options; - CertFile -> - lists:keystore(certfile, 1, - StateData#state.tls_options, - {certfile, CertFile}) - end, - TLSOpts2 = case ejabberd_config:get_option( - {s2s_cafile, StateData#state.server}, - fun iolist_to_binary/1) of - undefined -> TLSOpts1; - CAFile -> - lists:keystore(cafile, 1, TLSOpts1, - {cafile, CAFile}) - end, - TLSOpts = case ejabberd_config:get_option( - {s2s_tls_compression, StateData#state.server}, - fun(true) -> true; - (false) -> false - end, false) of - true -> lists:delete(compression_none, TLSOpts2); - false -> [compression_none | TLSOpts2] - end, - TLSSocket = (StateData#state.sockmod):starttls( - Socket, TLSOpts, - fxml:element_to_binary( - xmpp:encode(#starttls_proceed{}))), - {next_state, wait_for_stream, - StateData#state{socket = TLSSocket, streamid = new_id(), - tls_enabled = true, tls_options = TLSOpts}}; - _ -> - send_element(StateData, #starttls_failure{}), - {stop, normal, StateData} - end; -wait_for_feature_request(#sasl_auth{mechanism = Mech}, - #state{tls_enabled = true} = StateData) -> - case Mech of - <<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> -> - AuthDomain = StateData#state.auth_domain, - AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>, AuthDomain), - if AllowRemoteHost -> - (StateData#state.sockmod):reset_stream(StateData#state.socket), - send_element(StateData, #sasl_success{}), - ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)", - [AuthDomain, StateData#state.tls_enabled]), - change_shaper(StateData, <<"">>, jid:make(AuthDomain)), - {next_state, wait_for_stream, - StateData#state{streamid = new_id(), - authenticated = true}}; +handle_stream_end(Reason, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [Reason]). + +handle_stream_established(State) -> + set_idle_timeout(State#{established => true}). + +handle_auth_success(RServer, Mech, _AuthModule, + #{sockmod := SockMod, + socket := Socket, ip := IP, + auth_domains := AuthDomains, + server_host := ServerHost, + lserver := LServer} = State) -> + ?INFO_MSG("(~s) Accepted inbound s2s ~s authentication ~s -> ~s (~s)", + [SockMod:pp(Socket), Mech, RServer, LServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + State1 = case ejabberd_s2s:allow_host(ServerHost, RServer) of true -> - Txt = xmpp:mk_text(<<"Denied by ACL">>, ?MYLANG), - send_element(StateData, - #sasl_failure{reason = 'not-authorized', - text = Txt}), - {stop, normal, StateData} - end; - _ -> - send_element(StateData, #sasl_failure{reason = 'invalid-mechanism'}), - {stop, normal, StateData} - end; -wait_for_feature_request({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -wait_for_feature_request({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_feature_request(closed, StateData) -> - {stop, normal, StateData}; -wait_for_feature_request(_Pkt, #state{tls_required = TLSRequired, - tls_enabled = TLSEnabled} = StateData) - when TLSRequired and not TLSEnabled -> - Txt = <<"Use of STARTTLS required">>, - send_element(StateData, xmpp:serr_policy_violation(Txt, ?MYLANG)), - {stop, normal, StateData}; -wait_for_feature_request(El, StateData) -> - stream_established({xmlstreamelement, El}, StateData). - -stream_established({xmlstreamelement, El}, StateData) -> - cancel_timer(StateData#state.timer), - Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - decode_element(El, stream_established, StateData#state{timer = Timer}); -stream_established(#db_result{to = To, from = From, key = Key}, - StateData) -> - ?DEBUG("GET KEY: ~p", [{To, From, Key}]), - case {ejabberd_s2s:allow_host(To, From), - lists:member(To, ejabberd_router:dirty_get_all_domains())} of - {true, true} -> - ejabberd_s2s_out:terminate_if_waiting_delay(To, From), - ejabberd_s2s_out:start(To, From, - {verify, self(), Key, - StateData#state.streamid}), - Conns = (?DICT):store({From, To}, - wait_for_verification, - StateData#state.connections), - change_shaper(StateData, To, jid:make(From)), - {next_state, stream_established, - StateData#state{connections = Conns}}; - {_, false} -> - send_element(StateData, xmpp:serr_host_unknown()), - {stop, normal, StateData}; - {false, _} -> - send_element(StateData, xmpp:serr_invalid_from()), - {stop, normal, StateData} - end; -stream_established(#db_verify{to = To, from = From, id = Id, key = Key}, - StateData) -> - ?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]), - Type = case ejabberd_s2s:make_key({To, From}, Id) of - Key -> valid; - _ -> invalid + AuthDomains1 = sets:add_element(RServer, AuthDomains), + change_shaper(State, RServer), + State#{auth_domains => AuthDomains1}; + false -> + State end, - send_element(StateData, - #db_verify{from = To, to = From, id = Id, type = Type}), - {next_state, stream_established, StateData}; -stream_established(Pkt, StateData) when ?is_stanza(Pkt) -> + ejabberd_hooks:run_fold(s2s_in_auth_result, ServerHost, State1, [true, RServer]). + +handle_auth_failure(RServer, Mech, Reason, + #{sockmod := SockMod, + socket := Socket, ip := IP, + server_host := ServerHost, + lserver := LServer} = State) -> + ?INFO_MSG("(~s) Failed inbound s2s ~s authentication ~s -> ~s (~s): ~s", + [SockMod:pp(Socket), Mech, RServer, LServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), + ejabberd_hooks:run_fold(s2s_in_auth_result, + ServerHost, State, [false, RServer]). + +handle_unauthenticated_packet(Pkt, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_unauthenticated_packet, + LServer, State, [Pkt]). + +handle_authenticated_packet(Pkt, #{server_host := LServer} = State) when not ?is_stanza(Pkt) -> + ejabberd_hooks:run_fold(s2s_in_authenticated_packet, LServer, State, [Pkt]); +handle_authenticated_packet(Pkt, State) -> From = xmpp:get_from(Pkt), To = xmpp:get_to(Pkt), - if To /= undefined, From /= undefined -> - LFrom = From#jid.lserver, - LTo = To#jid.lserver, - if StateData#state.authenticated -> - case LFrom == StateData#state.auth_domain andalso - lists:member(LTo, ejabberd_router:dirty_get_all_domains()) of - true -> - ejabberd_hooks:run(s2s_receive_packet, LTo, - [From, To, Pkt]), - ejabberd_router:route(From, To, Pkt); - false -> - send_error(StateData, Pkt, xmpp:err_not_authorized()) - end; - true -> - case (?DICT):find({LFrom, LTo}, StateData#state.connections) of - {ok, established} -> - ejabberd_hooks:run(s2s_receive_packet, LTo, - [From, To, Pkt]), - ejabberd_router:route(From, To, Pkt); - _ -> - send_error(StateData, Pkt, xmpp:err_not_authorized()) - end - end; - true -> - send_error(StateData, Pkt, xmpp:err_jid_malformed()) + case check_from_to(From, To, State) of + ok -> + LServer = ejabberd_router:host_of_route(To#jid.lserver), + State1 = ejabberd_hooks:run_fold(s2s_in_authenticated_packet, + LServer, State, [Pkt]), + {Pkt1, State2} = ejabberd_hooks:run_fold(s2s_receive_packet, LServer, + {Pkt, State1}, []), + case Pkt1 of + drop -> ok; + _ -> ejabberd_router:route(From, To, Pkt1) end, - ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]), - {next_state, stream_established, StateData}; -stream_established({valid, From, To}, StateData) -> - send_element(StateData, - #db_result{from = To, to = From, type = valid}), - ?INFO_MSG("Accepted s2s dialback authentication for ~s (TLS=~p)", - [From, StateData#state.tls_enabled]), - NSD = StateData#state{connections = - (?DICT):store({From, To}, established, - StateData#state.connections)}, - {next_state, stream_established, NSD}; -stream_established({invalid, From, To}, StateData) -> - send_element(StateData, - #db_result{from = To, to = From, type = invalid}), - NSD = StateData#state{connections = - (?DICT):erase({From, To}, - StateData#state.connections)}, - {next_state, stream_established, NSD}; -stream_established({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -stream_established({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -stream_established(timeout, StateData) -> - send_element(StateData, xmpp:serr_connection_timeout()), - {stop, normal, StateData}; -stream_established(closed, StateData) -> - {stop, normal, StateData}; -stream_established(Pkt, StateData) -> - ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]), - {next_state, stream_established, StateData}. + State2; + {error, Err} -> + send(State, Err) + end. -%%---------------------------------------------------------------------- -%% Func: StateName/3 -%% Returns: {next_state, NextStateName, NextStateData} | -%% {next_state, NextStateName, NextStateData, Timeout} | -%% {reply, Reply, NextStateName, NextStateData} | -%% {reply, Reply, NextStateName, NextStateData, Timeout} | -%% {stop, Reason, NewStateData} | -%% {stop, Reason, Reply, NewStateData} -%%---------------------------------------------------------------------- -%state_name(Event, From, StateData) -> -% Reply = ok, -% {reply, Reply, state_name, StateData}. +handle_cdata(Data, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_handle_cdata, LServer, State, [Data]). -handle_event(_Event, StateName, StateData) -> - {next_state, StateName, StateData}. +handle_recv(El, Pkt, #{server_host := LServer} = State) -> + State1 = set_idle_timeout(State), + ejabberd_hooks:run_fold(s2s_in_handle_recv, LServer, State1, [El, Pkt]). -handle_sync_event(get_state_infos, _From, StateName, - StateData) -> - SockMod = StateData#state.sockmod, - {Addr, Port} = try - SockMod:peername(StateData#state.socket) - of - {ok, {A, P}} -> {A, P}; - {error, _} -> {unknown, unknown} - catch - _:_ -> {unknown, unknown} +handle_send(Pkt, Result, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_handle_send, LServer, + State, [Pkt, Result]). + +init([State, Opts]) -> + Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none), + TLSOpts1 = lists:filter( + fun({certfile, _}) -> true; + ({ciphers, _}) -> true; + ({dhfile, _}) -> true; + ({cafile, _}) -> true; + (_) -> false + end, Opts), + TLSOpts2 = case lists:keyfind(protocol_options, 1, Opts) of + false -> TLSOpts1; + {_, OptString} -> + ProtoOpts = str:join(OptString, <<$|>>), + [{protocol_options, ProtoOpts}|TLSOpts1] end, - Domains = get_external_hosts(StateData), - Infos = [{direction, in}, {statename, StateName}, - {addr, Addr}, {port, Port}, - {streamid, StateData#state.streamid}, - {tls, StateData#state.tls}, - {tls_enabled, StateData#state.tls_enabled}, - {tls_options, StateData#state.tls_options}, - {authenticated, StateData#state.authenticated}, - {shaper, StateData#state.shaper}, {sockmod, SockMod}, - {domains, Domains}], - Reply = {state_infos, Infos}, - {reply, Reply, StateName, StateData}; -%%---------------------------------------------------------------------- -%% Func: handle_sync_event/4 -%% Returns: {next_state, NextStateName, NextStateData} | -%% {next_state, NextStateName, NextStateData, Timeout} | -%% {reply, Reply, NextStateName, NextStateData} | -%% {reply, Reply, NextStateName, NextStateData, Timeout} | -%% {stop, Reason, NewStateData} | -%% {stop, Reason, Reply, NewStateData} -%%---------------------------------------------------------------------- -handle_sync_event(_Event, _From, StateName, - StateData) -> - Reply = ok, {reply, Reply, StateName, StateData}. - -code_change(_OldVsn, StateName, StateData, _Extra) -> - {ok, StateName, StateData}. - -handle_info({send_text, Text}, StateName, StateData) -> - send_text(StateData, Text), - {next_state, StateName, StateData}; -handle_info({timeout, Timer, _}, StateName, - #state{timer = Timer} = StateData) -> - if StateName == wait_for_stream -> - send_header(StateData, undefined); - true -> - ok + TLSOpts3 = case proplists:get_bool(tls_compression, Opts) of + false -> [compression_none | TLSOpts2]; + true -> TLSOpts2 end, - send_element(StateData, xmpp:serr_connection_timeout()), - {stop, normal, StateData}; -handle_info(_, StateName, StateData) -> - {next_state, StateName, StateData}. + State1 = State#{tls_options => TLSOpts3, + auth_domains => sets:new(), + xmlns => ?NS_SERVER, + lang => ?MYLANG, + server => ?MYNAME, + lserver => ?MYNAME, + server_host => ?MYNAME, + established => false, + shaper => Shaper}, + ejabberd_hooks:run_fold(s2s_in_init, {ok, State1}, [Opts]). -terminate(Reason, _StateName, StateData) -> - ?DEBUG("terminated: ~p", [Reason]), +handle_call(Request, From, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_handle_call, LServer, State, [Request, From]). + +handle_cast({update_state, Fun}, State) -> + case Fun of + {M, F, A} -> erlang:apply(M, F, [State|A]); + _ when is_function(Fun) -> Fun(State) + end; +handle_cast(Msg, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_handle_cast, LServer, State, [Msg]). + +handle_info(Info, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_handle_info, LServer, State, [Info]). + +terminate(Reason, #{auth_domains := AuthDomains}) -> case Reason of {process_limit, _} -> - [ejabberd_s2s:external_host_overloaded(Host) - || Host <- get_external_hosts(StateData)]; - _ -> ok - end, - catch send_trailer(StateData), - (StateData#state.sockmod):close(StateData#state.socket), - ok. - -get_external_hosts(StateData) -> - case StateData#state.authenticated of - true -> [StateData#state.auth_domain]; - false -> - Connections = StateData#state.connections, - [D - || {{D, _}, established} <- dict:to_list(Connections)] + sets:fold( + fun(Host, _) -> + ejabberd_s2s:external_host_overloaded(Host) + end, ok, AuthDomains); + _ -> + ok end. -print_state(State) -> State. +code_change(_OldVsn, State, _Extra) -> + {ok, State}. -%%%---------------------------------------------------------------------- +%%%=================================================================== %%% Internal functions -%%%---------------------------------------------------------------------- - --spec send_text(state(), iodata()) -> ok. -send_text(StateData, Text) -> - (StateData#state.sockmod):send(StateData#state.socket, - Text). - --spec send_element(state(), xmpp_element()) -> ok. -send_element(StateData, El) -> - El1 = xmpp:encode(El, ?NS_SERVER), - send_text(StateData, fxml:element_to_binary(El1)). - --spec send_error(state(), xmlel() | stanza(), stanza_error()) -> ok. -send_error(StateData, Stanza, Error) -> - Type = xmpp:get_type(Stanza), - if Type == error; Type == result; - Type == <<"error">>; Type == <<"result">> -> - ok; +%%%=================================================================== +-spec check_from_to(jid(), jid(), state()) -> ok | {error, stream_error()}. +check_from_to(From, To, State) -> + case check_from(From, State) of true -> - send_element(StateData, xmpp:make_error(Stanza, Error)) - end. - --spec send_trailer(state()) -> ok. -send_trailer(StateData) -> - send_text(StateData, <<"">>). - --spec send_header(state(), undefined | {integer(), integer()}) -> ok. -send_header(StateData, Version) -> - Header = xmpp:encode( - #stream_start{xmlns = ?NS_SERVER, - stream_xmlns = ?NS_STREAM, - db_xmlns = ?NS_SERVER_DIALBACK, - id = StateData#state.streamid, - version = Version}), - send_text(StateData, fxml:element_to_header(Header)). - --spec change_shaper(state(), binary(), jid()) -> ok. -change_shaper(StateData, Host, JID) -> - Shaper = acl:match_rule(Host, StateData#state.shaper, - JID), - (StateData#state.sockmod):change_shaper(StateData#state.socket, - Shaper). - --spec new_id() -> binary(). -new_id() -> randoms:get_string(). - --spec cancel_timer(reference()) -> ok. -cancel_timer(Timer) -> - erlang:cancel_timer(Timer), - receive {timeout, Timer, _} -> ok after 0 -> ok end. - -fsm_limit_opts(Opts) -> - case lists:keysearch(max_fsm_queue, 1, Opts) of - {value, {_, N}} when is_integer(N) -> [{max_queue, N}]; - _ -> - case ejabberd_config:get_option( - max_fsm_queue, - fun(I) when is_integer(I), I > 0 -> I end) of - undefined -> []; - N -> [{max_queue, N}] - end - end. - --spec decode_element(xmlel() | xmpp_element(), state_name(), state()) -> fsm_transition(). -decode_element(#xmlel{} = El, StateName, StateData) -> - Opts = if StateName == stream_established -> - [ignore_els]; + case check_to(To, State) of true -> - [] - end, - try xmpp:decode(El, ?NS_SERVER, Opts) of - Pkt -> ?MODULE:StateName(Pkt, StateData) - catch error:{xmpp_codec, Why} -> - case xmpp:is_stanza(El) of - true -> - Lang = xmpp:get_lang(El), - Txt = xmpp:format_error(Why), - send_error(StateData, El, xmpp:err_bad_request(Txt, Lang)); + ok; false -> - ok - end, - {next_state, StateName, StateData} + {error, xmpp:serr_host_unknown()} end; -decode_element(Pkt, StateName, StateData) -> - ?MODULE:StateName(Pkt, StateData). + false -> + {error, xmpp:serr_invalid_from()} + end. + +-spec check_from(jid(), state()) -> boolean(). +check_from(#jid{lserver = S1}, #{auth_domains := AuthDomains}) -> + sets:is_element(S1, AuthDomains). + +-spec check_to(jid(), state()) -> boolean(). +check_to(#jid{lserver = LServer}, _State) -> + ejabberd_router:is_my_route(LServer). + +-spec set_idle_timeout(state()) -> state(). +set_idle_timeout(#{server_host := LServer, + established := true} = State) -> + Timeout = ejabberd_s2s:get_idle_timeout(LServer), + xmpp_stream_in:set_timeout(State, Timeout); +set_idle_timeout(State) -> + State. + +-spec change_shaper(state(), binary()) -> ok. +change_shaper(#{shaper := ShaperName, server_host := ServerHost} = State, + RServer) -> + Shaper = acl:match_rule(ServerHost, ShaperName, jid:make(RServer)), + xmpp_stream_in:change_shaper(State, Shaper). -opt_type(domain_certfile) -> fun iolist_to_binary/1; -opt_type(max_fsm_queue) -> - fun (I) when is_integer(I), I > 0 -> I end; -opt_type(s2s_certfile) -> fun iolist_to_binary/1; -opt_type(s2s_cafile) -> fun iolist_to_binary/1; -opt_type(s2s_ciphers) -> fun iolist_to_binary/1; -opt_type(s2s_dhfile) -> fun iolist_to_binary/1; -opt_type(s2s_protocol_options) -> - fun (Options) -> - [_ | O] = lists:foldl(fun (X, Acc) -> X ++ Acc end, [], - [["|" | binary_to_list(Opt)] - || Opt <- Options, is_binary(Opt)]), - iolist_to_binary(O) - end; -opt_type(s2s_tls_compression) -> - fun (true) -> true; - (false) -> false - end; -opt_type(s2s_use_starttls) -> - fun (false) -> false; - (true) -> true; - (optional) -> optional; - (required) -> required; - (required_trusted) -> required_trusted - end; opt_type(_) -> - [domain_certfile, max_fsm_queue, s2s_certfile, s2s_cafile, - s2s_ciphers, s2s_dhfile, s2s_protocol_options, - s2s_tls_compression, s2s_use_starttls]. + []. diff --git a/src/ejabberd_s2s_out.erl b/src/ejabberd_s2s_out.erl index fb3dfcbd5..a923860f3 100644 --- a/src/ejabberd_s2s_out.erl +++ b/src/ejabberd_s2s_out.erl @@ -1,8 +1,5 @@ -%%%---------------------------------------------------------------------- -%%% File : ejabberd_s2s_out.erl -%%% Author : Alexey Shchepin -%%% Purpose : Manage outgoing server-to-server connections -%%% Created : 6 Dec 2002 by Alexey Shchepin +%%%------------------------------------------------------------------- +%%% Created : 16 Dec 2016 by Evgeny Khramtsov %%% %%% %%% ejabberd, Copyright (C) 2002-2017 ProcessOne @@ -21,953 +18,356 @@ %%% with this program; if not, write to the Free Software Foundation, Inc., %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. %%% -%%%---------------------------------------------------------------------- - +%%%------------------------------------------------------------------- -module(ejabberd_s2s_out). - +-behaviour(xmpp_stream_out). -behaviour(ejabberd_config). --author('alexey@process-one.net'). - --behaviour(p1_fsm). - -%% External exports --export([start/3, - start_link/3, - start_connection/1, - terminate_if_waiting_delay/2, - stop_connection/1, - transform_options/1]). - --export([init/1, open_socket/2, wait_for_stream/2, - wait_for_validation/2, wait_for_features/2, - wait_for_auth_result/2, wait_for_starttls_proceed/2, - relay_to_bridge/2, reopen_socket/2, wait_before_retry/2, - stream_established/2, handle_event/3, - handle_sync_event/4, handle_info/3, terminate/3, - print_state/1, code_change/4, test_get_addr_port/1, - get_addr_port/1, opt_type/1]). +%% ejabberd_config callbacks +-export([opt_type/1, transform_options/1]). +%% xmpp_stream_out callbacks +-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, + connect_timeout/1, address_families/1, default_port/1, + dns_retries/1, dns_timeout/1, + handle_auth_success/2, handle_auth_failure/3, handle_packet/2, + handle_stream_end/2, handle_stream_downgraded/2, + handle_recv/3, handle_send/3, handle_cdata/2, + handle_stream_established/1, handle_timeout/1]). +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). +%% Hooks +-export([process_auth_result/2, process_closed/2, handle_unexpected_info/2, + handle_unexpected_cast/2, process_downgraded/2]). +%% API +-export([start/3, start_link/3, connect/1, close/1, stop/1, send/2, + route/2, establish/1, update_state/2, add_hooks/0]). -include("ejabberd.hrl"). --include("logger.hrl"). -include("xmpp.hrl"). +-include("logger.hrl"). --record(state, - {socket :: ejabberd_socket:socket_state(), - streamid = <<"">> :: binary(), - remote_streamid = <<"">> :: binary(), - use_v10 = true :: boolean(), - tls = false :: boolean(), - tls_required = false :: boolean(), - tls_certverify = false :: boolean(), - tls_enabled = false :: boolean(), - tls_options = [connect] :: list(), - authenticated = false :: boolean(), - db_enabled = true :: boolean(), - try_auth = true :: boolean(), - myname = <<"">> :: binary(), - server = <<"">> :: binary(), - queue = queue:new() :: ?TQUEUE, - delay_to_retry = undefined_delay :: undefined_delay | non_neg_integer(), - new = false :: boolean(), - verify = false :: false | {pid(), binary(), binary()}, - bridge :: {atom(), atom()}, - timer = make_ref() :: reference()}). +-type state() :: map(). +-export_type([state/0]). --type state_name() :: open_socket | wait_for_stream | - wait_for_validation | wait_for_features | - wait_for_auth_result | wait_for_starttls_proceed | - relay_to_bridge | reopen_socket | wait_before_retry | - stream_established. --type state() :: #state{}. --type fsm_stop() :: {stop, normal, state()}. --type fsm_next() :: {next_state, state_name(), state(), non_neg_integer()} | - {next_state, state_name(), state()}. --type fsm_transition() :: fsm_stop() | fsm_next(). - -%%-define(DBGFSM, true). - --ifdef(DBGFSM). - --define(FSMOPTS, [{debug, [trace]}]). - --else. - --define(FSMOPTS, []). - --endif. - --define(FSMTIMEOUT, 30000). - -%% We do not block on send anymore. --define(TCP_SEND_TIMEOUT, 15000). - -%% Maximum delay to wait before retrying to connect after a failed attempt. -%% Specified in miliseconds. Default value is 5 minutes. --define(MAX_RETRY_DELAY, 300000). - --define(SOCKET_DEFAULT_RESULT, {error, badarg}). - -%%%---------------------------------------------------------------------- +%%%=================================================================== %%% API -%%%---------------------------------------------------------------------- -start(From, Host, Type) -> - supervisor:start_child(ejabberd_s2s_out_sup, - [From, Host, Type]). - -start_link(From, Host, Type) -> - p1_fsm:start_link(ejabberd_s2s_out, [From, Host, Type], - fsm_limit_opts() ++ (?FSMOPTS)). - -start_connection(Pid) -> p1_fsm:send_event(Pid, init). - -stop_connection(Pid) -> p1_fsm:send_event(Pid, closed). - -%%%---------------------------------------------------------------------- -%%% Callback functions from p1_fsm -%%%---------------------------------------------------------------------- - -init([From, Server, Type]) -> - process_flag(trap_exit, true), - ?DEBUG("started: ~p", [{From, Server, Type}]), - {TLS, TLSRequired, TLSCertverify} = - case ejabberd_config:get_option( - s2s_use_starttls, - fun(true) -> true; - (false) -> false; - (optional) -> optional; - (required) -> required; - (required_trusted) -> required_trusted - end) - of - UseTls - when (UseTls == undefined) or (UseTls == false) -> - {false, false, false}; - UseTls - when (UseTls == true) or (UseTls == optional) -> - {true, false, false}; - required -> - {true, true, false}; - required_trusted -> - {true, true, true} - end, - UseV10 = TLS, - TLSOpts1 = case - ejabberd_config:get_option( - s2s_certfile, fun iolist_to_binary/1) - of - undefined -> [connect]; - CertFile -> [{certfile, CertFile}, connect] - end, - TLSOpts2 = case ejabberd_config:get_option( - s2s_ciphers, fun iolist_to_binary/1) of - undefined -> TLSOpts1; - Ciphers -> [{ciphers, Ciphers} | TLSOpts1] - end, - TLSOpts3 = case ejabberd_config:get_option( - s2s_protocol_options, - fun (Options) -> - [_|O] = lists:foldl( - fun(X, Acc) -> X ++ Acc end, [], - [["|" | binary_to_list(Opt)] || Opt <- Options, is_binary(Opt)] - ), - iolist_to_binary(O) - end) of - undefined -> TLSOpts2; - ProtocolOpts -> [{protocol_options, ProtocolOpts} | TLSOpts2] - end, - TLSOpts4 = case ejabberd_config:get_option( - s2s_dhfile, fun iolist_to_binary/1) of - undefined -> TLSOpts3; - DHFile -> [{dhfile, DHFile} | TLSOpts3] - end, - TLSOpts = case ejabberd_config:get_option( - {s2s_tls_compression, From}, - fun(true) -> true; - (false) -> false - end, false) of - false -> [compression_none | TLSOpts4]; - true -> TLSOpts4 - end, - {New, Verify} = case Type of - new -> {true, false}; - {verify, Pid, Key, SID} -> - start_connection(self()), {false, {Pid, Key, SID}} - end, - Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - {ok, open_socket, - #state{use_v10 = UseV10, tls = TLS, - tls_required = TLSRequired, tls_certverify = TLSCertverify, - tls_options = TLSOpts, queue = queue:new(), myname = From, - server = Server, new = New, verify = Verify, timer = Timer}}. - -open_socket(init, StateData) -> - log_s2s_out(StateData#state.new, StateData#state.myname, - StateData#state.server, StateData#state.tls), - ?DEBUG("open_socket: ~p", - [{StateData#state.myname, StateData#state.server, - StateData#state.new, StateData#state.verify}]), - AddrList = case - ejabberd_idna:domain_utf8_to_ascii(StateData#state.server) - of - false -> []; - ASCIIAddr -> get_addr_port(ASCIIAddr) - end, - case lists:foldl(fun ({Addr, Port}, Acc) -> - case Acc of - {ok, Socket} -> {ok, Socket}; - _ -> open_socket1(Addr, Port) - end - end, - ?SOCKET_DEFAULT_RESULT, AddrList) - of - {ok, Socket} -> - Version = if StateData#state.use_v10 -> {1,0}; - true -> undefined - end, - NewStateData = StateData#state{socket = Socket, - tls_enabled = false, - streamid = new_id()}, - send_header(NewStateData, Version), - {next_state, wait_for_stream, NewStateData, - ?FSMTIMEOUT}; - {error, Reason} -> - ?INFO_MSG("s2s connection: ~s -> ~s (remote server " - "not found: ~p)", - [StateData#state.myname, StateData#state.server, Reason]), - case ejabberd_hooks:run_fold(find_s2s_bridge, undefined, - [StateData#state.myname, - StateData#state.server]) - of - {Mod, Fun, Type} -> - ?INFO_MSG("found a bridge to ~s for: ~s -> ~s", - [Type, StateData#state.myname, - StateData#state.server]), - NewStateData = StateData#state{bridge = {Mod, Fun}}, - {next_state, relay_to_bridge, NewStateData}; - _ -> wait_before_reconnect(StateData) - end - end; -open_socket(Event, StateData) -> - handle_unexpected_event(Event, open_socket, StateData). - -open_socket1({_, _, _, _} = Addr, Port) -> - open_socket2(inet, Addr, Port); -%% IPv6 -open_socket1({_, _, _, _, _, _, _, _} = Addr, Port) -> - open_socket2(inet6, Addr, Port); -%% Hostname -open_socket1(Host, Port) -> - lists:foldl(fun (_Family, {ok, _Socket} = R) -> R; - (Family, _) -> - Addrs = get_addrs(Host, Family), - lists:foldl(fun (_Addr, {ok, _Socket} = R) -> R; - (Addr, _) -> open_socket1(Addr, Port) - end, - ?SOCKET_DEFAULT_RESULT, Addrs) - end, - ?SOCKET_DEFAULT_RESULT, outgoing_s2s_families()). - -open_socket2(Type, Addr, Port) -> - ?DEBUG("s2s_out: connecting to ~p:~p~n", [Addr, Port]), - Timeout = outgoing_s2s_timeout(), - case catch ejabberd_socket:connect(Addr, Port, - [binary, {packet, 0}, - {send_timeout, ?TCP_SEND_TIMEOUT}, - {send_timeout_close, true}, - {active, false}, Type], - Timeout) - of - {ok, _Socket} = R -> R; - {error, Reason} = R -> - ?DEBUG("s2s_out: connect return ~p~n", [Reason]), R; - {'EXIT', Reason} -> - ?DEBUG("s2s_out: connect crashed ~p~n", [Reason]), - {error, Reason} - end. - -%%---------------------------------------------------------------------- - -wait_for_stream({xmlstreamstart, Name, Attrs}, StateData0) -> - {CertCheckRes, CertCheckMsg, StateData} = - if StateData0#state.tls_certverify, StateData0#state.tls_enabled -> - {Res, Msg} = - ejabberd_s2s:check_peer_certificate(ejabberd_socket, - StateData0#state.socket, - StateData0#state.server), - ?DEBUG("Certificate verification result for ~s: ~s", - [StateData0#state.server, Msg]), - {Res, Msg, StateData0#state{tls_certverify = false}}; +%%%=================================================================== +start(From, To, Opts) -> + case proplists:get_value(supervisor, Opts, true) of true -> - {no_verify, <<"Not verified">>, StateData0} - end, - try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of - _ when CertCheckRes == error -> - send_element(StateData, - xmpp:serr_policy_violation(CertCheckMsg, ?MYLANG)), - ?INFO_MSG("Closing s2s connection: ~s -> ~s (~s)", - [StateData#state.myname, StateData#state.server, - CertCheckMsg]), - {stop, normal, StateData}; - #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM} - when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM -> - send_element(StateData, xmpp:serr_invalid_namespace()), - {stop, normal, StateData}; - #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID, - version = V} when V /= {1,0} -> - send_db_request(StateData#state{remote_streamid = ID}); - #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID} - when StateData#state.use_v10 -> - {next_state, wait_for_features, - StateData#state{remote_streamid = ID}, ?FSMTIMEOUT}; - #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID} - when not StateData#state.use_v10 -> - %% Handle Tigase's workaround for an old ejabberd bug: - send_db_request(StateData#state{remote_streamid = ID}); - #stream_start{id = ID} when StateData#state.use_v10 -> - {next_state, wait_for_features, - StateData#state{db_enabled = false, remote_streamid = ID}, - ?FSMTIMEOUT}; - #stream_start{} -> - send_element(StateData, xmpp:serr_invalid_namespace()), - {stop, normal, StateData}; - _ -> - send_element(StateData, xmpp:serr_invalid_xml()), - {stop, normal, StateData} - catch _:{xmpp_codec, Why} -> - Txt = xmpp:format_error(Why), - send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)), - {stop, normal, StateData} - end; -wait_for_stream(Event, StateData) -> - handle_unexpected_event(Event, wait_for_stream, StateData). - -wait_for_validation({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_validation, StateData); -wait_for_validation(#db_result{to = To, from = From, type = Type}, StateData) -> - ?DEBUG("recv result: ~p", [{From, To, Type}]), - case {Type, StateData#state.tls_enabled, StateData#state.tls_required} of - {valid, Enabled, Required} when (Enabled == true) or (Required == false) -> - send_queue(StateData, StateData#state.queue), - ?INFO_MSG("Connection established: ~s -> ~s with " - "TLS=~p", - [StateData#state.myname, StateData#state.server, - StateData#state.tls_enabled]), - ejabberd_hooks:run(s2s_connect_hook, - [StateData#state.myname, - StateData#state.server]), - {next_state, stream_established, StateData#state{queue = queue:new()}}; - {valid, Enabled, Required} when (Enabled == false) and (Required == true) -> - ?INFO_MSG("Closing s2s connection: ~s -> ~s (TLS " - "is required but unavailable)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; - _ -> - ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid " - "dialback key result)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData} - end; -wait_for_validation(#db_verify{to = To, from = From, id = Id, type = Type}, - StateData) -> - ?DEBUG("recv verify: ~p", [{From, To, Id, Type}]), - case StateData#state.verify of - false -> - NextState = wait_for_validation, - {next_state, NextState, StateData, get_timeout_interval(NextState)}; - {Pid, _Key, _SID} -> - case Type of - valid -> - p1_fsm:send_event(Pid, - {valid, StateData#state.server, - StateData#state.myname}); + supervisor:start_child(ejabberd_s2s_out_sup, + [From, To, Opts]); _ -> - p1_fsm:send_event(Pid, - {invalid, StateData#state.server, - StateData#state.myname}) - end, - if StateData#state.verify == false -> - {stop, normal, StateData}; - true -> - NextState = wait_for_validation, - {next_state, NextState, StateData, get_timeout_interval(NextState)} - end - end; -wait_for_validation(timeout, - #state{verify = {VPid, VKey, SID}} = StateData) - when is_pid(VPid) and is_binary(VKey) and is_binary(SID) -> - ?DEBUG("wait_for_validation: ~s -> ~s (timeout in verify connection)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -wait_for_validation(Event, StateData) -> - handle_unexpected_event(Event, wait_for_validation, StateData). - -wait_for_features({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_features, StateData); -wait_for_features(#stream_features{sub_els = Els}, StateData) -> - {SASLEXT, StartTLS, StartTLSRequired} = - lists:foldl( - fun(#sasl_mechanisms{list = Mechs}, {_, STLS, STLSReq}) -> - {lists:member(<<"EXTERNAL">>, Mechs), STLS, STLSReq}; - (#starttls{required = Required}, {SEXT, _, _}) -> - {SEXT, true, Required}; - (_, Acc) -> - Acc - end, {false, false, false}, Els), - if not SASLEXT and not StartTLS and StateData#state.authenticated -> - send_queue(StateData, StateData#state.queue), - ?INFO_MSG("Connection established: ~s -> ~s with " - "SASL EXTERNAL and TLS=~p", - [StateData#state.myname, StateData#state.server, - StateData#state.tls_enabled]), - ejabberd_hooks:run(s2s_connect_hook, - [StateData#state.myname, - StateData#state.server]), - {next_state, stream_established, - StateData#state{queue = queue:new()}}; - SASLEXT and StateData#state.try_auth and - (StateData#state.new /= false) and - (StateData#state.tls_enabled or - not StateData#state.tls_required) -> - send_element(StateData, - #sasl_auth{mechanism = <<"EXTERNAL">>, - text = StateData#state.myname}), - {next_state, wait_for_auth_result, - StateData#state{try_auth = false}, ?FSMTIMEOUT}; - StartTLS and StateData#state.tls and - not StateData#state.tls_enabled -> - send_element(StateData, #starttls{}), - {next_state, wait_for_starttls_proceed, StateData, ?FSMTIMEOUT}; - StartTLSRequired and not StateData#state.tls -> - ?DEBUG("restarted: ~p", - [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:close(StateData#state.socket), - {next_state, reopen_socket, - StateData#state{socket = undefined, use_v10 = false}, - ?FSMTIMEOUT}; - StateData#state.db_enabled -> - send_db_request(StateData); - true -> - ?DEBUG("restarted: ~p", - [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:close(StateData#state.socket), - {next_state, reopen_socket, - StateData#state{socket = undefined, use_v10 = false}, ?FSMTIMEOUT} - end; -wait_for_features(Event, StateData) -> - handle_unexpected_event(Event, wait_for_features, StateData). - -wait_for_auth_result({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_auth_result, StateData); -wait_for_auth_result(#sasl_success{}, StateData) -> - ?DEBUG("auth: ~p", [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:reset_stream(StateData#state.socket), - send_header(StateData, {1,0}), - {next_state, wait_for_stream, - StateData#state{streamid = new_id(), authenticated = true}, - ?FSMTIMEOUT}; -wait_for_auth_result(#sasl_failure{}, StateData) -> - ?DEBUG("restarted: ~p", [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:close(StateData#state.socket), - {next_state, reopen_socket, - StateData#state{socket = undefined}, ?FSMTIMEOUT}; -wait_for_auth_result(Event, StateData) -> - handle_unexpected_event(Event, wait_for_auth_result, StateData). - -wait_for_starttls_proceed({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_starttls_proceed, StateData); -wait_for_starttls_proceed(#starttls_proceed{}, StateData) -> - ?DEBUG("starttls: ~p", [{StateData#state.myname, StateData#state.server}]), - Socket = StateData#state.socket, - TLSOpts = case ejabberd_config:get_option( - {domain_certfile, StateData#state.myname}, - fun iolist_to_binary/1) of - undefined -> StateData#state.tls_options; - CertFile -> - [{certfile, CertFile} - | lists:keydelete(certfile, 1, - StateData#state.tls_options)] - end, - TLSSocket = ejabberd_socket:starttls(Socket, TLSOpts), - NewStateData = StateData#state{socket = TLSSocket, - streamid = new_id(), - tls_enabled = true, - tls_options = TLSOpts}, - send_header(NewStateData, {1,0}), - {next_state, wait_for_stream, NewStateData, ?FSMTIMEOUT}; -wait_for_starttls_proceed(Event, StateData) -> - handle_unexpected_event(Event, wait_for_starttls_proceed, StateData). - -reopen_socket({xmlstreamelement, _El}, StateData) -> - {next_state, reopen_socket, StateData, ?FSMTIMEOUT}; -reopen_socket({xmlstreamend, _Name}, StateData) -> - {next_state, reopen_socket, StateData, ?FSMTIMEOUT}; -reopen_socket({xmlstreamerror, _}, StateData) -> - {next_state, reopen_socket, StateData, ?FSMTIMEOUT}; -reopen_socket(timeout, StateData) -> - ?INFO_MSG("reopen socket: timeout", []), - {stop, normal, StateData}; -reopen_socket(closed, StateData) -> - p1_fsm:send_event(self(), init), - {next_state, open_socket, StateData, ?FSMTIMEOUT}. - -%% This state is use to avoid reconnecting to often to bad sockets -wait_before_retry(_Event, StateData) -> - {next_state, wait_before_retry, StateData, ?FSMTIMEOUT}. - -relay_to_bridge(stop, StateData) -> - wait_before_reconnect(StateData); -relay_to_bridge(closed, StateData) -> - ?INFO_MSG("relay to bridge: ~s -> ~s (closed)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -relay_to_bridge(_Event, StateData) -> - {next_state, relay_to_bridge, StateData}. - -stream_established({xmlstreamelement, El}, StateData) -> - decode_element(El, stream_established, StateData); -stream_established(#db_verify{to = VTo, from = VFrom, id = VId, type = VType}, - StateData) -> - ?DEBUG("recv verify: ~p", [{VFrom, VTo, VId, VType}]), - case StateData#state.verify of - {VPid, _VKey, _SID} -> - case VType of - valid -> - p1_fsm:send_event(VPid, - {valid, StateData#state.server, - StateData#state.myname}); - _ -> - p1_fsm:send_event(VPid, - {invalid, StateData#state.server, - StateData#state.myname}) - end; - _ -> ok - end, - {next_state, stream_established, StateData}; -stream_established(Event, StateData) -> - handle_unexpected_event(Event, stream_established, StateData). - --spec handle_unexpected_event(term(), state_name(), state()) -> fsm_transition(). -handle_unexpected_event(Event, StateName, StateData) -> - case Event of - {xmlstreamerror, _} -> - send_element(StateData, xmpp:serr_not_well_formed()), - ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " - "got invalid XML from peer", - [StateData#state.myname, StateData#state.server, - StateName]), - {stop, normal, StateData}; - {xmlstreamend, _} -> - ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " - "XML stream closed by peer", - [StateData#state.myname, StateData#state.server, - StateName]), - {stop, normal, StateData}; - timeout -> - send_element(StateData, xmpp:serr_connection_timeout()), - ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " - "timed out during establishing an XML stream", - [StateData#state.myname, StateData#state.server, - StateName]), - {stop, normal, StateData}; - closed -> - ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " - "connection socket closed", - [StateData#state.myname, StateData#state.server, - StateName]), - {stop, normal, StateData}; - Pkt when StateName == wait_for_stream; - StateName == wait_for_features; - StateName == wait_for_auth_result; - StateName == wait_for_starttls_proceed -> - send_element(StateData, xmpp:serr_bad_format()), - ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " - "got unexpected event ~p", - [StateData#state.myname, StateData#state.server, - StateName, Pkt]), - {stop, normal, StateData}; - _ -> - {next_state, StateName, StateData, get_timeout_interval(StateName)} + xmpp_stream_out:start(?MODULE, [ejabberd_socket, From, To, Opts], + ejabberd_config:fsm_limit_opts([])) end. -%%---------------------------------------------------------------------- -%% Func: StateName/3 -%% Returns: {next_state, NextStateName, NextStateData} | -%% {next_state, NextStateName, NextStateData, Timeout} | -%% {reply, Reply, NextStateName, NextStateData} | -%% {reply, Reply, NextStateName, NextStateData, Timeout} | -%% {stop, Reason, NewStateData} | -%% {stop, Reason, Reply, NewStateData} -%%---------------------------------------------------------------------- -%%state_name(Event, From, StateData) -> -%% Reply = ok, -%% {reply, Reply, state_name, StateData}. +start_link(From, To, Opts) -> + xmpp_stream_out:start_link(?MODULE, [ejabberd_socket, From, To, Opts], + ejabberd_config:fsm_limit_opts([])). -handle_event(_Event, StateName, StateData) -> - {next_state, StateName, StateData, - get_timeout_interval(StateName)}. +connect(Ref) -> + xmpp_stream_out:connect(Ref). -handle_sync_event(get_state_infos, _From, StateName, - StateData) -> - {Addr, Port} = try - ejabberd_socket:peername(StateData#state.socket) - of - {ok, {A, P}} -> {A, P}; - {error, _} -> {unknown, unknown} - catch - _:_ -> {unknown, unknown} - end, - Infos = [{direction, out}, {statename, StateName}, - {addr, Addr}, {port, Port}, - {streamid, StateData#state.streamid}, - {use_v10, StateData#state.use_v10}, - {tls, StateData#state.tls}, - {tls_required, StateData#state.tls_required}, - {tls_enabled, StateData#state.tls_enabled}, - {tls_options, StateData#state.tls_options}, - {authenticated, StateData#state.authenticated}, - {db_enabled, StateData#state.db_enabled}, - {try_auth, StateData#state.try_auth}, - {myname, StateData#state.myname}, - {server, StateData#state.server}, - {delay_to_retry, StateData#state.delay_to_retry}, - {verify, StateData#state.verify}], - Reply = {state_infos, Infos}, - {reply, Reply, StateName, StateData}; -%%---------------------------------------------------------------------- -%% Func: handle_sync_event/4 -%% Returns: {next_state, NextStateName, NextStateData} | -%% {next_state, NextStateName, NextStateData, Timeout} | -%% {reply, Reply, NextStateName, NextStateData} | -%% {reply, Reply, NextStateName, NextStateData, Timeout} | -%% {stop, Reason, NewStateData} | -%% {stop, Reason, Reply, NewStateData} -%%---------------------------------------------------------------------- -handle_sync_event(_Event, _From, StateName, - StateData) -> - Reply = ok, - {reply, Reply, StateName, StateData, - get_timeout_interval(StateName)}. +close(Ref) -> + xmpp_stream_out:close(Ref). -code_change(_OldVsn, StateName, StateData, _Extra) -> - {ok, StateName, StateData}. +stop(Ref) -> + xmpp_stream_out:stop(Ref). -handle_info({send_text, Text}, StateName, StateData) -> - send_text(StateData, Text), - cancel_timer(StateData#state.timer), - Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - {next_state, StateName, StateData#state{timer = Timer}, - get_timeout_interval(StateName)}; -handle_info({send_element, El}, StateName, StateData) -> - case StateName of - stream_established -> - cancel_timer(StateData#state.timer), - Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - send_element(StateData, El), - {next_state, StateName, StateData#state{timer = Timer}}; - %% In this state we bounce all message: We are waiting before - %% trying to reconnect - wait_before_retry -> - bounce_element(El, xmpp:err_remote_server_not_found()), - {next_state, StateName, StateData}; - relay_to_bridge -> - {Mod, Fun} = StateData#state.bridge, - ?DEBUG("relaying stanza via ~p:~p/1", [Mod, Fun]), - case catch Mod:Fun(El) of - {'EXIT', Reason} -> - ?ERROR_MSG("Error while relaying to bridge: ~p", - [Reason]), - bounce_element(El, xmpp:err_internal_server_error()), - wait_before_reconnect(StateData); - _ -> {next_state, StateName, StateData} - end; - _ -> - Q = queue:in(El, StateData#state.queue), - {next_state, StateName, StateData#state{queue = Q}, - get_timeout_interval(StateName)} - end; -handle_info({timeout, Timer, _}, wait_before_retry, - #state{timer = Timer} = StateData) -> - ?INFO_MSG("Reconnect delay expired: Will now retry " - "to connect to ~s when needed.", - [StateData#state.server]), - {stop, normal, StateData}; -handle_info({timeout, Timer, _}, _StateName, - #state{timer = Timer} = StateData) -> - ?INFO_MSG("Closing connection with ~s: timeout", - [StateData#state.server]), - {stop, normal, StateData}; -handle_info(terminate_if_waiting_before_retry, - wait_before_retry, StateData) -> - {stop, normal, StateData}; -handle_info(terminate_if_waiting_before_retry, - StateName, StateData) -> - {next_state, StateName, StateData, - get_timeout_interval(StateName)}; -handle_info(_, StateName, StateData) -> - {next_state, StateName, StateData, - get_timeout_interval(StateName)}. +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Stream, Pkt) -> + xmpp_stream_out:send(Stream, Pkt). -terminate(Reason, StateName, StateData) -> - ?DEBUG("terminated: ~p", [{Reason, StateName}]), - case StateData#state.new of - false -> ok; - true -> - ejabberd_s2s:remove_connection({StateData#state.myname, - StateData#state.server}, - self()) - end, - bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()), - bounce_messages(xmpp:err_remote_server_not_found()), - case StateData#state.socket of - undefined -> ok; - _Socket -> - catch send_trailer(StateData), - ejabberd_socket:close(StateData#state.socket) - end, - ok. +-spec route(pid(), xmpp_element()) -> ok. +route(Ref, Pkt) -> + Ref ! {route, Pkt}. -print_state(State) -> State. +-spec establish(state()) -> state(). +establish(State) -> + xmpp_stream_out:establish(State). -%%%---------------------------------------------------------------------- -%%% Internal functions -%%%---------------------------------------------------------------------- +-spec update_state(pid(), fun((state()) -> state()) | + {module(), atom(), list()}) -> ok. +update_state(Ref, Callback) -> + xmpp_stream_out:cast(Ref, {update_state, Callback}). --spec send_text(state(), iodata()) -> ok. -send_text(StateData, Text) -> - ?DEBUG("Send Text on stream = ~s", [Text]), - ejabberd_socket:send(StateData#state.socket, Text). +-spec add_hooks() -> ok. +add_hooks() -> + lists:foreach( + fun(Host) -> + ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE, + process_auth_result, 100), + ejabberd_hooks:add(s2s_out_closed, Host, ?MODULE, + process_closed, 100), + ejabberd_hooks:add(s2s_out_handle_info, Host, ?MODULE, + handle_unexpected_info, 100), + ejabberd_hooks:add(s2s_out_handle_cast, Host, ?MODULE, + handle_unexpected_cast, 100), + ejabberd_hooks:add(s2s_out_downgraded, Host, ?MODULE, + process_downgraded, 100) + end, ?MYHOSTS). --spec send_element(state(), xmpp_element()) -> ok. -send_element(StateData, El) -> - El1 = xmpp:encode(El, ?NS_SERVER), - send_text(StateData, fxml:element_to_binary(El1)). +%%%=================================================================== +%%% Hooks +%%%=================================================================== +process_auth_result(#{server := LServer, remote_server := RServer} = State, + {false, Reason}) -> + Delay = get_delay(), + ?INFO_MSG("Failed to establish outbound s2s connection ~s -> ~s: " + "authentication failed; bouncing for ~p seconds", + [LServer, RServer, Delay]), + State1 = State#{on_route => bounce, stop_reason => Reason}, + State2 = close(State1), + State3 = bounce_queue(State2), + xmpp_stream_out:set_timeout(State3, timer:seconds(Delay)); +process_auth_result(State, true) -> + State. --spec send_header(state(), undefined | {integer(), integer()}) -> ok. -send_header(StateData, Version) -> - Header = xmpp:encode( - #stream_start{xmlns = ?NS_SERVER, - stream_xmlns = ?NS_STREAM, - db_xmlns = ?NS_SERVER_DIALBACK, - from = jid:make(StateData#state.myname), - to = jid:make(StateData#state.server), - version = Version}), - send_text(StateData, fxml:element_to_header(Header)). +process_closed(#{server := LServer, remote_server := RServer, + on_route := send} = State, + Reason) -> + ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s", + [LServer, RServer, xmpp_stream_out:format_error(Reason)]), + stop(State); +process_closed(#{server := LServer, remote_server := RServer} = State, + Reason) -> + Delay = get_delay(), + ?INFO_MSG("Failed to establish outbound s2s connection ~s -> ~s: ~s; " + "bouncing for ~p seconds", + [LServer, RServer, xmpp_stream_out:format_error(Reason), Delay]), + State1 = State#{on_route => bounce}, + State2 = bounce_queue(State1), + xmpp_stream_out:set_timeout(State2, timer:seconds(Delay)). --spec send_trailer(state()) -> ok. -send_trailer(StateData) -> - send_text(StateData, <<"">>). +handle_unexpected_info(State, Info) -> + ?WARNING_MSG("got unexpected info: ~p", [Info]), + State. --spec send_queue(state(), queue:queue()) -> ok. -send_queue(StateData, Q) -> - case queue:out(Q) of - {{value, El}, Q1} -> - send_element(StateData, El), send_queue(StateData, Q1); - {empty, _Q1} -> ok - end. +handle_unexpected_cast(State, Msg) -> + ?WARNING_MSG("got unexpected cast: ~p", [Msg]), + State. -%% Bounce a single message (xmlelement) --spec bounce_element(stanza(), stanza_error()) -> ok. -bounce_element(El, Error) -> - From = xmpp:get_from(El), - To = xmpp:get_to(El), - ejabberd_router:route_error(To, From, El, Error). +process_downgraded(State, _StreamStart) -> + send(State, xmpp:serr_unsupported_version()). --spec bounce_queue(queue:queue(), stanza_error()) -> ok. -bounce_queue(Q, Error) -> - case queue:out(Q) of - {{value, El}, Q1} -> - bounce_element(El, Error), bounce_queue(Q1, Error); - {empty, _} -> ok - end. +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +tls_options(#{server := LServer}) -> + ejabberd_s2s:tls_options(LServer, []). --spec new_id() -> binary(). -new_id() -> randoms:get_string(). +tls_required(#{server := LServer}) -> + ejabberd_s2s:tls_required(LServer). --spec cancel_timer(reference()) -> ok. -cancel_timer(Timer) -> - erlang:cancel_timer(Timer), - receive {timeout, Timer, _} -> ok after 0 -> ok end. +tls_verify(#{server := LServer}) -> + ejabberd_s2s:tls_verify(LServer). --spec bounce_messages(stanza_error()) -> ok. -bounce_messages(Error) -> - receive - {send_element, El} -> - bounce_element(El, Error), bounce_messages(Error) - after 0 -> ok - end. +tls_enabled(#{server := LServer}) -> + ejabberd_s2s:tls_enabled(LServer). --spec send_db_request(state()) -> fsm_transition(). -send_db_request(StateData) -> - Server = StateData#state.server, - New = case StateData#state.new of - false -> - ejabberd_s2s:try_register({StateData#state.myname, Server}); - true -> - true - end, - NewStateData = StateData#state{new = New}, - try case New of - false -> ok; - true -> - Key1 = ejabberd_s2s:make_key( - {StateData#state.myname, Server}, - StateData#state.remote_streamid), - send_element(StateData, - #db_result{from = StateData#state.myname, - to = Server, - key = Key1}) - end, - case StateData#state.verify of - false -> ok; - {_Pid, Key2, SID} -> - send_element(StateData, - #db_verify{from = StateData#state.myname, - to = StateData#state.server, - id = SID, - key = Key2}) - end, - {next_state, wait_for_validation, NewStateData, - (?FSMTIMEOUT) * 6} - catch - _:_ -> {stop, normal, NewStateData} - end. - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -%% SRV support - --include_lib("kernel/include/inet.hrl"). - --spec get_addr_port(binary()) -> [{binary(), inet:port_number()}]. -get_addr_port(Server) -> - Res = srv_lookup(Server), - case Res of - {error, Reason} -> - ?DEBUG("srv lookup of '~s' failed: ~p~n", - [Server, Reason]), - [{Server, outgoing_s2s_port()}]; - {ok, HEnt} -> - ?DEBUG("srv lookup of '~s': ~p~n", - [Server, HEnt#hostent.h_addr_list]), - AddrList = HEnt#hostent.h_addr_list, - case catch lists:map(fun ({Priority, Weight, Port, - Host}) -> - N = case Weight of - 0 -> 0; - _ -> - (Weight + 1) * randoms:uniform() - end, - {Priority * 65536 - N, Host, Port} - end, - AddrList) - of - SortedList = [_ | _] -> - List = lists:map(fun ({_, Host, Port}) -> - {list_to_binary(Host), Port} - end, - lists:keysort(1, SortedList)), - ?DEBUG("srv lookup of '~s': ~p~n", [Server, List]), - List; - _ -> [{Server, outgoing_s2s_port()}] - end - end. - -srv_lookup(Server) -> - TimeoutMs = timer:seconds( - ejabberd_config:get_option( - s2s_dns_timeout, - fun(I) when is_integer(I), I>=0 -> I end, - 10)), - Retries = ejabberd_config:get_option( - s2s_dns_retries, - fun(I) when is_integer(I), I>=0 -> I end, - 2), - srv_lookup(binary_to_list(Server), TimeoutMs, Retries). - -%% XXX - this behaviour is suboptimal in the case that the domain -%% has a "_xmpp-server._tcp." but not a "_jabber._tcp." record and -%% we don't get a DNS reply for the "_xmpp-server._tcp." lookup. In this -%% case we'll give up when we get the "_jabber._tcp." nxdomain reply. -srv_lookup(_Server, _Timeout, Retries) - when Retries < 1 -> - {error, timeout}; -srv_lookup(Server, Timeout, Retries) -> - case inet_res:getbyname("_xmpp-server._tcp." ++ Server, - srv, Timeout) - of - {error, _Reason} -> - case inet_res:getbyname("_jabber._tcp." ++ Server, srv, - Timeout) - of - {error, timeout} -> - ?ERROR_MSG("The DNS servers~n ~p~ntimed out on " - "request for ~p IN SRV. You should check " - "your DNS configuration.", - [inet_db:res_option(nameserver), Server]), - srv_lookup(Server, Timeout, Retries - 1); - R -> R - end; - {ok, _HEnt} = R -> R - end. - -test_get_addr_port(Server) -> - lists:foldl(fun (_, Acc) -> - [HostPort | _] = get_addr_port(Server), - case lists:keysearch(HostPort, 1, Acc) of - false -> [{HostPort, 1} | Acc]; - {value, {_, Num}} -> - lists:keyreplace(HostPort, 1, Acc, - {HostPort, Num + 1}) - end - end, - [], lists:seq(1, 100000)). - -get_addrs(Host, Family) -> - Type = case Family of - inet4 -> inet; - ipv4 -> inet; - inet6 -> inet6; - ipv6 -> inet6 - end, - case inet:gethostbyname(binary_to_list(Host), Type) of - {ok, #hostent{h_addr_list = Addrs}} -> - ?DEBUG("~s of ~s resolved to: ~p~n", - [Type, Host, Addrs]), - Addrs; - {error, Reason} -> - ?DEBUG("~s lookup of '~s' failed: ~p~n", - [Type, Host, Reason]), - [] - end. - --spec outgoing_s2s_port() -> pos_integer(). -outgoing_s2s_port() -> +connect_timeout(#{server := LServer}) -> ejabberd_config:get_option( - outgoing_s2s_port, + {outgoing_s2s_timeout, LServer}, + fun(TimeOut) when is_integer(TimeOut), TimeOut > 0 -> + timer:seconds(TimeOut); + (infinity) -> + infinity + end, timer:seconds(10)). + +default_port(#{server := LServer}) -> + ejabberd_config:get_option( + {outgoing_s2s_port, LServer}, fun(I) when is_integer(I), I > 0, I =< 65536 -> I end, 5269). --spec outgoing_s2s_families() -> [ipv4 | ipv6]. -outgoing_s2s_families() -> +address_families(#{server := LServer}) -> ejabberd_config:get_option( - outgoing_s2s_families, + {outgoing_s2s_families, LServer}, fun(Families) -> - true = lists:all( - fun(ipv4) -> true; - (ipv6) -> true - end, Families), - Families - end, [ipv4, ipv6]). + lists:map( + fun(ipv4) -> inet; + (ipv6) -> inet6 + end, Families) + end, [inet, inet6]). --spec outgoing_s2s_timeout() -> pos_integer(). -outgoing_s2s_timeout() -> +dns_retries(#{server := LServer}) -> ejabberd_config:get_option( - outgoing_s2s_timeout, - fun(TimeOut) when is_integer(TimeOut), TimeOut > 0 -> - TimeOut; - (infinity) -> - infinity - end, 10000). + {s2s_dns_retries, LServer}, + fun(I) when is_integer(I), I>=0 -> I end, + 2). + +dns_timeout(#{server := LServer}) -> + ejabberd_config:get_option( + {s2s_dns_timeout, LServer}, + fun(I) when is_integer(I), I>=0 -> + timer:seconds(I); + (infinity) -> + infinity + end, timer:seconds(10)). + +handle_auth_success(Mech, #{sockmod := SockMod, + socket := Socket, ip := IP, + remote_server := RServer, + server := LServer} = State) -> + ?INFO_MSG("(~s) Accepted outbound s2s ~s authentication ~s -> ~s (~s)", + [SockMod:pp(Socket), Mech, LServer, RServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State, [true]). + +handle_auth_failure(Mech, Reason, + #{sockmod := SockMod, + socket := Socket, ip := IP, + remote_server := RServer, + server := LServer} = State) -> + ?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s", + [SockMod:pp(Socket), Mech, LServer, RServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), + xmpp_stream_out:format_error(Reason)]), + ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State, [{false, Reason}]). + +handle_packet(Pkt, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]). + +handle_stream_end(Reason, #{server := LServer} = State) -> + State1 = State#{stop_reason => Reason}, + ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [Reason]). + +handle_stream_downgraded(StreamStart, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_downgraded, LServer, State, [StreamStart]). + +handle_stream_established(State) -> + State1 = State#{on_route => send}, + State2 = resend_queue(State1), + set_idle_timeout(State2). + +handle_cdata(Data, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_cdata, LServer, State, [Data]). + +handle_recv(El, Pkt, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_recv, LServer, State, [El, Pkt]). + +handle_send(El, Pkt, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_send, LServer, State, [El, Pkt]). + +handle_timeout(#{on_route := Action} = State) -> + case Action of + bounce -> stop(State); + _ -> send(State, xmpp:serr_connection_timeout()) + end. + +init([#{server := LServer, remote_server := RServer} = State, Opts]) -> + State1 = State#{on_route => queue, + queue => queue:new(), + xmlns => ?NS_SERVER, + lang => ?MYLANG, + shaper => none}, + ?INFO_MSG("Outbound s2s connection started: ~s -> ~s", + [LServer, RServer]), + ejabberd_hooks:run_fold(s2s_out_init, LServer, {ok, State1}, [Opts]). + +handle_call(Request, From, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_call, LServer, State, [Request, From]). + +handle_cast({update_state, Fun}, State) -> + case Fun of + {M, F, A} -> erlang:apply(M, F, [State|A]); + _ when is_function(Fun) -> Fun(State) + end; +handle_cast(Msg, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_cast, LServer, State, [Msg]). + +handle_info({route, Pkt}, #{queue := Q, on_route := Action} = State) -> + case Action of + queue -> State#{queue => queue:in(Pkt, Q)}; + bounce -> bounce_packet(Pkt, State); + send -> set_idle_timeout(send(State, Pkt)) + end; +handle_info(Info, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_info, LServer, State, [Info]). + +terminate(Reason, #{server := LServer, + remote_server := RServer} = State) -> + ejabberd_s2s:remove_connection({LServer, RServer}, self()), + State1 = case Reason of + normal -> State; + _ -> State#{stop_reason => internal_failure} + end, + bounce_queue(State1), + bounce_message_queue(State1). + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec resend_queue(state()) -> state(). +resend_queue(#{queue := Q} = State) -> + State1 = State#{queue => queue:new()}, + jlib:queue_foldl( + fun(Pkt, AccState) -> + send(AccState, Pkt) + end, State1, Q). + +-spec bounce_queue(state()) -> state(). +bounce_queue(#{queue := Q} = State) -> + State1 = State#{queue => queue:new()}, + jlib:queue_foldl( + fun(Pkt, AccState) -> + bounce_packet(Pkt, AccState) + end, State1, Q). + +-spec bounce_message_queue(state()) -> state(). +bounce_message_queue(State) -> + receive {route, Pkt} -> + State1 = bounce_packet(Pkt, State), + bounce_message_queue(State1) + after 0 -> + State + end. + +-spec bounce_packet(xmpp_element(), state()) -> state(). +bounce_packet(Pkt, State) when ?is_stanza(Pkt) -> + From = xmpp:get_from(Pkt), + To = xmpp:get_to(Pkt), + Lang = xmpp:get_lang(Pkt), + Err = mk_bounce_error(Lang, State), + ejabberd_router:route_error(To, From, Pkt, Err), + State; +bounce_packet(_, State) -> + State. + +-spec mk_bounce_error(binary(), state()) -> stanza_error(). +mk_bounce_error(Lang, #{stop_reason := Why}) -> + Reason = xmpp_stream_out:format_error(Why), + case Why of + internal_failure -> + xmpp:err_internal_server_error(); + {dns, _} -> + xmpp:err_remote_server_not_found(Reason, Lang); + _ -> + xmpp:err_remote_server_timeout(Reason, Lang) + end; +mk_bounce_error(_Lang, _State) -> + %% We should not be here. Probably :) + xmpp:err_remote_server_not_found(). + +-spec get_delay() -> non_neg_integer(). +get_delay() -> + MaxDelay = ejabberd_config:get_option( + s2s_max_retry_delay, + fun(I) when is_integer(I), I > 0 -> I end, + 300), + crypto:rand_uniform(1, MaxDelay). + +-spec set_idle_timeout(state()) -> state(). +set_idle_timeout(#{on_route := send, server := LServer} = State) -> + Timeout = ejabberd_s2s:get_idle_timeout(LServer), + xmpp_stream_out:set_timeout(State, Timeout); +set_idle_timeout(State) -> + State. transform_options(Opts) -> lists:foldl(fun transform_options/2, [], Opts). @@ -978,6 +378,7 @@ transform_options({outgoing_s2s_options, Families, Timeout}, Opts) -> "but it is better to fix your config: " "use 'outgoing_s2s_timeout' and " "'outgoing_s2s_families' instead.", []), + maybe_report_huge_timeout(outgoing_s2s_timeout, Timeout), [{outgoing_s2s_families, Families}, {outgoing_s2s_timeout, Timeout} | Opts]; @@ -989,109 +390,27 @@ transform_options({s2s_dns_options, S2SDNSOpts}, AllOpts) -> "'s2s_dns_retries' instead", []), lists:foldr( fun({timeout, T}, AccOpts) -> + maybe_report_huge_timeout(s2s_dns_timeout, T), [{s2s_dns_timeout, T}|AccOpts]; ({retries, R}, AccOpts) -> [{s2s_dns_retries, R}|AccOpts]; (_, AccOpts) -> AccOpts end, AllOpts, S2SDNSOpts); +transform_options({Opt, T}, Opts) + when Opt == outgoing_s2s_timeout; Opt == s2s_dns_timeout -> + maybe_report_huge_timeout(Opt, T), + [{outgoing_s2s_timeout, T}|Opts]; transform_options(Opt, Opts) -> [Opt|Opts]. -%% Human readable S2S logging: Log only new outgoing connections as INFO -%% Do not log dialback -log_s2s_out(false, _, _, _) -> ok; -%% Log new outgoing connections: -log_s2s_out(_, Myname, Server, Tls) -> - ?INFO_MSG("Trying to open s2s connection: ~s -> " - "~s with TLS=~p", - [Myname, Server, Tls]). +maybe_report_huge_timeout(Opt, T) when is_integer(T), T >= 1000 -> + ?WARNING_MSG("value '~p' of option '~p' is too big, " + "are you sure you have set seconds?", + [T, Opt]); +maybe_report_huge_timeout(_, _) -> + ok. -%% Calculate timeout depending on which state we are in: -%% Can return integer > 0 | infinity --spec get_timeout_interval(state_name()) -> pos_integer() | infinity. -get_timeout_interval(StateName) -> - case StateName of - %% Validation implies dialback: Networking can take longer: - wait_for_validation -> (?FSMTIMEOUT) * 6; - %% When stream is established, we only rely on S2S Timeout timer: - stream_established -> infinity; - relay_to_bridge -> infinity; - open_socket -> infinity; - _ -> ?FSMTIMEOUT - end. - -%% This function is intended to be called at the end of a state -%% function that want to wait for a reconnect delay before stopping. --spec wait_before_reconnect(state()) -> fsm_next(). -wait_before_reconnect(StateData) -> - bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()), - bounce_messages(xmpp:err_remote_server_not_found()), - cancel_timer(StateData#state.timer), - Delay = case StateData#state.delay_to_retry of - undefined_delay -> - {_, _, MicroSecs} = p1_time_compat:timestamp(), MicroSecs rem 14000 + 1000; - D1 -> lists:min([D1 * 2, get_max_retry_delay()]) - end, - Timer = erlang:start_timer(Delay, self(), []), - {next_state, wait_before_retry, - StateData#state{timer = Timer, delay_to_retry = Delay, - queue = queue:new()}}. - --spec get_max_retry_delay() -> pos_integer(). -get_max_retry_delay() -> - case ejabberd_config:get_option( - s2s_max_retry_delay, - fun(I) when is_integer(I), I > 0 -> I end) of - undefined -> ?MAX_RETRY_DELAY; - Seconds -> Seconds * 1000 - end. - -%% Terminate s2s_out connections that are in state wait_before_retry --spec terminate_if_waiting_delay(binary(), binary()) -> ok. -terminate_if_waiting_delay(From, To) -> - FromTo = {From, To}, - Pids = ejabberd_s2s:get_connections_pids(FromTo), - lists:foreach(fun (Pid) -> - Pid ! terminate_if_waiting_before_retry - end, - Pids). - --spec fsm_limit_opts() -> [{max_queue, pos_integer()}]. -fsm_limit_opts() -> - case ejabberd_config:get_option( - max_fsm_queue, - fun(I) when is_integer(I), I > 0 -> I end) of - undefined -> []; - N -> [{max_queue, N}] - end. - --spec decode_element(xmlel(), state_name(), state()) -> fsm_next(). -decode_element(#xmlel{} = El, StateName, StateData) -> - Opts = if StateName == stream_established -> - [ignore_els]; - true -> - [] - end, - try xmpp:decode(El, ?NS_SERVER, Opts) of - Pkt -> ?MODULE:StateName(Pkt, StateData) - catch error:{xmpp_codec, Why} -> - Type = xmpp:get_type(El), - case xmpp:is_stanza(El) of - true when Type /= <<"result">>, Type /= <<"error">> -> - Lang = xmpp:get_lang(El), - Txt = xmpp:format_error(Why), - Err = xmpp:make_error(El, xmpp:err_bad_request(Txt, Lang)), - send_element(StateData, Err); - false -> - ok - end, - {next_state, StateName, StateData, get_timeout_interval(StateName)} - end. - -opt_type(domain_certfile) -> fun iolist_to_binary/1; -opt_type(max_fsm_queue) -> - fun (I) when is_integer(I), I > 0 -> I end; opt_type(outgoing_s2s_families) -> fun (Families) -> true = lists:all(fun (ipv4) -> true; @@ -1107,36 +426,15 @@ opt_type(outgoing_s2s_timeout) -> TimeOut; (infinity) -> infinity end; -opt_type(s2s_certfile) -> fun iolist_to_binary/1; -opt_type(s2s_ciphers) -> fun iolist_to_binary/1; -opt_type(s2s_dhfile) -> fun iolist_to_binary/1; opt_type(s2s_dns_retries) -> fun (I) when is_integer(I), I >= 0 -> I end; opt_type(s2s_dns_timeout) -> - fun (I) when is_integer(I), I >= 0 -> I end; + fun (TimeOut) when is_integer(TimeOut), TimeOut > 0 -> + TimeOut; + (infinity) -> infinity + end; opt_type(s2s_max_retry_delay) -> fun (I) when is_integer(I), I > 0 -> I end; -opt_type(s2s_protocol_options) -> - fun (Options) -> - [_ | O] = lists:foldl(fun (X, Acc) -> X ++ Acc end, [], - [["|" | binary_to_list(Opt)] - || Opt <- Options, is_binary(Opt)]), - iolist_to_binary(O) - end; -opt_type(s2s_tls_compression) -> - fun (true) -> true; - (false) -> false - end; -opt_type(s2s_use_starttls) -> - fun (true) -> true; - (false) -> false; - (optional) -> optional; - (required) -> required; - (required_trusted) -> required_trusted - end; opt_type(_) -> - [domain_certfile, max_fsm_queue, outgoing_s2s_families, - outgoing_s2s_port, outgoing_s2s_timeout, s2s_certfile, - s2s_ciphers, s2s_dhfile, s2s_dns_retries, s2s_dns_timeout, - s2s_max_retry_delay, s2s_protocol_options, - s2s_tls_compression, s2s_use_starttls]. + [outgoing_s2s_families, outgoing_s2s_port, outgoing_s2s_timeout, + s2s_dns_retries, s2s_dns_timeout, s2s_max_retry_delay]. diff --git a/src/ejabberd_service.erl b/src/ejabberd_service.erl index 5003ff6ab..dd949f2f9 100644 --- a/src/ejabberd_service.erl +++ b/src/ejabberd_service.erl @@ -1,8 +1,5 @@ -%%%---------------------------------------------------------------------- -%%% File : ejabberd_service.erl -%%% Author : Alexey Shchepin -%%% Purpose : External component management (XEP-0114) -%%% Created : 6 Dec 2002 by Alexey Shchepin +%%%------------------------------------------------------------------- +%%% Created : 11 Dec 2016 by Evgeny Khramtsov %%% %%% %%% ejabberd, Copyright (C) 2002-2017 ProcessOne @@ -21,77 +18,57 @@ %%% with this program; if not, write to the Free Software Foundation, Inc., %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. %%% -%%%---------------------------------------------------------------------- - +%%%------------------------------------------------------------------- -module(ejabberd_service). - +-behaviour(xmpp_stream_in). -behaviour(ejabberd_config). - --author('alexey@process-one.net'). +-behaviour(ejabberd_socket). -protocol({xep, 114, '1.6'}). --define(GEN_FSM, p1_fsm). - --behaviour(?GEN_FSM). - -%% External exports --export([start/2, start_link/2, send_text/2, - send_element/2, socket_type/0, transform_listen_option/2]). - --export([init/1, wait_for_stream/2, - wait_for_handshake/2, stream_established/2, - handle_event/3, handle_sync_event/4, code_change/4, - handle_info/3, terminate/3, print_state/1, opt_type/1]). +%% ejabberd_socket callbacks +-export([start/2, start_link/2, socket_type/0]). +%% ejabberd_config callbacks +-export([opt_type/1, transform_listen_option/2]). +%% xmpp_stream_in callbacks +-export([init/1, handle_info/2, terminate/2, code_change/3]). +-export([handle_stream_start/2, handle_auth_success/4, handle_auth_failure/4, + handle_authenticated_packet/2, get_password_fun/1]). +%% API +-export([send/2]). -include("ejabberd.hrl"). --include("logger.hrl"). -include("xmpp.hrl"). +-include("logger.hrl"). --record(state, - {socket :: ejabberd_socket:socket_state(), - sockmod = ejabberd_socket :: ejabberd_socket | ejabberd_frontend_socket, - streamid = <<"">> :: binary(), - host_opts = dict:new() :: ?TDICT, - host = <<"">> :: binary(), - access :: atom(), - check_from = true :: boolean()}). +-type state() :: map(). +-export_type([state/0]). --type state_name() :: wait_for_stream | wait_for_handshake | stream_established. --type state() :: #state{}. --type fsm_next() :: {next_state, state_name(), state()}. --type fsm_stop() :: {stop, normal, state()}. --type fsm_transition() :: fsm_stop() | fsm_next(). - -%-define(DBGFSM, true). --ifdef(DBGFSM). --define(FSMOPTS, [{debug, [trace]}]). --else. --define(FSMOPTS, []). --endif. - -%%%---------------------------------------------------------------------- +%%%=================================================================== %%% API -%%%---------------------------------------------------------------------- +%%%=================================================================== start(SockData, Opts) -> - supervisor:start_child(ejabberd_service_sup, - [SockData, Opts]). + xmpp_stream_in:start(?MODULE, [SockData, Opts], + ejabberd_config:fsm_limit_opts(Opts)). start_link(SockData, Opts) -> - (?GEN_FSM):start_link(ejabberd_service, - [SockData, Opts], fsm_limit_opts(Opts) ++ (?FSMOPTS)). + xmpp_stream_in:start_link(?MODULE, [SockData, Opts], + ejabberd_config:fsm_limit_opts(Opts)). -socket_type() -> xml_stream. +socket_type() -> + xml_stream. -%%%---------------------------------------------------------------------- -%%% Callback functions from gen_fsm -%%%---------------------------------------------------------------------- -init([{SockMod, Socket}, Opts]) -> - ?INFO_MSG("(~w) External service connected", [Socket]), - Access = case lists:keysearch(access, 1, Opts) of - {value, {_, A}} -> A; - _ -> all - end, +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Stream, Pkt) -> + xmpp_stream_in:send(Stream, Pkt). + +%%%=================================================================== +%%% xmpp_stream_in callbacks +%%%=================================================================== +init([State, Opts]) -> + Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all), + Shaper = gen_mod:get_opt(shaper_rule, Opts, fun acl:shaper_rules_validator/1, none), HostOpts = case lists:keyfind(hosts, 1, Opts) of {hosts, HOpts} -> lists:foldl( @@ -107,252 +84,141 @@ init([{SockMod, Socket}, Opts]) -> p1_sha:sha(randoms:bytes(20))), dict:from_list([{global, Pass}]) end, - Shaper = case lists:keysearch(shaper_rule, 1, Opts) of - {value, {_, S}} -> S; - _ -> none - end, - CheckFrom = case lists:keysearch(service_check_from, 1, - Opts) - of - {value, {_, CF}} -> CF; - _ -> true - end, - SockMod:change_shaper(Socket, Shaper), - {ok, wait_for_stream, - #state{socket = Socket, sockmod = SockMod, - streamid = new_id(), host_opts = HostOpts, - access = Access, check_from = CheckFrom}}. + CheckFrom = gen_mod:get_opt(check_from, Opts, + fun(Flag) when is_boolean(Flag) -> Flag end, + true), + xmpp_stream_in:change_shaper(State, Shaper), + State1 = State#{access => Access, + xmlns => ?NS_COMPONENT, + lang => ?MYLANG, + server => ?MYNAME, + host_opts => HostOpts, + check_from => CheckFrom}, + ejabberd_hooks:run_fold(component_init, {ok, State1}, [Opts]). -wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) -> - try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of - #stream_start{xmlns = NS_COMPONENT, stream_xmlns = NS_STREAM} - when NS_COMPONENT /= ?NS_COMPONENT; NS_STREAM /= ?NS_STREAM -> - send_header(StateData, ?MYNAME), - send_element(StateData, xmpp:serr_invalid_namespace()), - {stop, normal, StateData}; - #stream_start{to = To} when is_record(To, jid) -> - Host = To#jid.lserver, - send_header(StateData, Host), - HostOpts = case dict:is_key(Host, StateData#state.host_opts) of +handle_stream_start(_StreamStart, + #{remote_server := RemoteServer, + lang := Lang, + host_opts := HostOpts} = State) -> + case ejabberd_router:is_my_host(RemoteServer) of true -> - StateData#state.host_opts; + Txt = <<"Unable to register route on existing local domain">>, + xmpp_stream_in:send(State, xmpp:serr_conflict(Txt, Lang)); false -> - case dict:find(global, StateData#state.host_opts) of + NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of + true -> + HostOpts; + false -> + case dict:find(global, HostOpts) of {ok, GlobalPass} -> - dict:from_list([{Host, GlobalPass}]); + dict:from_list([{RemoteServer, GlobalPass}]); error -> - StateData#state.host_opts + HostOpts end end, - {next_state, wait_for_handshake, - StateData#state{host = Host, host_opts = HostOpts}}; - #stream_start{} -> - send_header(StateData, ?MYNAME), - send_element(StateData, xmpp:serr_improper_addressing()), - {stop, normal, StateData}; - _ -> - send_header(StateData, ?MYNAME), - send_element(StateData, xmpp:serr_invalid_xml()), - {stop, normal, StateData} - catch _:{xmpp_codec, Why} -> - Txt = xmpp:format_error(Why), - send_header(StateData, ?MYNAME), - send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)), - {stop, normal, StateData} - end; -wait_for_stream({xmlstreamerror, _}, StateData) -> - send_header(StateData, ?MYNAME), - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_stream(closed, StateData) -> - {stop, normal, StateData}. + State#{host_opts => NewHostOpts} + end. -wait_for_handshake({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_handshake, StateData); -wait_for_handshake(#handshake{data = Digest}, StateData) -> - case dict:find(StateData#state.host, StateData#state.host_opts) of +get_password_fun(#{remote_server := RemoteServer, + socket := Socket, sockmod := SockMod, + ip := IP, + host_opts := HostOpts}) -> + fun(_) -> + case dict:find(RemoteServer, HostOpts) of {ok, Password} -> - case p1_sha:sha(<<(StateData#state.streamid)/binary, - Password/binary>>) of - Digest -> - send_element(StateData, #handshake{}), + {Password, undefined}; + error -> + ?ERROR_MSG("(~s) Domain ~s is unconfigured for " + "external component from ~s", + [SockMod:pp(Socket), RemoteServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + {false, undefined} + end + end. + +handle_auth_success(_, Mech, _, + #{remote_server := RemoteServer, host_opts := HostOpts, + socket := Socket, sockmod := SockMod, + ip := IP} = State) -> + ?INFO_MSG("(~s) Accepted external component ~s authentication " + "for ~s from ~s", + [SockMod:pp(Socket), Mech, RemoteServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), lists:foreach( fun (H) -> ejabberd_router:register_route(H, ?MYNAME), - ?INFO_MSG("Route registered for service ~p~n", - [H]), ejabberd_hooks:run(component_connected, [H]) - end, dict:fetch_keys(StateData#state.host_opts)), - {next_state, stream_established, StateData}; - _ -> - send_element(StateData, xmpp:serr_not_authorized()), - {stop, normal, StateData} - end; - _ -> - send_element(StateData, xmpp:serr_not_authorized()), - {stop, normal, StateData} - end; -wait_for_handshake({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -wait_for_handshake({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_handshake(closed, StateData) -> - {stop, normal, StateData}; -wait_for_handshake(_Pkt, StateData) -> - {next_state, wait_for_handshake, StateData}. + end, dict:fetch_keys(HostOpts)), + State. -stream_established({xmlstreamelement, El}, StateData) -> - decode_element(El, stream_established, StateData); -stream_established(El, StateData) when ?is_stanza(El) -> - From = xmpp:get_from(El), - To = xmpp:get_to(El), - Lang = xmpp:get_lang(El), - if From == undefined orelse To == undefined -> - Txt = <<"Missing 'from' or 'to' attribute">>, - send_error(StateData, El, xmpp:err_jid_malformed(Txt, Lang)); - true -> - case check_from(From, StateData) of +handle_auth_failure(_, Mech, Reason, + #{remote_server := RemoteServer, + sockmod := SockMod, + socket := Socket, ip := IP} = State) -> + ?ERROR_MSG("(~s) Failed external component ~s authentication " + "for ~s from ~s: ~s", + [SockMod:pp(Socket), Mech, RemoteServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), + Reason]), + State. + +handle_authenticated_packet(Pkt, #{lang := Lang} = State) -> + From = xmpp:get_from(Pkt), + case check_from(From, State) of true -> - ejabberd_router:route(From, To, El); + To = xmpp:get_to(Pkt), + ejabberd_router:route(From, To, Pkt), + State; false -> Txt = <<"Improper domain part of 'from' attribute">>, - send_error(StateData, El, xmpp:err_not_allowed(Txt, Lang)) - end - end, - {next_state, stream_established, StateData}; -stream_established({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -stream_established({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -stream_established(closed, StateData) -> - {stop, normal, StateData}; -stream_established(_Event, StateData) -> - {next_state, stream_established, StateData}. + Err = xmpp:serr_invalid_from(Txt, Lang), + xmpp_stream_in:send(State, Err) + end. -handle_event(_Event, StateName, StateData) -> - {next_state, StateName, StateData}. - -handle_sync_event(_Event, _From, StateName, - StateData) -> - Reply = ok, {reply, Reply, StateName, StateData}. - -code_change(_OldVsn, StateName, StateData, _Extra) -> - {ok, StateName, StateData}. - -handle_info({send_text, Text}, StateName, StateData) -> - send_text(StateData, Text), - {next_state, StateName, StateData}; -handle_info({send_element, El}, StateName, StateData) -> - send_element(StateData, El), - {next_state, StateName, StateData}; -handle_info({route, From, To, Packet}, StateName, - StateData) -> - case acl:match_rule(global, StateData#state.access, From) of +handle_info({route, From, To, Packet}, #{access := Access} = State) -> + case acl:match_rule(global, Access, From) of allow -> Pkt = xmpp:set_from_to(Packet, From, To), - send_element(StateData, Pkt); + xmpp_stream_in:send(State, Pkt); deny -> Lang = xmpp:get_lang(Packet), Err = xmpp:err_not_allowed(<<"Denied by ACL">>, Lang), - ejabberd_router:route_error(To, From, Packet, Err) - end, - {next_state, StateName, StateData}; -handle_info(Info, StateName, StateData) -> + ejabberd_router:route_error(To, From, Packet, Err), + State + end; +handle_info(Info, State) -> ?ERROR_MSG("Unexpected info: ~p", [Info]), - {next_state, StateName, StateData}. + State. -terminate(Reason, StateName, StateData) -> - ?INFO_MSG("terminated: ~p", [Reason]), - case StateName of - stream_established -> - lists:foreach(fun (H) -> +terminate(Reason, #{stream_state := StreamState, host_opts := HostOpts}) -> + case StreamState of + established -> + lists:foreach( + fun(H) -> ejabberd_router:unregister_route(H), - ejabberd_hooks:run(component_disconnected, - [H, Reason]) - end, - dict:fetch_keys(StateData#state.host_opts)); - _ -> ok - end, - catch send_trailer(StateData), - (StateData#state.sockmod):close(StateData#state.socket), - ok. - -%%---------------------------------------------------------------------- -%% Func: print_state/1 -%% Purpose: Prepare the state to be printed on error log -%% Returns: State to print -%%---------------------------------------------------------------------- -print_state(State) -> State. - -%%%---------------------------------------------------------------------- -%%% Internal functions -%%%---------------------------------------------------------------------- - --spec send_text(state(), iodata()) -> ok. -send_text(StateData, Text) -> - (StateData#state.sockmod):send(StateData#state.socket, - Text). - --spec send_element(state(), xmpp_element()) -> ok. -send_element(StateData, El) -> - El1 = xmpp:encode(El, ?NS_COMPONENT), - send_text(StateData, fxml:element_to_binary(El1)). - --spec send_error(state(), xmlel() | stanza(), stanza_error()) -> ok. -send_error(StateData, Stanza, Error) -> - Type = xmpp:get_type(Stanza), - if Type == error; Type == result; - Type == <<"error">>; Type == <<"result">> -> - ok; - true -> - send_element(StateData, xmpp:make_error(Stanza, Error)) - end. - --spec send_header(state(), binary()) -> ok. -send_header(StateData, Host) -> - Header = xmpp:encode( - #stream_start{xmlns = ?NS_COMPONENT, - stream_xmlns = ?NS_STREAM, - from = jid:make(Host), - id = StateData#state.streamid}), - send_text(StateData, fxml:element_to_header(Header)). - --spec send_trailer(state()) -> ok. -send_trailer(StateData) -> - send_text(StateData, <<"">>). - --spec decode_element(xmlel(), state_name(), state()) -> fsm_transition(). -decode_element(#xmlel{} = El, StateName, StateData) -> - try xmpp:decode(El, ?NS_COMPONENT, [ignore_els]) of - Pkt -> ?MODULE:StateName(Pkt, StateData) - catch error:{xmpp_codec, Why} -> - case xmpp:is_stanza(El) of - true -> - Lang = xmpp:get_lang(El), - Txt = xmpp:format_error(Why), - send_error(StateData, El, xmpp:err_bad_request(Txt, Lang)); - false -> + ejabberd_hooks:run(component_disconnected, [H, Reason]) + end, dict:fetch_keys(HostOpts)); + _ -> ok - end, - {next_state, StateName, StateData} end. +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== -spec check_from(jid(), state()) -> boolean(). -check_from(_From, #state{check_from = false}) -> +check_from(_From, #{check_from := false}) -> %% If the admin does not want to check the from field %% when accept packets from any address. %% In this case, the component can send packet of %% behalf of the server users. true; -check_from(From, StateData) -> +check_from(From, #{host_opts := HostOpts}) -> %% The default is the standard behaviour in XEP-0114 Server = From#jid.lserver, - dict:is_key(Server, StateData#state.host_opts). - --spec new_id() -> binary(). -new_id() -> randoms:get_string(). + dict:is_key(Server, HostOpts). transform_listen_option({hosts, Hosts, O}, Opts) -> case lists:keyfind(hosts, 1, Opts) of @@ -372,19 +238,4 @@ transform_listen_option({host, Host, Os}, Opts) -> transform_listen_option(Opt, Opts) -> [Opt|Opts]. -fsm_limit_opts(Opts) -> - case lists:keysearch(max_fsm_queue, 1, Opts) of - {value, {_, N}} when is_integer(N) -> - [{max_queue, N}]; - _ -> - case ejabberd_config:get_option( - max_fsm_queue, - fun(I) when is_integer(I), I > 0 -> I end) of - undefined -> []; - N -> [{max_queue, N}] - end - end. - -opt_type(max_fsm_queue) -> - fun (I) when is_integer(I), I > 0 -> I end; -opt_type(_) -> [max_fsm_queue]. +opt_type(_) -> []. diff --git a/src/ejabberd_sm.erl b/src/ejabberd_sm.erl index 5abcaa572..98aaed573 100644 --- a/src/ejabberd_sm.erl +++ b/src/ejabberd_sm.erl @@ -34,6 +34,7 @@ %% API -export([start/0, start_link/0, + route/2, route/3, process_iq/3, open_session/5, @@ -63,12 +64,14 @@ user_resources/2, kick_user/2, get_session_pid/3, + get_user_info/2, get_user_info/3, get_user_ip/3, get_max_user_sessions/2, get_all_pids/0, is_existing_resource/3, get_commands_spec/0, + c2s_handle_info/2, make_sid/0 ]). @@ -81,7 +84,6 @@ -include("xmpp.hrl"). -include("ejabberd_commands.hrl"). --include("mod_privacy.hrl"). -include("ejabberd_sm.hrl"). -callback init() -> ok | {error, any()}. @@ -98,15 +100,6 @@ %% default value for the maximum number of user connections -define(MAX_USER_SESSIONS, infinity). --type broadcast() :: {broadcast, broadcast_data()}. - --type broadcast_data() :: - {rebind, pid(), binary()} | %% ejabberd_c2s - {item, ljid(), mod_roster:subscription()} | %% mod_roster/mod_shared_roster - {exit, binary()} | %% mod_roster/mod_shared_roster - {privacy_list, mod_privacy:userlist(), binary()} | %% mod_privacy - {blocking, unblock_all | {block | unblock, [ljid()]}}. %% mod_blocking - %%==================================================================== %% API %%==================================================================== @@ -120,7 +113,18 @@ start() -> start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). --spec route(jid(), jid(), stanza() | broadcast()) -> ok. +-spec route(jid(), term()) -> ok. +%% @doc route arbitrary term to c2s process(es) +route(To, Term) -> + case catch do_route(To, Term) of + {'EXIT', Reason} -> + ?ERROR_MSG("route ~p to ~p failed: ~p", + [Term, To, Reason]); + _ -> + ok + end. + +-spec route(jid(), jid(), stanza()) -> ok. route(From, To, Packet) -> case catch do_route(From, To, Packet) of @@ -180,9 +184,7 @@ bounce_offline_message(From, To, Packet) -> -spec disconnect_removed_user(binary(), binary()) -> ok. disconnect_removed_user(User, Server) -> - ejabberd_sm:route(jid:make(<<"">>, <<"">>, <<"">>), - jid:make(User, Server, <<"">>), - {broadcast, {exit, <<"User removed">>}}). + route(jid:make(User, Server, <<"">>), {exit, <<"User removed">>}). get_user_resources(User, Server) -> LUser = jid:nodeprep(User), @@ -214,6 +216,17 @@ get_user_ip(User, Server, Resource) -> proplists:get_value(ip, Session#session.info) end. +-spec get_user_info(binary(), binary()) -> [{binary(), info()}]. +get_user_info(User, Server) -> + LUser = jid:nodeprep(User), + LServer = jid:nameprep(Server), + Mod = get_sm_backend(LServer), + Ss = online(Mod:get_sessions(LUser, LServer)), + [{LResource, [{node, node(Pid)}|Info]} + || #session{usr = {_, _, LResource}, + info = Info, + sid = {_, Pid}} <- clean_session_list(Ss)]. + -spec get_user_info(binary(), binary(), binary()) -> info() | offline. get_user_info(User, Server, Resource) -> @@ -227,9 +240,7 @@ get_user_info(User, Server, Resource) -> Ss -> Session = lists:max(Ss), Node = node(element(2, Session#session.sid)), - Conn = proplists:get_value(conn, Session#session.info), - IP = proplists:get_value(ip, Session#session.info), - [{node, Node}, {conn, Conn}, {ip, IP}] + [{node, Node}|Session#session.info] end. -spec set_presence(sid(), binary(), binary(), binary(), @@ -356,6 +367,21 @@ register_iq_handler(Host, XMLNS, Module, Fun, Opts) -> unregister_iq_handler(Host, XMLNS) -> ejabberd_sm ! {unregister_iq_handler, Host, XMLNS}. +%% Why the hell do we have so many similar kicks? +c2s_handle_info(#{lang := Lang} = State, replaced) -> + State1 = State#{replaced => true}, + Err = xmpp:serr_conflict(<<"Replaced by new connection">>, Lang), + {stop, ejabberd_c2s:send(State1, Err)}; +c2s_handle_info(#{lang := Lang} = State, kick) -> + Err = xmpp:serr_policy_violation(<<"has been kicked">>, Lang), + c2s_handle_info(State, {kick, kicked_by_admin, Err}); +c2s_handle_info(State, {kick, _Reason, Err}) -> + {stop, ejabberd_c2s:send(State, Err)}; +c2s_handle_info(#{lang := Lang} = State, {exit, Reason}) -> + Err = xmpp:serr_conflict(Reason, Lang), + {stop, ejabberd_c2s:send(State, Err)}; +c2s_handle_info(State, _) -> + State. %%==================================================================== %% gen_server callbacks @@ -366,12 +392,15 @@ init([]) -> ets:new(sm_iqtable, [named_table]), lists:foreach( fun(Host) -> + ejabberd_hooks:add(c2s_handle_info, Host, + ejabberd_sm, c2s_handle_info, 50), ejabberd_hooks:add(roster_in_subscription, Host, ejabberd_sm, check_in_subscription, 20), ejabberd_hooks:add(offline_message_hook, Host, ejabberd_sm, bounce_offline_message, 100), ejabberd_hooks:add(remove_user, Host, - ejabberd_sm, disconnect_removed_user, 100) + ejabberd_sm, disconnect_removed_user, 100), + ejabberd_c2s:add_hooks(Host) end, ?MYHOSTS), ejabberd_commands:register_commands(get_commands_spec()), {ok, #state{}}. @@ -411,6 +440,17 @@ handle_info({unregister_iq_handler, Host, XMLNS}, handle_info(_Info, State) -> {noreply, State}. terminate(_Reason, _State) -> + lists:foreach( + fun(Host) -> + ejabberd_hooks:delete(c2s_handle_info, Host, + ejabberd_sm, c2s_handle_info, 50), + ejabberd_hooks:delete(roster_in_subscription, Host, + ejabberd_sm, check_in_subscription, 20), + ejabberd_hooks:delete(offline_message_hook, Host, + ejabberd_sm, bounce_offline_message, 100), + ejabberd_hooks:delete(remove_user, Host, + ejabberd_sm, disconnect_removed_user, 100) + end, ?MYHOSTS), ejabberd_commands:unregister_commands(get_commands_spec()), ok. @@ -444,26 +484,27 @@ is_online(#session{info = Info}) -> not proplists:get_bool(offline, Info). %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% --spec do_route(jid(), jid(), stanza() | broadcast()) -> any(). -do_route(From, #jid{lresource = <<"">>} = To, {broadcast, _} = Packet) -> - ?DEBUG("processing broadcast to bare JID: ~p", [Packet]), +-spec do_route(jid(), term()) -> any(). +do_route(#jid{lresource = <<"">>} = To, Term) -> lists:foreach( fun(R) -> - do_route(From, jid:replace_resource(To, R), Packet) + do_route(jid:replace_resource(To, R), Term) end, get_user_resources(To#jid.user, To#jid.server)); -do_route(From, To, {broadcast, _} = Packet) -> - ?DEBUG("processing broadcast to full JID: ~p", [Packet]), +do_route(To, Term) -> + ?DEBUG("broadcasting ~p to ~s", [Term, jid:to_string(To)]), {U, S, R} = jid:tolower(To), Mod = get_sm_backend(S), case online(Mod:get_sessions(U, S, R)) of [] -> - ?DEBUG("dropping broadcast to unavailable resourse: ~p", [Packet]); + ?DEBUG("dropping broadcast to unavailable resourse: ~p", [Term]); Ss -> Session = lists:max(Ss), Pid = element(2, Session#session.sid), - ?DEBUG("sending to process ~p: ~p", [Pid, Packet]), - Pid ! {route, From, To, Packet} - end; + ?DEBUG("sending to process ~p: ~p", [Pid, Term]), + Pid ! Term + end. + +-spec do_route(jid(), jid(), stanza()) -> any(). do_route(From, To, #presence{type = T, status = Status} = Packet) when T == subscribe; T == subscribed; T == unsubscribe; T == unsubscribed -> ?DEBUG("processing subscription:~n~s", [xmpp:pp(Packet)]), @@ -544,24 +585,10 @@ do_route(From, To, Packet) -> %% or if there are no current sessions for the user. -spec is_privacy_allow(jid(), jid(), stanza()) -> boolean(). is_privacy_allow(From, To, Packet) -> - User = To#jid.user, - Server = To#jid.server, - PrivacyList = - ejabberd_hooks:run_fold(privacy_get_user_list, Server, - #userlist{}, [User, Server]), - is_privacy_allow(From, To, Packet, PrivacyList). - -%% Check if privacy rules allow this delivery -%% Function copied from ejabberd_c2s.erl --spec is_privacy_allow(jid(), jid(), stanza(), #userlist{}) -> boolean(). -is_privacy_allow(From, To, Packet, PrivacyList) -> - User = To#jid.user, - Server = To#jid.server, - allow == - ejabberd_hooks:run_fold(privacy_check_packet, Server, - allow, - [User, Server, PrivacyList, {From, To, Packet}, - in]). + LServer = To#jid.server, + allow == ejabberd_hooks:run_fold( + privacy_check_packet, LServer, allow, + [To, xmpp:set_from_to(Packet, From, To), in]). -spec route_message(jid(), jid(), message(), message_type()) -> any(). route_message(From, To, Packet, Type) -> @@ -725,10 +752,14 @@ process_iq(From, To, #iq{type = T, lang = Lang, sub_els = [El]} = Packet) Err = xmpp:err_service_unavailable(Txt, Lang), ejabberd_router:route_error(To, From, Packet, Err) end; -process_iq(From, To, #iq{type = T} = Packet) when T == get; T == set -> - Err = xmpp:err_bad_request(), - ejabberd_router:route_error(To, From, Packet, Err), - ok; +process_iq(From, To, #iq{type = T, lang = Lang, sub_els = SubEls} = Packet) + when T == get; T == set -> + Txt = case SubEls of + [] -> <<"No child elements found">>; + _ -> <<"Too many child elements">> + end, + Err = xmpp:err_bad_request(Txt, Lang), + ejabberd_router:route_error(To, From, Packet, Err); process_iq(_From, _To, #iq{}) -> ok. @@ -738,17 +769,21 @@ force_update_presence({LUser, LServer}) -> Mod = get_sm_backend(LServer), Ss = online(Mod:get_sessions(LUser, LServer)), lists:foreach(fun (#session{sid = {_, Pid}}) -> - Pid ! {force_update_presence, LUser, LServer} + Pid ! force_update_presence end, Ss). -spec get_sm_backend(binary()) -> module(). get_sm_backend(Host) -> - DBType = ejabberd_config:get_option( + DBType = case ejabberd_config:get_option( {sm_db_type, Host}, - fun(T) -> ejabberd_config:v_db(?MODULE, T) end, - mnesia), + fun(T) -> ejabberd_config:v_db(?MODULE, T) end) of + undefined -> + ejabberd_config:default_ram_db(Host, ?MODULE); + T -> + T + end, list_to_atom("ejabberd_sm_" ++ atom_to_list(DBType)). -spec get_sm_backends() -> [module()]. diff --git a/src/ejabberd_socket.erl b/src/ejabberd_socket.erl index 4cf36a81c..c7b57a6a1 100644 --- a/src/ejabberd_socket.erl +++ b/src/ejabberd_socket.erl @@ -33,10 +33,12 @@ connect/4, connect/5, starttls/2, - starttls/3, compress/1, compress/2, reset_stream/1, + send_element/2, + send_header/2, + send_trailer/1, send/2, send_xml/2, change_shaper/2, @@ -46,9 +48,11 @@ get_peer_certificate/1, get_verify_result/1, close/1, + pp/1, sockname/1, peername/1]). -include("ejabberd.hrl"). +-include("xmpp.hrl"). -include("logger.hrl"). -type sockmod() :: ejabberd_bosh | @@ -68,60 +72,68 @@ -export_type([socket/0, socket_state/0, sockmod/0]). +-callback start({module(), socket_state()}, + [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore. +-callback start_link({module(), socket_state()}, + [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore. +-callback socket_type() -> xml_stream | independent | raw. + +-define(is_http_socket(S), + (S#socket_state.sockmod == ejabberd_bosh orelse + S#socket_state.sockmod == ejabberd_http_ws)). %%==================================================================== %% API %%==================================================================== --spec start(atom(), sockmod(), socket(), [{atom(), any()}]) -> any(). - +-spec start(atom(), sockmod(), socket(), [proplists:propery()]) + -> {ok, pid() | independent} | {error, inet:posix() | any()}. start(Module, SockMod, Socket, Opts) -> case Module:socket_type() of + independent -> {ok, independent}; xml_stream -> - MaxStanzaSize = case lists:keysearch(max_stanza_size, 1, - Opts) - of - {value, {_, Size}} -> Size; - _ -> infinity - end, - {ReceiverMod, Receiver, RecRef} = case catch - SockMod:custom_receiver(Socket) - of + MaxStanzaSize = proplists:get_value(max_stanza_size, Opts, infinity), + {ReceiverMod, Receiver, RecRef} = + try SockMod:custom_receiver(Socket) of {receiver, RecMod, RecPid} -> - {RecMod, RecPid, RecMod}; - _ -> - RecPid = - ejabberd_receiver:start(Socket, - SockMod, - none, - MaxStanzaSize), - {ejabberd_receiver, RecPid, - RecPid} + {RecMod, RecPid, RecMod} + catch _:_ -> + RecPid = ejabberd_receiver:start( + Socket, SockMod, none, MaxStanzaSize), + {ejabberd_receiver, RecPid, RecPid} end, SocketData = #socket_state{sockmod = SockMod, socket = Socket, receiver = RecRef}, case Module:start({?MODULE, SocketData}, Opts) of {ok, Pid} -> case SockMod:controlling_process(Socket, Receiver) of - ok -> ok; - {error, _Reason} -> SockMod:close(Socket) - end, - ReceiverMod:become_controller(Receiver, Pid); - {error, _Reason} -> + ok -> + ReceiverMod:become_controller(Receiver, Pid), + {ok, Receiver}; + Err -> + SockMod:close(Socket), + Err + end; + Err -> SockMod:close(Socket), case ReceiverMod of ejabberd_receiver -> ReceiverMod:close(Receiver); _ -> ok - end + end, + Err end; - independent -> ok; raw -> case Module:start({SockMod, Socket}, Opts) of {ok, Pid} -> case SockMod:controlling_process(Socket, Pid) of - ok -> ok; - {error, _Reason} -> SockMod:close(Socket) + ok -> + {ok, Pid}; + {error, _} = Err -> + SockMod:close(Socket), + Err end; - {error, _Reason} -> SockMod:close(Socket) + Err -> + SockMod:close(Socket), + Err end end. @@ -147,25 +159,31 @@ connect(Addr, Port, Opts, Timeout, Owner) -> {error, _Reason} = Error -> Error end. -starttls(SocketData, TLSOpts) -> - {ok, TLSSocket} = fast_tls:tcp_to_tls(SocketData#socket_state.socket, TLSOpts), - ejabberd_receiver:starttls(SocketData#socket_state.receiver, TLSSocket), - SocketData#socket_state{socket = TLSSocket, sockmod = fast_tls}. - -starttls(SocketData, TLSOpts, Data) -> - {ok, TLSSocket} = fast_tls:tcp_to_tls(SocketData#socket_state.socket, TLSOpts), - ejabberd_receiver:starttls(SocketData#socket_state.receiver, TLSSocket), - send(SocketData, Data), - SocketData#socket_state{socket = TLSSocket, sockmod = fast_tls}. +starttls(#socket_state{socket = Socket, + receiver = Receiver} = SocketData, TLSOpts) -> + case fast_tls:tcp_to_tls(Socket, TLSOpts) of + {ok, TLSSocket} -> + case ejabberd_receiver:starttls(Receiver, TLSSocket) of + ok -> + {ok, SocketData#socket_state{socket = TLSSocket, + sockmod = fast_tls}}; + {error, _} = Err -> + Err + end; + {error, _} = Err -> + Err + end. compress(SocketData) -> compress(SocketData, undefined). compress(SocketData, Data) -> - {ok, ZlibSocket} = - ejabberd_receiver:compress(SocketData#socket_state.receiver, - Data), - SocketData#socket_state{socket = ZlibSocket, - sockmod = ezlib}. + case ejabberd_receiver:compress(SocketData#socket_state.receiver, Data) of + {ok, ZlibSocket} -> + {ok, SocketData#socket_state{socket = ZlibSocket, sockmod = ezlib}}; + Err -> + ?ERROR_MSG("compress failed: ~p", [Err]), + Err + end. reset_stream(SocketData) when is_pid(SocketData#socket_state.receiver) -> @@ -174,29 +192,41 @@ reset_stream(SocketData) when is_atom(SocketData#socket_state.receiver) -> (SocketData#socket_state.receiver):reset_stream(SocketData#socket_state.socket). --spec send(socket_state(), iodata()) -> ok. +-spec send_element(socket_state(), fxml:xmlel()) -> ok | {error, inet:posix()}. +send_element(SocketData, El) when ?is_http_socket(SocketData) -> + send_xml(SocketData, {xmlstreamelement, El}); +send_element(SocketData, El) -> + send(SocketData, fxml:element_to_binary(El)). -send(SocketData, Data) -> - case catch (SocketData#socket_state.sockmod):send( - SocketData#socket_state.socket, Data) of - ok -> ok; - {error, timeout} -> - ?INFO_MSG("Timeout on ~p:send",[SocketData#socket_state.sockmod]), - exit(normal); - Error -> - ?DEBUG("Error in ~p:send: ~p",[SocketData#socket_state.sockmod, Error]), - exit(normal) +-spec send_header(socket_state(), fxml:xmlel()) -> ok | {error, inet:posix()}. +send_header(SocketData, El) when ?is_http_socket(SocketData) -> + send_xml(SocketData, {xmlstreamstart, El#xmlel.name, El#xmlel.attrs}); +send_header(SocketData, El) -> + send(SocketData, fxml:element_to_header(El)). + +-spec send_trailer(socket_state()) -> ok | {error, inet:posix()}. +send_trailer(SocketData) when ?is_http_socket(SocketData) -> + send_xml(SocketData, {xmlstreamend, <<"stream:stream">>}); +send_trailer(SocketData) -> + send(SocketData, <<"">>). + +-spec send(socket_state(), iodata()) -> ok | {error, inet:posix()}. +send(#socket_state{sockmod = SockMod, socket = Socket} = SocketData, Data) -> + ?DEBUG("(~s) Send XML on stream = ~p", [pp(SocketData), Data]), + try SockMod:send(Socket, Data) + catch _:badarg -> + %% Some modules throw badarg exceptions on closed sockets + %% TODO: their code should be improved + {error, einval} end. -%% Can only be called when in c2s StateData#state.xml_socket is true -%% This function is used for HTTP bind -%% sockmod=ejabberd_http_ws|ejabberd_http_bind or any custom module --spec send_xml(socket_state(), fxml:xmlel()) -> any(). - -send_xml(SocketData, Data) -> - catch - (SocketData#socket_state.sockmod):send_xml(SocketData#socket_state.socket, - Data). +-spec send_xml(socket_state(), + {xmlstreamelement, fxml:xmlel()} | + {xmlstreamstart, binary(), [{binary(), binary()}]} | + {xmlstreamend, binary()} | + {xmlstreamraw, iodata()}) -> term(). +send_xml(SocketData, El) -> + (SocketData#socket_state.sockmod):send_xml(SocketData#socket_state.socket, El). change_shaper(SocketData, Shaper) when is_pid(SocketData#socket_state.receiver) -> @@ -254,3 +284,7 @@ peername(#socket_state{sockmod = SockMod, gen_tcp -> inet:peername(Socket); _ -> SockMod:peername(Socket) end. + +pp(#socket_state{receiver = Receiver} = State) -> + Transport = get_transport(State), + io_lib:format("~s|~w", [Transport, Receiver]). diff --git a/src/ejabberd_sup.erl b/src/ejabberd_sup.erl index 5efaaba75..f9a48be4d 100644 --- a/src/ejabberd_sup.erl +++ b/src/ejabberd_sup.erl @@ -41,13 +41,6 @@ init([]) -> brutal_kill, worker, [ejabberd_hooks]}, - NodeGroups = - {ejabberd_node_groups, - {ejabberd_node_groups, start_link, []}, - permanent, - brutal_kill, - worker, - [ejabberd_node_groups]}, SystemMonitor = {ejabberd_system_monitor, {ejabberd_system_monitor, start_link, []}, @@ -55,20 +48,6 @@ init([]) -> brutal_kill, worker, [ejabberd_system_monitor]}, - Router = - {ejabberd_router, - {ejabberd_router, start_link, []}, - permanent, - brutal_kill, - worker, - [ejabberd_router]}, - Router_multicast = - {ejabberd_router_multicast, - {ejabberd_router_multicast, start_link, []}, - permanent, - brutal_kill, - worker, - [ejabberd_router_multicast]}, S2S = {ejabberd_s2s, {ejabberd_s2s, start_link, []}, @@ -76,13 +55,6 @@ init([]) -> brutal_kill, worker, [ejabberd_s2s]}, - Local = - {ejabberd_local, - {ejabberd_local, start_link, []}, - permanent, - brutal_kill, - worker, - [ejabberd_local]}, Captcha = {ejabberd_captcha, {ejabberd_captcha, start_link, []}, @@ -121,14 +93,6 @@ init([]) -> infinity, supervisor, [ejabberd_tmp_sup]}, - FrontendSocketSupervisor = - {ejabberd_frontend_socket_sup, - {ejabberd_tmp_sup, start_link, - [ejabberd_frontend_socket_sup, ejabberd_frontend_socket]}, - permanent, - infinity, - supervisor, - [ejabberd_tmp_sup]}, IQSupervisor = {ejabberd_iq_sup, {ejabberd_tmp_sup, start_link, @@ -139,16 +103,11 @@ init([]) -> [ejabberd_tmp_sup]}, {ok, {{one_for_one, 10, 1}, [Hooks, - NodeGroups, SystemMonitor, - Router, - Router_multicast, S2S, - Local, Captcha, S2SInSupervisor, S2SOutSupervisor, ServiceSupervisor, IQSupervisor, - FrontendSocketSupervisor, Listener]}}. diff --git a/src/ejabberd_web_admin.erl b/src/ejabberd_web_admin.erl index be3c54313..7ab5451c7 100644 --- a/src/ejabberd_web_admin.erl +++ b/src/ejabberd_web_admin.erl @@ -192,7 +192,7 @@ process([<<"server">>, SHost | RPath] = Path, method = Method} = Request) -> Host = jid:nameprep(SHost), - case lists:member(Host, ?MYHOSTS) of + case ejabberd_router:is_my_host(Host) of true -> case get_auth_admin(Auth, HostHTTP, Path, Method) of {ok, {User, Server}} -> diff --git a/src/gen_mod.erl b/src/gen_mod.erl index 0036fe2cd..c77726ef1 100644 --- a/src/gen_mod.erl +++ b/src/gen_mod.erl @@ -31,12 +31,13 @@ -export([start/0, start_module/2, start_module/3, stop_module/2, stop_module_keep_config/2, get_opt/3, - get_opt/4, get_opt_host/3, db_type/2, db_type/3, + get_opt/4, get_opt_host/3, opt_type/1, get_module_opt/4, get_module_opt/5, get_module_opt_host/3, loaded_modules/1, loaded_modules_with_opts/1, get_hosts/2, get_module_proc/2, is_loaded/2, start_modules/0, start_modules/1, stop_modules/0, stop_modules/1, - opt_type/1, db_mod/2, db_mod/3]). + db_mod/2, db_mod/3, ram_db_mod/2, ram_db_mod/3, + db_type/2, db_type/3, ram_db_type/2, ram_db_type/3]). %%-export([behaviour_info/1]). @@ -424,6 +425,43 @@ db_mod(Host, Module) when is_binary(Host) orelse Host == global -> db_mod(Host, Opts, Module) when is_list(Opts) -> db_mod(db_type(Host, Opts, Module), Module). +-spec ram_db_type(binary() | global, module()) -> db_type(); + (opts(), module()) -> db_type(). +ram_db_type(Opts, Module) when is_list(Opts) -> + ram_db_type(global, Opts, Module); +ram_db_type(Host, Module) when is_atom(Module) -> + case catch Module:mod_opt_type(ram_db_type) of + F when is_function(F) -> + case get_module_opt(Host, Module, ram_db_type, F) of + undefined -> ejabberd_config:default_ram_db(Host, Module); + Type -> Type + end; + _ -> + undefined + end. + +-spec ram_db_type(binary(), opts(), module()) -> db_type(). +ram_db_type(Host, Opts, Module) -> + case catch Module:mod_opt_type(ram_db_type) of + F when is_function(F) -> + case get_opt(ram_db_type, Opts, F) of + undefined -> ejabberd_config:default_ram_db(Host, Module); + Type -> Type + end; + _ -> + undefined + end. + +-spec ram_db_mod(binary() | global | db_type(), module()) -> module(). +ram_db_mod(Type, Module) when is_atom(Type), Type /= global -> + list_to_atom(atom_to_list(Module) ++ "_" ++ atom_to_list(Type)); +ram_db_mod(Host, Module) when is_binary(Host) orelse Host == global -> + ram_db_mod(ram_db_type(Host, Module), Module). + +-spec ram_db_mod(binary() | global, opts(), module()) -> module(). +ram_db_mod(Host, Opts, Module) when is_list(Opts) -> + ram_db_mod(ram_db_type(Host, Opts, Module), Module). + -spec loaded_modules(binary()) -> [atom()]. loaded_modules(Host) -> @@ -470,6 +508,5 @@ get_module_proc(Host, Base) -> is_loaded(Host, Module) -> ets:member(ejabberd_modules, {Module, Host}). -opt_type(default_db) -> fun(T) when is_atom(T) -> T end; opt_type(modules) -> fun (L) when is_list(L) -> L end; -opt_type(_) -> [default_db, modules]. +opt_type(_) -> [modules]. diff --git a/src/jlib.erl b/src/jlib.erl index a4d1127f4..33fc7d6bb 100644 --- a/src/jlib.erl +++ b/src/jlib.erl @@ -38,8 +38,8 @@ -export([tolower/1, term_to_base64/1, base64_to_term/1, decode_base64/1, encode_base64/1, ip_to_list/1, atom_to_binary/1, binary_to_atom/1, tuple_to_binary/1, - l2i/1, i2l/1, i2l/2, queue_drop_while/2, - expr_to_term/1, term_to_expr/1]). + l2i/1, i2l/1, i2l/2, expr_to_term/1, term_to_expr/1, + queue_drop_while/2, queue_foldl/3, queue_foldr/3, queue_foreach/2]). %% The following functions are used by gen_iq_handler.erl for providing backward %% compatibility and must not be used in other parts of the code @@ -974,3 +974,33 @@ queue_drop_while(F, Q) -> empty -> Q end. + +-spec queue_foldl(fun((term(), T) -> T), T, ?TQUEUE) -> T. +queue_foldl(F, Acc, Q) -> + case queue:out(Q) of + {{value, Item}, Q1} -> + Acc1 = F(Item, Acc), + queue_foldl(F, Acc1, Q1); + {empty, _} -> + Acc + end. + +-spec queue_foldr(fun((term(), T) -> T), T, ?TQUEUE) -> T. +queue_foldr(F, Acc, Q) -> + case queue:out_r(Q) of + {{value, Item}, Q1} -> + Acc1 = F(Item, Acc), + queue_foldr(F, Acc1, Q1); + {empty, _} -> + Acc + end. + +-spec queue_foreach(fun((_) -> _), ?TQUEUE) -> ok. +queue_foreach(F, Q) -> + case queue:out(Q) of + {{value, Item}, Q1} -> + F(Item), + queue_foreach(F, Q1); + {empty, _} -> + ok + end. diff --git a/src/mod_admin_extra.erl b/src/mod_admin_extra.erl index 7cef6af97..472e9fbe3 100644 --- a/src/mod_admin_extra.erl +++ b/src/mod_admin_extra.erl @@ -913,9 +913,8 @@ kick_session(User, Server, Resource, ReasonText) -> ok. kick_this_session(User, Server, Resource, Reason) -> - ejabberd_sm:route(jid:make(<<"">>, <<"">>, <<"">>), - jid:make(User, Server, Resource), - {broadcast, {exit, Reason}}). + ejabberd_sm:route(jid:make(User, Server, Resource), + {exit, Reason}). status_num(Host, Status) -> length(get_status_list(Host, Status)). @@ -943,7 +942,7 @@ get_status_list(Host, Status_required) -> end, Sessions3 = [ {Pid, Server, Priority} || {{_User, Server, _Resource}, {_, Pid}, Priority} <- Sessions2, apply(Fhost, [Server, Host])], %% For each Pid, get its presence - Sessions4 = [ {catch ejabberd_c2s:get_presence(Pid), Server, Priority} || {Pid, Server, Priority} <- Sessions3], + Sessions4 = [ {catch get_presence(Pid), Server, Priority} || {Pid, Server, Priority} <- Sessions3], %% Filter by status Fstatus = case Status_required of <<"all">> -> @@ -996,6 +995,16 @@ stringize(String) -> %% Replace newline characters with other code ejabberd_regexp:greplace(String, <<"\n">>, <<"\\n">>). +get_presence(Pid) -> + Pres = #presence{from = From} = ejabberd_c2s:get_presence(Pid), + Show = case Pres of + #presence{type = unavailable} -> <<"unavailable">>; + #presence{show = undefined} -> <<"available">>; + #presence{show = S} -> atom_to_binary(S, utf8) + end, + Status = xmpp:get_text(Pres#presence.status), + {From#jid.user, From#jid.resource, Show, Status}. + get_presence(U, S) -> Pids = [ejabberd_sm:get_session_pid(U, S, R) || R <- ejabberd_sm:get_user_resources(U, S)], @@ -1004,8 +1013,7 @@ get_presence(U, S) -> [] -> {jid:to_string({U, S, <<>>}), <<"unavailable">>, <<"">>}; [SessionPid|_] -> - {_User, Resource, Show, Status} = - ejabberd_c2s:get_presence(SessionPid), + {_User, Resource, Show, Status} = get_presence(SessionPid), FullJID = jid:to_string({U, S, Resource}), {FullJID, Show, Status} end. @@ -1048,7 +1056,7 @@ user_sessions_info(User, Host) -> fun(Session) -> {_U, _S, Resource} = Session#session.usr, {Now, Pid} = Session#session.sid, - {_U, _Resource, Status, StatusText} = ejabberd_c2s:get_presence(Pid), + {_U, _Resource, Status, StatusText} = get_presence(Pid), Info = Session#session.info, Priority = Session#session.priority, Conn = proplists:get_value(conn, Info), @@ -1301,7 +1309,7 @@ push_roster_item(LU, LS, U, S, Action) -> push_roster_item(LU, LS, R, U, S, Action) -> LJID = jid:make(LU, LS, R), BroadcastEl = build_broadcast(U, S, Action), - ejabberd_sm:route(LJID, LJID, BroadcastEl), + ejabberd_sm:route(LJID, BroadcastEl), Item = build_roster_item(U, S, Action), ResIQ = build_iq_roster_push(Item), ejabberd_router:route(jid:remove_resource(LJID), LJID, ResIQ). @@ -1326,7 +1334,7 @@ build_broadcast(U, S, remove) -> %% @spec (U::binary(), S::binary(), Subs::atom()) -> any() %% Subs = both | from | to | none build_broadcast(U, S, SubsAtom) when is_atom(SubsAtom) -> - {broadcast, {item, {U, S, <<>>}, SubsAtom}}. + {item, {U, S, <<>>}, SubsAtom}. %%% %%% Last Activity diff --git a/src/mod_announce.erl b/src/mod_announce.erl index a1f1ae157..d4740fa5f 100644 --- a/src/mod_announce.erl +++ b/src/mod_announce.erl @@ -68,7 +68,7 @@ start(Host, Opts) -> ejabberd_hooks:add(disco_local_items, Host, ?MODULE, disco_items, 50), ejabberd_hooks:add(adhoc_local_items, Host, ?MODULE, announce_items, 50), ejabberd_hooks:add(adhoc_local_commands, Host, ?MODULE, announce_commands, 50), - ejabberd_hooks:add(user_available_hook, Host, + ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, send_motd, 50), register(gen_mod:get_module_proc(Host, ?PROCNAME), proc_lib:spawn(?MODULE, init, [])). @@ -123,7 +123,7 @@ stop(Host) -> ejabberd_hooks:delete(disco_local_items, Host, ?MODULE, disco_items, 50), ejabberd_hooks:delete(local_send_to_resource_hook, Host, ?MODULE, announce, 50), - ejabberd_hooks:delete(user_available_hook, Host, + ejabberd_hooks:delete(c2s_self_presence, Host, ?MODULE, send_motd, 50), Proc = gen_mod:get_module_proc(Host, ?PROCNAME), exit(whereis(Proc), stop), @@ -733,8 +733,13 @@ announce_motd_delete(LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:delete_motd(LServer). --spec send_motd(jid()) -> ok | {atomic, any()}. -send_motd(#jid{luser = LUser, lserver = LServer} = JID) when LUser /= <<>> -> +-spec send_motd({presence(), ejabberd_c2s:state()}) -> {presence(), ejabberd_c2s:state()}. +send_motd({_, #{pres_last := _}} = Acc) -> + %% This is just a presence update, nothing to do + Acc; +send_motd({#presence{type = available}, + #{jid := #jid{luser = LUser, lserver = LServer} = JID}} = Acc) + when LUser /= <<>> -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:get_motd(LServer) of {ok, Packet} -> @@ -754,9 +759,10 @@ send_motd(#jid{luser = LUser, lserver = LServer} = JID) when LUser /= <<>> -> end; error -> ok - end; -send_motd(_) -> - ok. + end, + Acc; +send_motd(Acc) -> + Acc. get_stored_motd(LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), diff --git a/src/mod_block_strangers.erl b/src/mod_block_strangers.erl new file mode 100644 index 000000000..c304f20d5 --- /dev/null +++ b/src/mod_block_strangers.erl @@ -0,0 +1,107 @@ +%%%------------------------------------------------------------------- +%%% File : mod_block_strangers.erl +%%% Author : Alexey Shchepin +%%% Purpose : Block packets from non-subscribers +%%% Created : 25 Dec 2016 by Alexey Shchepin +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2017 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%------------------------------------------------------------------- +-module(mod_block_strangers). + +-author('alexey@process-one.net'). + +-behaviour(gen_mod). + +%% API +-export([start/2, stop/1, + depends/2, mod_opt_type/1]). + +-export([filter_packet/1]). + +-include("xmpp.hrl"). +-include("ejabberd.hrl"). +-include("logger.hrl"). + +-define(SETS, gb_sets). + +start(Host, _Opts) -> + ejabberd_hooks:add(user_receive_packet, Host, + ?MODULE, filter_packet, 25), + ok. + +stop(Host) -> + ejabberd_hooks:delete(user_receive_packet, Host, + ?MODULE, filter_packet, 25), + ok. + +filter_packet({#message{} = Msg, State} = Acc) -> + From = xmpp:get_from(Msg), + LFrom = jid:tolower(From), + LBFrom = jid:remove_resource(LFrom), + #{pres_a := PresA} = State, + case Msg#message.body == [] + orelse ejabberd_router:is_my_route(From#jid.lserver) + orelse (?SETS):is_element(LFrom, PresA) + orelse (?SETS):is_element(LBFrom, PresA) + orelse sets_bare_member(LBFrom, PresA) of + true -> + Acc; + false -> + #{lserver := LServer} = State, + Drop = + gen_mod:get_module_opt(LServer, ?MODULE, drop, + fun(B) when is_boolean(B) -> B end, + true), + Log = + gen_mod:get_module_opt(LServer, ?MODULE, log, + fun(B) when is_boolean(B) -> B end, + false), + if + Log -> + ?INFO_MSG("Drop packet: ~s", + [fxml:element_to_binary( + xmpp:encode(Msg, ?NS_CLIENT))]); + true -> + ok + end, + if + Drop -> + {stop, {drop, State}}; + true -> + Acc + end + end; +filter_packet(Acc) -> + Acc. + +sets_bare_member({U, S, <<"">>} = LBJID, Set) -> + case ?SETS:next(?SETS:iterator_from(LBJID, Set)) of + {{U, S, _}, _} -> true; + _ -> false + end. + + +depends(_Host, _Opts) -> + []. + +mod_opt_type(drop) -> + fun (B) when is_boolean(B) -> B end; +mod_opt_type(log) -> + fun (B) when is_boolean(B) -> B end; +mod_opt_type(_) -> [drop, log]. diff --git a/src/mod_blocking.erl b/src/mod_blocking.erl index d4192447a..3f9e90256 100644 --- a/src/mod_blocking.erl +++ b/src/mod_blocking.erl @@ -29,8 +29,8 @@ -protocol({xep, 191, '1.2'}). --export([start/2, stop/1, process_iq/1, - process_iq_set/3, process_iq_get/3, mod_opt_type/1, depends/2]). +-export([start/2, stop/1, process_iq/1, mod_opt_type/1, depends/2, + disco_features/5]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -48,72 +48,72 @@ start(Host, Opts) -> IQDisc = gen_mod:get_opt(iqdisc, Opts, fun gen_iq_handler:check_type/1, one_queue), - ejabberd_hooks:add(privacy_iq_get, Host, ?MODULE, - process_iq_get, 40), - ejabberd_hooks:add(privacy_iq_set, Host, ?MODULE, - process_iq_set, 40), - mod_disco:register_feature(Host, ?NS_BLOCKING), + ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 50), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_BLOCKING, ?MODULE, process_iq, IQDisc). stop(Host) -> - ejabberd_hooks:delete(privacy_iq_get, Host, ?MODULE, - process_iq_get, 40), - ejabberd_hooks:delete(privacy_iq_set, Host, ?MODULE, - process_iq_set, 40), - mod_disco:unregister_feature(Host, ?NS_BLOCKING), - gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, - ?NS_BLOCKING). + ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, disco_features, 50), + gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_BLOCKING). depends(_Host, _Opts) -> [{mod_privacy, hard}]. +-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty, + jid(), jid(), binary(), binary()) -> + {error, stanza_error()} | {result, [binary()]}. +disco_features({error, Err}, _From, _To, _Node, _Lang) -> + {error, Err}; +disco_features(empty, _From, _To, <<"">>, _Lang) -> + {result, [?NS_BLOCKING]}; +disco_features({result, Feats}, _From, _To, <<"">>, _Lang) -> + {result, [?NS_BLOCKING|Feats]}; +disco_features(Acc, _From, _To, _Node, _Lang) -> + Acc. + -spec process_iq(iq()) -> iq(). -process_iq(IQ) -> - xmpp:make_error(IQ, xmpp:err_not_allowed()). +process_iq(#iq{type = Type, + from = #jid{luser = U, lserver = S}, + to = #jid{luser = U, lserver = S}} = IQ) -> + case Type of + get -> process_iq_get(IQ); + set -> process_iq_set(IQ) + end; +process_iq(#iq{lang = Lang} = IQ) -> + Txt = <<"Query to another users is forbidden">>, + xmpp:make_error(IQ, xmpp:err_forbidden(Txt, Lang)). --spec process_iq_get({error, stanza_error()} | {result, xmpp_element() | undefined}, - iq(), userlist()) -> - {error, stanza_error()} | - {result, xmpp_element() | undefined}. -process_iq_get(_, #iq{lang = Lang, from = From, - sub_els = [#block_list{}]}, _) -> - #jid{luser = LUser, lserver = LServer} = From, - process_blocklist_get(LUser, LServer, Lang); -process_iq_get(Acc, _, _) -> Acc. +-spec process_iq_get(iq()) -> iq(). +process_iq_get(#iq{sub_els = [#block_list{}]} = IQ) -> + process_get(IQ); +process_iq_get(#iq{lang = Lang} = IQ) -> + Txt = <<"No module is handling this query">>, + xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)). --spec process_iq_set({error, stanza_error()} | - {result, xmpp_element() | undefined} | - {result, xmpp_element() | undefined, userlist()}, - iq(), userlist()) -> - {error, stanza_error()} | - {result, xmpp_element() | undefined} | - {result, xmpp_element() | undefined, userlist()}. -process_iq_set(Acc, #iq{from = From, lang = Lang, sub_els = [SubEl]}, _) -> - #jid{luser = LUser, lserver = LServer} = From, +-spec process_iq_set(iq()) -> iq(). +process_iq_set(#iq{lang = Lang, sub_els = [SubEl]} = IQ) -> case SubEl of #block{items = []} -> Txt = <<"No items found in this query">>, - {error, xmpp:err_bad_request(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang)); #block{items = Items} -> JIDs = [jid:tolower(Item) || Item <- Items], - process_blocklist_block(LUser, LServer, JIDs, Lang); + process_block(IQ, JIDs); #unblock{items = []} -> - process_blocklist_unblock_all(LUser, LServer, Lang); + process_unblock_all(IQ); #unblock{items = Items} -> JIDs = [jid:tolower(Item) || Item <- Items], - process_blocklist_unblock(LUser, LServer, JIDs, Lang); + process_unblock(IQ, JIDs); _ -> - Acc - end; -process_iq_set(Acc, _, _) -> Acc. + Txt = <<"No module is handling this query">>, + xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)) + end. --spec list_to_blocklist_jids([listitem()], [ljid()]) -> [ljid()]. -list_to_blocklist_jids([], JIDs) -> JIDs; -list_to_blocklist_jids([#listitem{type = jid, - action = deny, value = JID} = - Item - | Items], +-spec listitems_to_jids([listitem()], [ljid()]) -> [ljid()]. +listitems_to_jids([], JIDs) -> + JIDs; +listitems_to_jids([#listitem{type = jid, + action = deny, value = JID} = Item | Items], JIDs) -> Match = case Item of #listitem{match_all = true} -> @@ -126,20 +126,18 @@ list_to_blocklist_jids([#listitem{type = jid, _ -> false end, - if Match -> list_to_blocklist_jids(Items, [JID | JIDs]); - true -> list_to_blocklist_jids(Items, JIDs) + if Match -> listitems_to_jids(Items, [JID | JIDs]); + true -> listitems_to_jids(Items, JIDs) end; % Skip Privacy List items than cannot be mapped to Blocking items -list_to_blocklist_jids([_ | Items], JIDs) -> - list_to_blocklist_jids(Items, JIDs). +listitems_to_jids([_ | Items], JIDs) -> + listitems_to_jids(Items, JIDs). --spec process_blocklist_block(binary(), binary(), [ljid()], - binary()) -> - {error, stanza_error()} | - {result, undefined, userlist()}. -process_blocklist_block(LUser, LServer, JIDs, Lang) -> +-spec process_block(iq(), [ljid()]) -> iq(). +process_block(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ, JIDs) -> Filter = fun (List) -> - AlreadyBlocked = list_to_blocklist_jids(List, []), + AlreadyBlocked = listitems_to_jids(List, []), lists:foldr(fun (JID, List1) -> case lists:member(JID, AlreadyBlocked) of @@ -159,21 +157,19 @@ process_blocklist_block(LUser, LServer, JIDs, Lang) -> case Mod:process_blocklist_block(LUser, LServer, Filter) of {atomic, {ok, Default, List}} -> UserList = make_userlist(Default, List), - broadcast_list_update(LUser, LServer, Default, - UserList), - broadcast_blocklist_event(LUser, LServer, - {block, [jid:make(J) || J <- JIDs]}), - {result, undefined, UserList}; + broadcast_list_update(LUser, LServer, UserList, Default), + broadcast_event(LUser, LServer, + #block{items = [jid:make(J) || J <- JIDs]}), + xmpp:make_iq_result(IQ); _Err -> ?ERROR_MSG("Error processing ~p: ~p", [{LUser, LServer, JIDs}, _Err]), - {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)} + Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang), + xmpp:make_error(IQ, Err) end. --spec process_blocklist_unblock_all(binary(), binary(), binary()) -> - {error, stanza_error()} | - {result, undefined} | - {result, undefined, userlist()}. -process_blocklist_unblock_all(LUser, LServer, Lang) -> +-spec process_unblock_all(iq()) -> iq(). +process_unblock_all(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ) -> Filter = fun (List) -> lists:filter(fun (#listitem{action = A}) -> A =/= deny end, @@ -181,23 +177,22 @@ process_blocklist_unblock_all(LUser, LServer, Lang) -> end, Mod = db_mod(LServer), case Mod:unblock_by_filter(LUser, LServer, Filter) of - {atomic, ok} -> {result, undefined}; + {atomic, ok} -> + xmpp:make_iq_result(IQ); {atomic, {ok, Default, List}} -> UserList = make_userlist(Default, List), - broadcast_list_update(LUser, LServer, Default, - UserList), - broadcast_blocklist_event(LUser, LServer, unblock_all), - {result, undefined, UserList}; + broadcast_list_update(LUser, LServer, UserList, Default), + broadcast_event(LUser, LServer, #unblock{}), + xmpp:make_iq_result(IQ); _Err -> ?ERROR_MSG("Error processing ~p: ~p", [{LUser, LServer}, _Err]), - {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)} + Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang), + xmpp:make_error(IQ, Err) end. --spec process_blocklist_unblock(binary(), binary(), [ljid()], binary()) -> - {error, stanza_error()} | - {result, undefined} | - {result, undefined, userlist()}. -process_blocklist_unblock(LUser, LServer, JIDs, Lang) -> +-spec process_unblock(iq(), [ljid()]) -> iq(). +process_unblock(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ, JIDs) -> Filter = fun (List) -> lists:filter(fun (#listitem{action = deny, type = jid, value = JID}) -> @@ -208,17 +203,18 @@ process_blocklist_unblock(LUser, LServer, JIDs, Lang) -> end, Mod = db_mod(LServer), case Mod:unblock_by_filter(LUser, LServer, Filter) of - {atomic, ok} -> {result, undefined}; + {atomic, ok} -> + xmpp:make_iq_result(IQ); {atomic, {ok, Default, List}} -> UserList = make_userlist(Default, List), - broadcast_list_update(LUser, LServer, Default, - UserList), - broadcast_blocklist_event(LUser, LServer, - {unblock, [jid:make(J) || J <- JIDs]}), - {result, undefined, UserList}; + broadcast_list_update(LUser, LServer, UserList, Default), + broadcast_event(LUser, LServer, + #unblock{items = [jid:make(J) || J <- JIDs]}), + xmpp:make_iq_result(IQ); _Err -> ?ERROR_MSG("Error processing ~p: ~p", [{LUser, LServer, JIDs}, _Err]), - {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)} + Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang), + xmpp:make_error(IQ, Err) end. -spec make_userlist(binary(), [listitem()]) -> userlist(). @@ -226,29 +222,34 @@ make_userlist(Name, List) -> NeedDb = mod_privacy:is_list_needdb(List), #userlist{name = Name, list = List, needdb = NeedDb}. --spec broadcast_list_update(binary(), binary(), binary(), userlist()) -> ok. -broadcast_list_update(LUser, LServer, Name, UserList) -> - ejabberd_sm:route(jid:make(LUser, LServer, <<"">>), - jid:make(LUser, LServer, <<"">>), - {broadcast, {privacy_list, UserList, Name}}). +-spec broadcast_list_update(binary(), binary(), userlist(), binary()) -> ok. +broadcast_list_update(LUser, LServer, UserList, Name) -> + mod_privacy:push_list_update(jid:make(LUser, LServer), UserList, Name). --spec broadcast_blocklist_event(binary(), binary(), block_event()) -> ok. -broadcast_blocklist_event(LUser, LServer, Event) -> - JID = jid:make(LUser, LServer, <<"">>), - ejabberd_sm:route(JID, JID, - {broadcast, {blocking, Event}}). +-spec broadcast_event(binary(), binary(), block_event()) -> ok. +broadcast_event(LUser, LServer, Event) -> + From = jid:make(LUser, LServer), + lists:foreach( + fun(R) -> + To = jid:replace_resource(From, R), + IQ = #iq{type = set, from = From, to = To, + id = <<"push", (randoms:get_string())/binary>>, + sub_els = [Event]}, + ejabberd_router:route(From, To, IQ) + end, ejabberd_sm:get_user_resources(LUser, LServer)). --spec process_blocklist_get(binary(), binary(), binary()) -> - {error, stanza_error()} | {result, block_list()}. -process_blocklist_get(LUser, LServer, Lang) -> +-spec process_get(iq()) -> iq(). +process_get(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ) -> Mod = db_mod(LServer), case Mod:process_blocklist_get(LUser, LServer) of error -> - {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)}; + Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang), + xmpp:make_error(IQ, Err); List -> - LJIDs = list_to_blocklist_jids(List, []), + LJIDs = listitems_to_jids(List, []), Items = [jid:make(J) || J <- LJIDs], - {result, #block_list{items = Items}} + xmpp:make_iq_result(IQ, #block_list{items = Items}) end. -spec db_mod(binary()) -> module(). diff --git a/src/mod_bosh.erl b/src/mod_bosh.erl index cebf4e6ba..62dc31ac8 100644 --- a/src/mod_bosh.erl +++ b/src/mod_bosh.erl @@ -30,31 +30,25 @@ %%-define(ejabberd_debug, true). --behaviour(gen_server). -behaviour(gen_mod). -export([start_link/0]). -export([start/2, stop/1, process/2, open_session/2, close_session/1, find_session/1]). --export([init/1, handle_call/3, handle_cast/2, - handle_info/2, terminate/2, code_change/3, - depends/2, mod_opt_type/1]). +-export([depends/2, mod_opt_type/1]). -include("ejabberd.hrl"). -include("logger.hrl"). -include_lib("stdlib/include/ms_transform.hrl"). -include("jlib.hrl"). - -include("ejabberd_http.hrl"). - -include("bosh.hrl"). --record(bosh, {sid = <<"">> :: binary() | '_', - timestamp = p1_time_compat:timestamp() :: erlang:timestamp() | '_', - pid = self() :: pid() | '$1'}). - --record(state, {}). +-callback init() -> any(). +-callback open_session(binary(), pid()) -> any(). +-callback close_session(binary()) -> any(). +-callback find_session(binary()) -> {ok, pid()} | error. %%%---------------------------------------------------------------------- %%% API @@ -83,137 +77,35 @@ process(_Path, _Request) -> children = [{xmlcdata, <<"400 Bad Request">>}]}}. open_session(SID, Pid) -> - Session = #bosh{sid = SID, timestamp = p1_time_compat:timestamp(), pid = Pid}, - lists:foreach( - fun(Node) when Node == node() -> - gen_server:call(?MODULE, {write, Session}); - (Node) -> - cluster_send({?MODULE, Node}, {write, Session}) - end, ejabberd_cluster:get_nodes()). + Mod = gen_mod:ram_db_mod(global, ?MODULE), + Mod:open_session(SID, Pid). close_session(SID) -> - case mnesia:dirty_read(bosh, SID) of - [Session] -> - lists:foreach( - fun(Node) when Node == node() -> - gen_server:call(?MODULE, {delete, Session}); - (Node) -> - cluster_send({?MODULE, Node}, {delete, Session}) - end, ejabberd_cluster:get_nodes()); - [] -> - ok - end. - -write_session(#bosh{pid = Pid1, sid = SID, timestamp = T1} = S1) -> - case mnesia:dirty_read(bosh, SID) of - [#bosh{pid = Pid2, timestamp = T2} = S2] -> - if Pid1 == Pid2 -> - mnesia:dirty_write(S1); - T1 < T2 -> - cluster_send(Pid2, replaced), - mnesia:dirty_write(S1); - true -> - cluster_send(Pid1, replaced), - mnesia:dirty_write(S2) - end; - [] -> - mnesia:dirty_write(S1) - end. - -delete_session(#bosh{sid = SID, pid = Pid1}) -> - case mnesia:dirty_read(bosh, SID) of - [#bosh{pid = Pid2}] -> - if Pid1 == Pid2 -> - mnesia:dirty_delete(bosh, SID); - true -> - ok - end; - [] -> - ok - end. + Mod = gen_mod:ram_db_mod(global, ?MODULE), + Mod:close_session(SID). find_session(SID) -> - case mnesia:dirty_read(bosh, SID) of - [#bosh{pid = Pid}] -> - {ok, Pid}; - [] -> - error - end. + Mod = gen_mod:ram_db_mod(global, ?MODULE), + Mod:find_session(SID). start(Host, Opts) -> - setup_database(), start_jiffy(Opts), TmpSup = gen_mod:get_module_proc(Host, ?PROCNAME), TmpSupSpec = {TmpSup, {ejabberd_tmp_sup, start_link, [TmpSup, ejabberd_bosh]}, permanent, infinity, supervisor, [ejabberd_tmp_sup]}, - ProcSpec = {?MODULE, - {?MODULE, start_link, []}, - transient, 2000, worker, [?MODULE]}, - case supervisor:start_child(ejabberd_sup, ProcSpec) of - {ok, _} -> - supervisor:start_child(ejabberd_sup, TmpSupSpec); - {error, {already_started, _}} -> - supervisor:start_child(ejabberd_sup, TmpSupSpec); - Err -> - Err - end. + supervisor:start_child(ejabberd_sup, TmpSupSpec), + Mod = gen_mod:ram_db_mod(global, ?MODULE), + Mod:init(). stop(Host) -> TmpSup = gen_mod:get_module_proc(Host, ?PROCNAME), supervisor:terminate_child(ejabberd_sup, TmpSup), supervisor:delete_child(ejabberd_sup, TmpSup). -%%%=================================================================== -%%% gen_server callbacks -%%%=================================================================== -init([]) -> - {ok, #state{}}. - -handle_call({write, Session}, _From, State) -> - Res = write_session(Session), - {reply, Res, State}; -handle_call({delete, Session}, _From, State) -> - Res = delete_session(Session), - {reply, Res, State}; -handle_call(_Request, _From, State) -> - Reply = ok, - {reply, Reply, State}. - -handle_cast(_Msg, State) -> - {noreply, State}. - -handle_info({write, Session}, State) -> - write_session(Session), - {noreply, State}; -handle_info({delete, Session}, State) -> - delete_session(Session), - {noreply, State}; -handle_info(_Info, State) -> - ?ERROR_MSG("got unexpected info: ~p", [_Info]), - {noreply, State}. - -terminate(_Reason, _State) -> - ok. - -code_change(_OldVsn, State, _Extra) -> - {ok, State}. - %%%=================================================================== %%% Internal functions %%%=================================================================== -setup_database() -> - case catch mnesia:table_info(bosh, attributes) of - [sid, pid] -> - mnesia:delete_table(bosh); - _ -> - ok - end, - ejabberd_mnesia:create(?MODULE, bosh, - [{ram_copies, [node()]}, {local_content, true}, - {attributes, record_info(fields, bosh)}]), - mnesia:add_table_copy(bosh, node(), ram_copies). - start_jiffy(Opts) -> case gen_mod:get_opt(json, Opts, fun(false) -> false; @@ -241,9 +133,6 @@ get_type(Hdrs) -> xml end. -cluster_send(NodePid, Msg) -> - erlang:send(NodePid, Msg, [noconnect, nosuspend]). - depends(_Host, _Opts) -> []. @@ -261,8 +150,10 @@ mod_opt_type(max_pause) -> fun (I) when is_integer(I), I > 0 -> I end; mod_opt_type(prebind) -> fun (B) when is_boolean(B) -> B end; +mod_opt_type(ram_db_type) -> + fun(T) -> ejabberd_config:v_db(?MODULE, T) end; mod_opt_type(_) -> - [json, max_concat, max_inactivity, max_pause, prebind]. + [json, max_concat, max_inactivity, max_pause, prebind, ram_db_type]. %%%---------------------------------------------------------------------- diff --git a/src/mod_bosh_mnesia.erl b/src/mod_bosh_mnesia.erl new file mode 100644 index 000000000..b61ef20a1 --- /dev/null +++ b/src/mod_bosh_mnesia.erl @@ -0,0 +1,163 @@ +%%%------------------------------------------------------------------- +%%% Created : 12 Jan 2017 by Evgeny Khramtsov +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2017 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%------------------------------------------------------------------- +-module(mod_bosh_mnesia). + +-behaviour(gen_server). +-behaviour(mod_bosh). + +%% mod_bosh API +-export([init/0, open_session/2, close_session/1, find_session/1]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-include("logger.hrl"). + +-record(bosh, {sid = <<"">> :: binary() | '_', + timestamp = p1_time_compat:timestamp() :: erlang:timestamp() | '_', + pid = self() :: pid() | '$1'}). + +-record(state, {}). + +%%%=================================================================== +%%% API +%%%=================================================================== +init() -> + case gen_server:start_link({local, ?MODULE}, ?MODULE, [], []) of + {ok, _Pid} -> + ok; + Err -> + Err + end. + +open_session(SID, Pid) -> + Session = #bosh{sid = SID, timestamp = p1_time_compat:timestamp(), pid = Pid}, + lists:foreach( + fun(Node) when Node == node() -> + gen_server:call(?MODULE, {write, Session}); + (Node) -> + cluster_send({?MODULE, Node}, {write, Session}) + end, ejabberd_cluster:get_nodes()). + +close_session(SID) -> + case mnesia:dirty_read(bosh, SID) of + [Session] -> + lists:foreach( + fun(Node) when Node == node() -> + gen_server:call(?MODULE, {delete, Session}); + (Node) -> + cluster_send({?MODULE, Node}, {delete, Session}) + end, ejabberd_cluster:get_nodes()); + [] -> + ok + end. + +find_session(SID) -> + case mnesia:dirty_read(bosh, SID) of + [#bosh{pid = Pid}] -> + {ok, Pid}; + [] -> + error + end. + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([]) -> + setup_database(), + {ok, #state{}}. + +handle_call({write, Session}, _From, State) -> + Res = write_session(Session), + {reply, Res, State}; +handle_call({delete, Session}, _From, State) -> + Res = delete_session(Session), + {reply, Res, State}; +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info({write, Session}, State) -> + write_session(Session), + {noreply, State}; +handle_info({delete, Session}, State) -> + delete_session(Session), + {noreply, State}; +handle_info(_Info, State) -> + ?ERROR_MSG("got unexpected info: ~p", [_Info]), + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +write_session(#bosh{pid = Pid1, sid = SID, timestamp = T1} = S1) -> + case mnesia:dirty_read(bosh, SID) of + [#bosh{pid = Pid2, timestamp = T2} = S2] -> + if Pid1 == Pid2 -> + mnesia:dirty_write(S1); + T1 < T2 -> + cluster_send(Pid2, replaced), + mnesia:dirty_write(S1); + true -> + cluster_send(Pid1, replaced), + mnesia:dirty_write(S2) + end; + [] -> + mnesia:dirty_write(S1) + end. + +delete_session(#bosh{sid = SID, pid = Pid1}) -> + case mnesia:dirty_read(bosh, SID) of + [#bosh{pid = Pid2}] -> + if Pid1 == Pid2 -> + mnesia:dirty_delete(bosh, SID); + true -> + ok + end; + [] -> + ok + end. + +cluster_send(NodePid, Msg) -> + erlang:send(NodePid, Msg, [noconnect, nosuspend]). + +setup_database() -> + case catch mnesia:table_info(bosh, attributes) of + [sid, pid] -> + mnesia:delete_table(bosh); + _ -> + ok + end, + ejabberd_mnesia:create(?MODULE, bosh, + [{ram_copies, [node()]}, {local_content, true}, + {attributes, record_info(fields, bosh)}]), + mnesia:add_table_copy(bosh, node(), ram_copies). diff --git a/src/mod_caps.erl b/src/mod_caps.erl index 132f1ee72..391a3ba74 100644 --- a/src/mod_caps.erl +++ b/src/mod_caps.erl @@ -35,10 +35,10 @@ -behaviour(gen_mod). --export([read_caps/1, caps_stream_features/2, +-export([read_caps/1, list_features/1, caps_stream_features/2, disco_features/5, disco_identity/5, disco_info/5, get_features/2, export/1, import_info/0, import/5, - import_start/2, import_stop/2]). + get_user_caps/2, import_start/2, import_stop/2]). %% gen_mod callbacks -export([start/2, start_link/2, stop/1, depends/2]). @@ -47,9 +47,8 @@ -export([init/1, handle_info/2, handle_call/3, handle_cast/2, terminate/2, code_change/3]). --export([user_send_packet/4, user_receive_packet/5, - c2s_presence_in/2, c2s_filter_packet/6, - c2s_broadcast_recipients/6, mod_opt_type/1]). +-export([user_send_packet/1, user_receive_packet/1, + c2s_presence_in/2, mod_opt_type/1]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -104,6 +103,22 @@ get_features(Host, #caps{node = Node, version = Version, end, [], SubNodes). +-spec list_features(ejabberd_c2s:state()) -> [{ljid(), caps()}]. +list_features(C2SState) -> + Rs = maps:get(caps_features, C2SState, gb_trees:empty()), + gb_trees:to_list(Rs). + +-spec get_user_caps(jid(), ejabberd_c2s:state()) -> {ok, caps()} | error. +get_user_caps(JID, C2SState) -> + Rs = maps:get(caps_features, C2SState, gb_trees:empty()), + LJID = jid:tolower(JID), + case gb_trees:lookup(LJID, Rs) of + {value, Caps} -> + {ok, Caps}; + none -> + error + end. + -spec read_caps(#presence{}) -> nothing | caps(). read_caps(Presence) -> case xmpp:get_subtag(Presence, #caps{}) of @@ -111,47 +126,51 @@ read_caps(Presence) -> Caps -> Caps end. --spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -user_send_packet(#presence{type = available} = Pkt, - _C2SState, - #jid{luser = User, lserver = Server} = From, - #jid{luser = User, lserver = Server, - lresource = <<"">>}) -> +-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_send_packet({#presence{type = available, + from = #jid{luser = U, lserver = LServer} = From, + to = #jid{luser = U, lserver = LServer, + lresource = <<"">>}} = Pkt, + State}) -> case read_caps(Pkt) of nothing -> ok; #caps{version = Version, exts = Exts} = Caps -> - feature_request(Server, From, Caps, [Version | Exts]) + feature_request(LServer, From, Caps, [Version | Exts]) end, - Pkt; -user_send_packet(Pkt, _C2SState, _From, _To) -> - Pkt. + {Pkt, State}; +user_send_packet(Acc) -> + Acc. --spec user_receive_packet(stanza(), ejabberd_c2s:state(), - jid(), jid(), jid()) -> stanza(). -user_receive_packet(#presence{type = available} = Pkt, - _C2SState, - #jid{lserver = Server}, - From, _To) -> - IsRemote = not lists:member(From#jid.lserver, ?MYHOSTS), +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_receive_packet({#presence{from = From, type = available} = Pkt, + #{lserver := LServer} = State}) -> + IsRemote = not ejabberd_router:is_my_host(From#jid.lserver), if IsRemote -> case read_caps(Pkt) of nothing -> ok; #caps{version = Version, exts = Exts} = Caps -> - feature_request(Server, From, Caps, [Version | Exts]) + feature_request(LServer, From, Caps, [Version | Exts]) end; true -> ok end, - Pkt; -user_receive_packet(Pkt, _C2SState, _JID, _From, _To) -> - Pkt. + {Pkt, State}; +user_receive_packet(Acc) -> + Acc. -spec caps_stream_features([xmpp_element()], binary()) -> [xmpp_element()]. caps_stream_features(Acc, MyHost) -> + case gen_mod:is_loaded(MyHost, ?MODULE) of + true -> case make_my_disco_hash(MyHost) of - <<"">> -> Acc; + <<"">> -> + Acc; Hash -> - [#caps{hash = <<"sha-1">>, node = ?EJABBERD_URI, version = Hash}|Acc] + [#caps{hash = <<"sha-1">>, node = ?EJABBERD_URI, + version = Hash}|Acc] + end; + false -> + Acc end. -spec disco_features({error, stanza_error()} | {result, [binary()]} | empty, @@ -194,23 +213,16 @@ disco_info(Acc, Host, Module, Node, Lang) when is_atom(Module) -> disco_info(Acc, _, _, _Node, _Lang) -> Acc. --spec c2s_presence_in(ejabberd_c2s:state(), {jid(), jid(), presence()}) -> - ejabberd_c2s:state(). +-spec c2s_presence_in(ejabberd_c2s:state(), presence()) -> ejabberd_c2s:state(). c2s_presence_in(C2SState, - {From, To, #presence{type = Type} = Presence}) -> - Subscription = ejabberd_c2s:get_subscription(From, - C2SState), + #presence{from = From, to = To, type = Type} = Presence) -> + Subscription = ejabberd_c2s:get_subscription(From, C2SState), Insert = (Type == available) and ((Subscription == both) or (Subscription == to)), Delete = (Type == unavailable) or (Type == error), if Insert or Delete -> LFrom = jid:tolower(From), - Rs = case ejabberd_c2s:get_aux_field(caps_resources, - C2SState) - of - {ok, Rs1} -> Rs1; - error -> gb_trees:empty() - end, + Rs = maps:get(caps_resources, C2SState, gb_trees:empty()), Caps = read_caps(Presence), NewRs = case Caps of nothing when Insert == true -> Rs; @@ -230,51 +242,11 @@ c2s_presence_in(C2SState, end; _ -> gb_trees:delete_any(LFrom, Rs) end, - ejabberd_c2s:set_aux_field(caps_resources, NewRs, - C2SState); - true -> C2SState + C2SState#{caps_resources => NewRs}; + true -> + C2SState end. --spec c2s_filter_packet(boolean(), binary(), ejabberd_c2s:state(), - {pep_message, binary()}, jid(), stanza()) -> - boolean(). -c2s_filter_packet(InAcc, Host, C2SState, {pep_message, Feature}, To, _Packet) -> - case ejabberd_c2s:get_aux_field(caps_resources, C2SState) of - {ok, Rs} -> - LTo = jid:tolower(To), - case gb_trees:lookup(LTo, Rs) of - {value, Caps} -> - Drop = not lists:member(Feature, get_features(Host, Caps)), - {stop, Drop}; - none -> - {stop, true} - end; - _ -> InAcc - end; -c2s_filter_packet(Acc, _, _, _, _, _) -> Acc. - --spec c2s_broadcast_recipients([ljid()], binary(), ejabberd_c2s:state(), - {pep_message, binary()}, jid(), stanza()) -> - [ljid()]. -c2s_broadcast_recipients(InAcc, Host, C2SState, - {pep_message, Feature}, _From, _Packet) -> - case ejabberd_c2s:get_aux_field(caps_resources, - C2SState) - of - {ok, Rs} -> - gb_trees_fold(fun (USR, Caps, Acc) -> - case lists:member(Feature, - get_features(Host, Caps)) - of - true -> [USR | Acc]; - false -> Acc - end - end, - InAcc, Rs); - _ -> InAcc - end; -c2s_broadcast_recipients(Acc, _, _, _, _, _) -> Acc. - -spec depends(binary(), gen_mod:opts()) -> [{module(), hard | soft}]. depends(_Host, _Opts) -> []. @@ -292,17 +264,13 @@ init([Host, Opts]) -> [{max_size, MaxSize}, {life_time, LifeTime}]), ejabberd_hooks:add(c2s_presence_in, Host, ?MODULE, c2s_presence_in, 75), - ejabberd_hooks:add(c2s_filter_packet, Host, ?MODULE, - c2s_filter_packet, 75), - ejabberd_hooks:add(c2s_broadcast_recipients, Host, - ?MODULE, c2s_broadcast_recipients, 75), ejabberd_hooks:add(user_send_packet, Host, ?MODULE, user_send_packet, 75), ejabberd_hooks:add(user_receive_packet, Host, ?MODULE, user_receive_packet, 75), - ejabberd_hooks:add(c2s_stream_features, Host, ?MODULE, + ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE, caps_stream_features, 75), - ejabberd_hooks:add(s2s_stream_features, Host, ?MODULE, + ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE, caps_stream_features, 75), ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 75), @@ -325,17 +293,13 @@ terminate(_Reason, State) -> Host = State#state.host, ejabberd_hooks:delete(c2s_presence_in, Host, ?MODULE, c2s_presence_in, 75), - ejabberd_hooks:delete(c2s_filter_packet, Host, ?MODULE, - c2s_filter_packet, 75), - ejabberd_hooks:delete(c2s_broadcast_recipients, Host, - ?MODULE, c2s_broadcast_recipients, 75), ejabberd_hooks:delete(user_send_packet, Host, ?MODULE, user_send_packet, 75), ejabberd_hooks:delete(user_receive_packet, Host, ?MODULE, user_receive_packet, 75), - ejabberd_hooks:delete(c2s_stream_features, Host, + ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE, caps_stream_features, 75), - ejabberd_hooks:delete(s2s_stream_features, Host, + ejabberd_hooks:delete(s2s_in_post_auth_features, Host, ?MODULE, caps_stream_features, 75), ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, disco_features, 75), @@ -494,20 +458,6 @@ concat_xdata_fields(#xdata{fields = Fields} = X) -> is_binary(Var), Var /= <<"FORM_TYPE">>], [Form, $<, lists:sort(Res)]. --spec gb_trees_fold(fun((_, _, T) -> T), T, gb_trees:tree()) -> T. -gb_trees_fold(F, Acc, Tree) -> - Iter = gb_trees:iterator(Tree), - gb_trees_fold_iter(F, Acc, Iter). - --spec gb_trees_fold_iter(fun((_, _, T) -> T), T, gb_trees:iter()) -> T. -gb_trees_fold_iter(F, Acc, Iter) -> - case gb_trees:next(Iter) of - {Key, Val, NewIter} -> - NewAcc = F(Key, Val, Acc), - gb_trees_fold_iter(F, NewAcc, NewIter); - _ -> Acc - end. - -spec now_ts() -> integer(). now_ts() -> p1_time_compat:system_time(seconds). diff --git a/src/mod_carboncopy.erl b/src/mod_carboncopy.erl index b9c09fab2..202c7005a 100644 --- a/src/mod_carboncopy.erl +++ b/src/mod_carboncopy.erl @@ -35,8 +35,8 @@ -export([start/2, stop/1]). --export([user_send_packet/4, user_receive_packet/5, - iq_handler/1, remove_connection/4, +-export([user_send_packet/1, user_receive_packet/1, + iq_handler/1, remove_connection/4, disco_features/5, is_carbon_copy/1, mod_opt_type/1, depends/2]). -include("ejabberd.hrl"). @@ -59,7 +59,7 @@ is_carbon_copy(_) -> start(Host, Opts) -> IQDisc = gen_mod:get_opt(iqdisc, Opts,fun gen_iq_handler:check_type/1, one_queue), - mod_disco:register_feature(Host, ?NS_CARBONS_2), + ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 50), Mod = gen_mod:db_mod(Host, ?MODULE), Mod:init(Host, Opts), ejabberd_hooks:add(unset_presence_hook,Host, ?MODULE, remove_connection, 10), @@ -70,12 +70,24 @@ start(Host, Opts) -> stop(Host) -> gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_CARBONS_2), - mod_disco:unregister_feature(Host, ?NS_CARBONS_2), + ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, disco_features, 50), %% why priority 89: to define clearly that we must run BEFORE mod_logdb hook (90) ejabberd_hooks:delete(user_send_packet,Host, ?MODULE, user_send_packet, 89), ejabberd_hooks:delete(user_receive_packet,Host, ?MODULE, user_receive_packet, 89), ejabberd_hooks:delete(unset_presence_hook,Host, ?MODULE, remove_connection, 10). +-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty, + jid(), jid(), binary(), binary()) -> + {error, stanza_error()} | {result, [binary()]}. +disco_features({error, Err}, _From, _To, _Node, _Lang) -> + {error, Err}; +disco_features(empty, _From, _To, <<"">>, _Lang) -> + {result, [?NS_CARBONS_2]}; +disco_features({result, Feats}, _From, _To, <<"">>, _Lang) -> + {result, [?NS_CARBONS_2|Feats]}; +disco_features(Acc, _From, _To, _Node, _Lang) -> + Acc. + -spec iq_handler(iq()) -> iq(). iq_handler(#iq{type = set, lang = Lang, from = From, sub_els = [El]} = IQ) when is_record(El, carbons_enable); @@ -105,16 +117,24 @@ iq_handler(#iq{type = get, lang = Lang} = IQ)-> Txt = <<"Value 'get' of 'type' attribute is not allowed">>, xmpp:make_error(IQ, xmpp:err_not_allowed(Txt, Lang)). --spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> - stanza() | {stop, stanza()}. -user_send_packet(Packet, _C2SState, From, To) -> - check_and_forward(From, To, Packet, sent). +-spec user_send_packet({stanza(), ejabberd_c2s:state()}) + -> {stanza(), ejabberd_c2s:state()} | {stop, {stanza(), ejabberd_c2s:state()}}. +user_send_packet({Packet, C2SState}) -> + From = xmpp:get_from(Packet), + To = xmpp:get_to(Packet), + case check_and_forward(From, To, Packet, sent) of + {stop, Pkt} -> {stop, {Pkt, C2SState}}; + Pkt -> {Pkt, C2SState} + end. --spec user_receive_packet(stanza(), ejabberd_c2s:state(), - jid(), jid(), jid()) -> - stanza() | {stop, stanza()}. -user_receive_packet(Packet, _C2SState, JID, _From, To) -> - check_and_forward(JID, To, Packet, received). +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) + -> {stanza(), ejabberd_c2s:state()} | {stop, {stanza(), ejabberd_c2s:state()}}. +user_receive_packet({Packet, #{jid := JID} = C2SState}) -> + To = xmpp:get_to(Packet), + case check_and_forward(JID, To, Packet, received) of + {stop, Pkt} -> {stop, {Pkt, C2SState}}; + Pkt -> {Pkt, C2SState} + end. % Modified from original version: % - registered to the user_send_packet hook, to be called only once even for multicast diff --git a/src/mod_client_state.erl b/src/mod_client_state.erl index aab89f0d6..d38de6832 100644 --- a/src/mod_client_state.erl +++ b/src/mod_client_state.erl @@ -34,8 +34,11 @@ -export([start/2, stop/1, mod_opt_type/1, depends/2]). %% ejabberd_hooks callbacks. --export([filter_presence/4, filter_chat_states/4, filter_pep/4, filter_other/4, - flush_queue/3, add_stream_feature/2]). +-export([filter_presence/1, filter_chat_states/1, + filter_pep/1, filter_other/1, + c2s_stream_started/2, add_stream_feature/2, + c2s_copy_session/2, c2s_authenticated_packet/2, + c2s_session_resumed/1]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -44,9 +47,10 @@ -define(CSI_QUEUE_MAX, 100). -type csi_type() :: presence | chatstate | {pep, binary()}. --type csi_key() :: {ljid(), csi_type()}. --type csi_stanza() :: {csi_key(), erlang:timestamp(), xmlel()}. --type csi_queue() :: [csi_stanza()]. +-type csi_queue() :: {non_neg_integer(), non_neg_integer(), map()}. +-type csi_timestamp() :: {non_neg_integer(), erlang:timestamp()}. +-type c2s_state() :: ejabberd_c2s:state(). +-type filter_acc() :: {stanza() | drop, c2s_state()}. %%-------------------------------------------------------------------- %% gen_mod callbacks. @@ -68,27 +72,33 @@ start(Host, Opts) -> fun(B) when is_boolean(B) -> B end, true), if QueuePresence; QueueChatStates; QueuePEP -> + ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE, + c2s_stream_started, 50), ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE, add_stream_feature, 50), + ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE, + c2s_authenticated_packet, 50), + ejabberd_hooks:add(c2s_copy_session, Host, ?MODULE, + c2s_copy_session, 50), + ejabberd_hooks:add(c2s_session_resumed, Host, ?MODULE, + c2s_session_resumed, 50), if QueuePresence -> - ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, filter_presence, 50); true -> ok end, if QueueChatStates -> - ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, filter_chat_states, 50); true -> ok end, if QueuePEP -> - ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, filter_pep, 50); true -> ok end, - ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE, - filter_other, 100), - ejabberd_hooks:add(csi_flush_queue, Host, ?MODULE, - flush_queue, 50); + ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, + filter_other, 75); true -> ok end. @@ -108,27 +118,33 @@ stop(Host) -> fun(B) when is_boolean(B) -> B end, true), if QueuePresence; QueueChatStates; QueuePEP -> + ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE, + c2s_stream_started, 50), ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE, add_stream_feature, 50), + ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE, + c2s_authenticated_packet, 50), + ejabberd_hooks:delete(c2s_copy_session, Host, ?MODULE, + c2s_copy_session, 50), + ejabberd_hooks:delete(c2s_session_resumed, Host, ?MODULE, + c2s_session_resumed, 50), if QueuePresence -> - ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, filter_presence, 50); true -> ok end, if QueueChatStates -> - ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, filter_chat_states, 50); true -> ok end, if QueuePEP -> - ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE, + ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, filter_pep, 50); true -> ok end, - ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE, - filter_other, 100), - ejabberd_hooks:delete(csi_flush_queue, Host, ?MODULE, - flush_queue, 50); + ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, + filter_other, 75); true -> ok end. @@ -150,29 +166,46 @@ depends(_Host, _Opts) -> %%-------------------------------------------------------------------- %% ejabberd_hooks callbacks. %%-------------------------------------------------------------------- +-spec c2s_stream_started(c2s_state(), stream_start()) -> c2s_state(). +c2s_stream_started(State, _) -> + State#{csi_state => active, csi_queue => queue_new()}. --spec filter_presence({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza()) - -> {ejabberd_c2s:state(), [stanza()]} | - {stop, {ejabberd_c2s:state(), [stanza()]}}. +-spec c2s_authenticated_packet(c2s_state(), xmpp_element()) -> c2s_state(). +c2s_authenticated_packet(C2SState, #csi{type = active}) -> + C2SState1 = C2SState#{csi_state => active}, + flush_queue(C2SState1); +c2s_authenticated_packet(C2SState, #csi{type = inactive}) -> + C2SState#{csi_state => inactive}; +c2s_authenticated_packet(C2SState, _) -> + C2SState. -filter_presence({C2SState, _OutStanzas} = Acc, Host, To, - #presence{type = Type} = Stanza) -> - if Type == available; Type == unavailable -> - ?DEBUG("Got availability presence stanza for ~s", - [jid:to_string(To)]), - queue_add(presence, Stanza, Host, C2SState); - true -> - Acc - end; -filter_presence(Acc, _Host, _To, _Stanza) -> Acc. +-spec c2s_copy_session(c2s_state(), c2s_state()) -> c2s_state(). +c2s_copy_session(C2SState, #{csi_state := State, csi_queue := Q}) -> + C2SState#{csi_state => State, csi_queue => Q}; +c2s_copy_session(C2SState, _) -> + C2SState. --spec filter_chat_states({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza()) - -> {ejabberd_c2s:state(), [stanza()]} | - {stop, {ejabberd_c2s:state(), [stanza()]}}. +-spec c2s_session_resumed(c2s_state()) -> c2s_state(). +c2s_session_resumed(C2SState) -> + flush_queue(C2SState). -filter_chat_states({C2SState, _OutStanzas} = Acc, Host, To, - #message{from = From} = Stanza) -> - case xmpp_util:is_standalone_chat_state(Stanza) of +-spec filter_presence(filter_acc()) -> filter_acc(). +filter_presence({#presence{meta = #{csi_resend := true}}, _} = Acc) -> + Acc; +filter_presence({#presence{to = To, type = Type} = Pres, + #{csi_state := inactive} = C2SState}) + when Type == available; Type == unavailable -> + ?DEBUG("Got availability presence stanza for ~s", [jid:to_string(To)]), + enqueue_stanza(presence, Pres, C2SState); +filter_presence(Acc) -> + Acc. + +-spec filter_chat_states(filter_acc()) -> filter_acc(). +filter_chat_states({#message{meta = #{csi_resend := true}}, _} = Acc) -> + Acc; +filter_chat_states({#message{from = From, to = To} = Msg, + #{csi_state := inactive} = C2SState} = Acc) -> + case xmpp_util:is_standalone_chat_state(Msg) of true -> case {From, To} of {#jid{luser = U, lserver = S}, #jid{luser = U, lserver = S}} -> @@ -183,108 +216,107 @@ filter_chat_states({C2SState, _OutStanzas} = Acc, Host, To, _ -> ?DEBUG("Got standalone chat state notification for ~s", [jid:to_string(To)]), - queue_add(chatstate, Stanza, Host, C2SState) + enqueue_stanza(chatstate, Msg, C2SState) end; false -> Acc end; -filter_chat_states(Acc, _Host, _To, _Stanza) -> Acc. +filter_chat_states(Acc) -> + Acc. --spec filter_pep({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza()) - -> {ejabberd_c2s:state(), [stanza()]} | - {stop, {ejabberd_c2s:state(), [stanza()]}}. - -filter_pep({C2SState, _OutStanzas} = Acc, Host, To, #message{} = Stanza) -> - case get_pep_node(Stanza) of +-spec filter_pep(filter_acc()) -> filter_acc(). +filter_pep({#message{meta = #{csi_resend := true}}, _} = Acc) -> + Acc; +filter_pep({#message{to = To} = Msg, + #{csi_state := inactive} = C2SState} = Acc) -> + case get_pep_node(Msg) of undefined -> Acc; Node -> ?DEBUG("Got PEP notification for ~s", [jid:to_string(To)]), - queue_add({pep, Node}, Stanza, Host, C2SState) + enqueue_stanza({pep, Node}, Msg, C2SState) end; -filter_pep(Acc, _Host, _To, _Stanza) -> Acc. +filter_pep(Acc) -> + Acc. --spec filter_other({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza()) - -> {ejabberd_c2s:state(), [stanza()]}. +-spec filter_other(filter_acc()) -> filter_acc(). +filter_other({Stanza, #{jid := JID} = C2SState} = Acc) when ?is_stanza(Stanza) -> + case xmpp:get_meta(Stanza) of + #{csi_resend := true} -> + Acc; + _ -> + ?DEBUG("Won't add stanza for ~s to CSI queue", [jid:to_string(JID)]), + From = xmpp:get_from(Stanza), + C2SState1 = dequeue_sender(From, C2SState), + {Stanza, C2SState1} + end; +filter_other(Acc) -> + Acc. -filter_other({C2SState, _OutStanzas}, Host, To, Stanza) -> - ?DEBUG("Won't add stanza for ~s to CSI queue", [jid:to_string(To)]), - queue_take(Stanza, Host, C2SState). - --spec flush_queue({ejabberd_c2s:state(), [stanza()]}, binary(), jid()) - -> {ejabberd_c2s:state(), [stanza()]}. - -flush_queue({C2SState, _OutStanzas}, Host, JID) -> - ?DEBUG("Going to flush CSI queue of ~s", [jid:to_string(JID)]), - Queue = get_queue(C2SState), - NewState = set_queue([], C2SState), - {NewState, get_stanzas(Queue, Host)}. - --spec add_stream_feature([stanza()], binary) -> [stanza()]. - -add_stream_feature(Features, _Host) -> - [#feature_csi{xmlns = <<"urn:xmpp:csi:0">>} | Features]. +-spec add_stream_feature([xmpp_element()], binary) -> [xmpp_element()]. +add_stream_feature(Features, Host) -> + case gen_mod:is_loaded(Host, ?MODULE) of + true -> + [#feature_csi{xmlns = <<"urn:xmpp:csi:0">>} | Features]; + false -> + Features + end. %%-------------------------------------------------------------------- %% Internal functions. %%-------------------------------------------------------------------- - --spec queue_add(csi_type(), stanza(), binary(), term()) - -> {stop, {term(), [stanza()]}}. - -queue_add(Type, Stanza, Host, C2SState) -> - case get_queue(C2SState) of - Queue when length(Queue) >= ?CSI_QUEUE_MAX -> +-spec enqueue_stanza(csi_type(), stanza(), c2s_state()) -> filter_acc(). +enqueue_stanza(Type, Stanza, #{csi_state := inactive, + csi_queue := Q} = C2SState) -> + case queue_len(Q) >= ?CSI_QUEUE_MAX of + true -> ?DEBUG("CSI queue too large, going to flush it", []), - NewState = set_queue([], C2SState), - {stop, {NewState, get_stanzas(Queue, Host) ++ [Stanza]}}; - Queue -> - ?DEBUG("Adding stanza to CSI queue", []), - From = xmpp:get_from(Stanza), - Key = {jid:tolower(From), Type}, - Entry = {Key, p1_time_compat:timestamp(), Stanza}, - NewQueue = lists:keystore(Key, 1, Queue, Entry), - NewState = set_queue(NewQueue, C2SState), - {stop, {NewState, []}} - end. + C2SState1 = flush_queue(C2SState), + enqueue_stanza(Type, Stanza, C2SState1); + false -> + #jid{luser = U, lserver = S} = xmpp:get_from(Stanza), + Q1 = queue_in({U, S}, Type, Stanza, Q), + {stop, {drop, C2SState#{csi_queue => Q1}}} + end; +enqueue_stanza(_Type, Stanza, State) -> + {Stanza, State}. --spec queue_take(stanza(), binary(), term()) -> {term(), [stanza()]}. - -queue_take(Stanza, Host, C2SState) -> - From = xmpp:get_from(Stanza), - {LUser, LServer, _LResource} = jid:tolower(From), - {Selected, Rest} = lists:partition( - fun({{{U, S, _R}, _Type}, _Time, _Stanza}) -> - U == LUser andalso S == LServer - end, get_queue(C2SState)), - NewState = set_queue(Rest, C2SState), - {NewState, get_stanzas(Selected, Host) ++ [Stanza]}. - --spec set_queue(csi_queue(), term()) -> term(). - -set_queue(Queue, C2SState) -> - ejabberd_c2s:set_aux_field(csi_queue, Queue, C2SState). - --spec get_queue(term()) -> csi_queue(). - -get_queue(C2SState) -> - case ejabberd_c2s:get_aux_field(csi_queue, C2SState) of - {ok, Queue} -> - Queue; +-spec dequeue_sender(jid(), c2s_state()) -> c2s_state(). +dequeue_sender(#jid{luser = U, lserver = S}, + #{csi_queue := Q, jid := JID} = C2SState) -> + ?DEBUG("Flushing packets of ~s@~s from CSI queue of ~s", + [U, S, jid:to_string(JID)]), + case queue_take({U, S}, Q) of + {Stanzas, Q1} -> + C2SState1 = flush_stanzas(C2SState, Stanzas), + C2SState1#{csi_queue => Q1}; error -> - [] + C2SState end. --spec get_stanzas(csi_queue(), binary()) -> [stanza()]. +-spec flush_queue(c2s_state()) -> c2s_state(). +flush_queue(#{csi_queue := Q, jid := JID} = C2SState) -> + ?DEBUG("Flushing CSI queue of ~s", [jid:to_string(JID)]), + C2SState1 = flush_stanzas(C2SState, queue_to_list(Q)), + C2SState1#{csi_queue => queue_new()}. -get_stanzas(Queue, Host) -> - lists:map(fun({_Key, Time, Stanza}) -> - xmpp_util:add_delay_info(Stanza, jid:make(Host), Time, - <<"Client Inactive">>) - end, Queue). +-spec flush_stanzas(c2s_state(), + [{csi_type(), csi_timestamp(), stanza()}]) -> c2s_state(). +flush_stanzas(#{lserver := LServer} = C2SState, Elems) -> + lists:foldl( + fun({_Type, Time, Stanza}, AccState) -> + Stanza1 = add_delay_info(Stanza, LServer, Time), + ejabberd_c2s:send(AccState, Stanza1) + end, C2SState, Elems). + +-spec add_delay_info(stanza(), binary(), csi_timestamp()) -> stanza(). +add_delay_info(Stanza, LServer, {_Seq, TimeStamp}) -> + Stanza1 = xmpp_util:add_delay_info( + Stanza, jid:make(LServer), TimeStamp, + <<"Client Inactive">>), + xmpp:put_meta(Stanza1, csi_resend, true). -spec get_pep_node(message()) -> binary() | undefined. - get_pep_node(#message{from = #jid{luser = <<>>}}) -> %% It's not PEP. undefined; @@ -295,3 +327,53 @@ get_pep_node(#message{} = Msg) -> _ -> undefined end. + +%%-------------------------------------------------------------------- +%% Queue interface +%%-------------------------------------------------------------------- +-spec queue_new() -> csi_queue(). +queue_new() -> + {0, 0, #{}}. + +-spec queue_in(term(), term(), term(), csi_queue()) -> csi_queue(). +queue_in(Key, Type, Val, {N, Seq, Q}) -> + Seq1 = Seq + 1, + Time = {Seq1, p1_time_compat:timestamp()}, + try maps:get(Key, Q) of + TypeVals -> + case lists:keymember(Type, 1, TypeVals) of + true -> + TypeVals1 = lists:keyreplace( + Type, 1, TypeVals, {Type, Time, Val}), + Q1 = maps:put(Key, TypeVals1, Q), + {N, Seq1, Q1}; + false -> + TypeVals1 = [{Type, Time, Val}|TypeVals], + Q1 = maps:put(Key, TypeVals1, Q), + {N + 1, Seq1, Q1} + end + catch _:{badkey, _} -> + Q1 = maps:put(Key, [{Type, Time, Val}], Q), + {N + 1, Seq1, Q1} + end. + +-spec queue_take(term(), csi_queue()) -> {list(), csi_queue()} | error. +queue_take(Key, {N, Seq, Q}) -> + case maps:take(Key, Q) of + {TypeVals, Q1} -> + {lists:keysort(2, TypeVals), {N-length(TypeVals), Seq, Q1}}; + error -> + error + end. + +-spec queue_len(csi_queue()) -> non_neg_integer(). +queue_len({N, _, _}) -> + N. + +-spec queue_to_list(csi_queue()) -> [term()]. +queue_to_list({_, _, Q}) -> + TypeVals = maps:fold( + fun(_, Vals, Acc) -> + Vals ++ Acc + end, [], Q), + lists:keysort(2, TypeVals). diff --git a/src/mod_disco.erl b/src/mod_disco.erl index b6e1a4a16..73f691dc6 100644 --- a/src/mod_disco.erl +++ b/src/mod_disco.erl @@ -37,9 +37,7 @@ get_local_features/5, get_local_services/5, process_sm_iq_items/1, process_sm_iq_info/1, get_sm_identity/5, get_sm_features/5, get_sm_items/5, - get_info/5, register_feature/2, unregister_feature/2, - register_extra_domain/2, unregister_extra_domain/2, - transform_module_options/1, mod_opt_type/1, depends/2]). + get_info/5, transform_module_options/1, mod_opt_type/1, depends/2]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -48,8 +46,10 @@ -include_lib("stdlib/include/ms_transform.hrl"). -include("mod_roster.hrl"). +-type features_acc() :: {error, stanza_error()} | {result, [binary()]} | empty. +-type items_acc() :: {error, stanza_error()} | {result, [disco_item()]} | empty. + start(Host, Opts) -> - ejabberd_local:refresh_iq_handlers(), IQDisc = gen_mod:get_opt(iqdisc, Opts, fun gen_iq_handler:check_type/1, one_queue), gen_iq_handler:add_iq_handler(ejabberd_local, Host, @@ -64,12 +64,9 @@ start(Host, Opts) -> gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_DISCO_INFO, ?MODULE, process_sm_iq_info, IQDisc), - catch ets:new(disco_features, - [named_table, ordered_set, public]), - register_feature(Host, <<"iq">>), - register_feature(Host, <<"presence">>), catch ets:new(disco_extra_domains, - [named_table, ordered_set, public]), + [named_table, ordered_set, public, + {heir, erlang:group_leader(), none}]), ExtraDomains = gen_mod:get_opt(extra_domains, Opts, fun(Hs) -> [iolist_to_binary(H) || H <- Hs] @@ -78,10 +75,6 @@ start(Host, Opts) -> register_extra_domain(Host, Domain) end, ExtraDomains), - catch ets:new(disco_sm_features, - [named_table, ordered_set, public]), - catch ets:new(disco_sm_nodes, - [named_table, ordered_set, public]), ejabberd_hooks:add(disco_local_items, Host, ?MODULE, get_local_services, 100), ejabberd_hooks:add(disco_local_features, Host, ?MODULE, @@ -121,35 +114,14 @@ stop(Host) -> ?NS_DISCO_ITEMS), gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_DISCO_INFO), - catch ets:match_delete(disco_features, {{'_', Host}}), catch ets:match_delete(disco_extra_domains, {{'_', Host}}), ok. --spec register_feature(binary(), binary()) -> true. -register_feature(Host, Feature) -> - catch ets:new(disco_features, - [named_table, ordered_set, public]), - ets:insert(disco_features, {{Feature, Host}}). - --spec unregister_feature(binary(), binary()) -> true. -unregister_feature(Host, Feature) -> - catch ets:new(disco_features, - [named_table, ordered_set, public]), - ets:delete(disco_features, {Feature, Host}). - -spec register_extra_domain(binary(), binary()) -> true. register_extra_domain(Host, Domain) -> - catch ets:new(disco_extra_domains, - [named_table, ordered_set, public]), ets:insert(disco_extra_domains, {{Domain, Host}}). --spec unregister_extra_domain(binary(), binary()) -> true. -unregister_extra_domain(Host, Domain) -> - catch ets:new(disco_extra_domains, - [named_table, ordered_set, public]), - ets:delete(disco_extra_domains, {Domain, Host}). - -spec process_local_iq_items(iq()) -> iq(). process_local_iq_items(#iq{type = set, lang = Lang} = IQ) -> Txt = <<"Value 'set' of 'type' attribute is not allowed">>, @@ -198,22 +170,18 @@ get_local_identity(Acc, _From, _To, <<"">>, _Lang) -> get_local_identity(Acc, _From, _To, _Node, _Lang) -> Acc. --spec get_local_features({error, stanza_error()} | {result, [binary()]} | empty, - jid(), jid(), binary(), binary()) -> +-spec get_local_features(features_acc(), jid(), jid(), binary(), binary()) -> {error, stanza_error()} | {result, [binary()]}. get_local_features({error, _Error} = Acc, _From, _To, _Node, _Lang) -> Acc; -get_local_features(Acc, _From, To, <<"">>, _Lang) -> +get_local_features(Acc, _From, _To, <<"">>, _Lang) -> Feats = case Acc of {result, Features} -> Features; empty -> [] end, - Host = To#jid.lserver, - {result, - ets:select(disco_features, - ets:fun2ms(fun({{F, H}}) when H == Host -> F end)) - ++ Feats}; + {result, [<<"iq">>, <<"presence">>, + ?NS_DISCO_INFO, ?NS_DISCO_ITEMS |Feats]}; get_local_features(Acc, _From, _To, _Node, Lang) -> case Acc of {result, _Features} -> Acc; @@ -222,9 +190,7 @@ get_local_features(Acc, _From, _To, _Node, Lang) -> {error, xmpp:err_item_not_found(Txt, Lang)} end. --spec get_local_services({error, stanza_error()} | {result, [disco_item()]} | empty, - jid(), jid(), - binary(), binary()) -> +-spec get_local_services(items_acc(), jid(), jid(), binary(), binary()) -> {error, stanza_error()} | {result, [disco_item()]}. get_local_services({error, _Error} = Acc, _From, _To, _Node, _Lang) -> @@ -269,7 +235,7 @@ get_vh_services(Host) -> [VH | _] -> VH == Host end end, - ejabberd_router:dirty_get_all_routes()). + ejabberd_router:get_all_routes()). %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% @@ -296,9 +262,7 @@ process_sm_iq_items(#iq{type = get, lang = Lang, xmpp:make_error(IQ, xmpp:err_subscription_required(Txt, Lang)) end. --spec get_sm_items({error, stanza_error()} | {result, [disco_item()]} | empty, - jid(), jid(), - binary(), binary()) -> +-spec get_sm_items(items_acc(), jid(), jid(), binary(), binary()) -> {error, stanza_error()} | {result, [disco_item()]}. get_sm_items({error, _Error} = Acc, _From, _To, _Node, _Lang) -> @@ -383,8 +347,7 @@ get_sm_identity(Acc, _From, _ -> [] end. --spec get_sm_features({error, stanza_error()} | {result, [binary()]} | empty, - jid(), jid(), binary(), binary()) -> +-spec get_sm_features(features_acc(), jid(), jid(), binary(), binary()) -> {error, stanza_error()} | {result, [binary()]}. get_sm_features(empty, From, To, _Node, Lang) -> #jid{luser = LFrom, lserver = LSFrom} = From, diff --git a/src/mod_fail2ban.erl b/src/mod_fail2ban.erl index 2b5e0bfc5..2c6ff618c 100644 --- a/src/mod_fail2ban.erl +++ b/src/mod_fail2ban.erl @@ -28,7 +28,8 @@ -behaviour(gen_server). %% API --export([start_link/2, start/2, stop/1, c2s_auth_result/4, check_bl_c2s/3]). +-export([start_link/2, start/2, stop/1, c2s_auth_result/3, + c2s_stream_started/2]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3, @@ -37,6 +38,7 @@ -include_lib("stdlib/include/ms_transform.hrl"). -include("ejabberd.hrl"). -include("logger.hrl"). +-include("xmpp.hrl"). -define(C2S_AUTH_BAN_LIFETIME, 3600). %% 1 hour -define(C2S_MAX_AUTH_FAILURES, 20). @@ -51,12 +53,12 @@ start_link(Host, Opts) -> Proc = gen_mod:get_module_proc(Host, ?MODULE), gen_server:start_link({local, Proc}, ?MODULE, [Host, Opts], []). --spec c2s_auth_result(boolean(), binary(), binary(), - {inet:ip_address(), non_neg_integer()}) -> ok. -c2s_auth_result(false, _User, LServer, {Addr, _Port}) -> +-spec c2s_auth_result(ejabberd_c2s:state(), boolean(), binary()) + -> ejabberd_c2s:state() | {stop, ejabberd_c2s:state()}. +c2s_auth_result(#{ip := {Addr, _}, lserver := LServer} = State, false, _User) -> case is_whitelisted(LServer, Addr) of true -> - ok; + State; false -> BanLifetime = gen_mod:get_module_opt( LServer, ?MODULE, c2s_auth_ban_lifetime, @@ -67,47 +69,41 @@ c2s_auth_result(false, _User, LServer, {Addr, _Port}) -> fun(I) when is_integer(I), I > 0 -> I end, ?C2S_MAX_AUTH_FAILURES), UnbanTS = p1_time_compat:system_time(seconds) + BanLifetime, - case ets:lookup(failed_auth, Addr) of + Attempts = case ets:lookup(failed_auth, Addr) of [{Addr, N, _, _}] -> - ets:insert(failed_auth, {Addr, N+1, UnbanTS, MaxFailures}); + ets:insert(failed_auth, + {Addr, N+1, UnbanTS, MaxFailures}), + N+1; [] -> - ets:insert(failed_auth, {Addr, 1, UnbanTS, MaxFailures}) + ets:insert(failed_auth, + {Addr, 1, UnbanTS, MaxFailures}), + 1 end, - ok + if Attempts >= MaxFailures -> + log_and_disconnect(State, Attempts, UnbanTS); + true -> + State + end end; -c2s_auth_result(true, _User, _Server, _AddrPort) -> - ok. +c2s_auth_result(#{ip := {Addr, _}} = State, true, _User) -> + ets:delete(failed_auth, Addr), + State. --spec check_bl_c2s({true, binary(), binary()} | false, - {inet:ip_address(), non_neg_integer()}, - binary()) -> {stop, {true, binary(), binary()}} | false. -check_bl_c2s(_Acc, Addr, Lang) -> +-spec c2s_stream_started(ejabberd_c2s:state(), stream_start()) + -> ejabberd_c2s:state() | {stop, ejabberd_c2s:state()}. +c2s_stream_started(#{ip := {Addr, _}} = State, _) -> + ets:tab2list(failed_auth), case ets:lookup(failed_auth, Addr) of [{Addr, N, TS, MaxFailures}] when N >= MaxFailures -> case TS > p1_time_compat:system_time(seconds) of true -> - IP = jlib:ip_to_list(Addr), - UnbanDate = format_date( - calendar:now_to_universal_time(seconds_to_now(TS))), - LogReason = io_lib:fwrite( - "Too many (~p) failed authentications " - "from this IP address (~s). The address " - "will be unblocked at ~s UTC", - [N, IP, UnbanDate]), - ReasonT = io_lib:fwrite( - translate:translate( - Lang, - <<"Too many (~p) failed authentications " - "from this IP address (~s). The address " - "will be unblocked at ~s UTC">>), - [N, IP, UnbanDate]), - {stop, {true, LogReason, ReasonT}}; + log_and_disconnect(State, N, TS); false -> ets:delete(failed_auth, Addr), - false + State end; _ -> - false + State end. %%==================================================================== @@ -133,7 +129,7 @@ depends(_Host, _Opts) -> %%%=================================================================== init([Host, _Opts]) -> ejabberd_hooks:add(c2s_auth_result, Host, ?MODULE, c2s_auth_result, 100), - ejabberd_hooks:add(check_bl_c2s, ?MODULE, check_bl_c2s, 100), + ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE, c2s_stream_started, 100), erlang:send_after(?CLEAN_INTERVAL, self(), clean), {ok, #state{host = Host}}. @@ -159,11 +155,11 @@ handle_info(_Info, State) -> terminate(_Reason, #state{host = Host}) -> ejabberd_hooks:delete(c2s_auth_result, Host, ?MODULE, c2s_auth_result, 100), + ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE, c2s_stream_started, 100), case is_loaded_at_other_hosts(Host) of true -> ok; false -> - ejabberd_hooks:delete(check_bl_c2s, ?MODULE, check_bl_c2s, 100), ets:delete(failed_auth) end. @@ -173,6 +169,21 @@ code_change(_OldVsn, State, _Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== +-spec log_and_disconnect(ejabberd_c2s:state(), pos_integer(), erlang:timestamp()) + -> {stop, ejabberd_c2s:state()}. +log_and_disconnect(#{ip := {Addr, _}, lang := Lang} = State, Attempts, UnbanTS) -> + IP = jlib:ip_to_list(Addr), + UnbanDate = format_date( + calendar:now_to_universal_time(seconds_to_now(UnbanTS))), + Format = <<"Too many (~p) failed authentications " + "from this IP address (~s). The address " + "will be unblocked at ~s UTC">>, + Args = [Attempts, IP, UnbanDate], + ?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s", + [IP, io_lib:fwrite(Format, Args)]), + Err = xmpp:serr_policy_violation({Format, Args}, Lang), + {stop, ejabberd_c2s:send(State, Err)}. + is_whitelisted(Host, Addr) -> Access = gen_mod:get_module_opt(Host, ?MODULE, access, fun(A) -> A end, diff --git a/src/mod_http_fileserver.erl b/src/mod_http_fileserver.erl index 239c8bd39..f837e8689 100644 --- a/src/mod_http_fileserver.erl +++ b/src/mod_http_fileserver.erl @@ -46,7 +46,7 @@ %% utility for other http modules -export([content_type/3]). --export([reopen_log/1, mod_opt_type/1, depends/2]). +-export([reopen_log/0, mod_opt_type/1, depends/2]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -236,7 +236,7 @@ check_docroot_is_readable(DRInfo, DocRoot) -> try_open_log(undefined, _Host) -> undefined; -try_open_log(FN, Host) -> +try_open_log(FN, _Host) -> FD = try open_log(FN) of FD1 -> FD1 catch @@ -244,7 +244,7 @@ try_open_log(FN, Host) -> ?ERROR_MSG("Cannot open access log file: ~p~nReason: ~p", [FN, Reason]), undefined end, - ejabberd_hooks:add(reopen_log_hook, Host, ?MODULE, reopen_log, 50), + ejabberd_hooks:add(reopen_log_hook, ?MODULE, reopen_log, 50), FD. %%-------------------------------------------------------------------- @@ -298,7 +298,8 @@ handle_info(_Info, State) -> %%-------------------------------------------------------------------- terminate(_Reason, State) -> close_log(State#state.accesslogfd), - ejabberd_hooks:delete(reopen_log_hook, State#state.host, ?MODULE, reopen_log, 50), + %% TODO: unregister the hook gracefully + %% ejabberd_hooks:delete(reopen_log_hook, State#state.host, ?MODULE, reopen_log, 50), ok. %%-------------------------------------------------------------------- @@ -410,8 +411,11 @@ reopen_log(FN, FD) -> close_log(FD), open_log(FN). -reopen_log(Host) -> - gen_server:cast(get_proc_name(Host), reopen_log). +reopen_log() -> + lists:foreach( + fun(Host) -> + gen_server:cast(get_proc_name(Host), reopen_log) + end, ?MYHOSTS). add_to_log(FileSize, Code, Request) -> gen_server:cast(get_proc_name(Request#request.host), diff --git a/src/mod_http_upload.erl b/src/mod_http_upload.erl index 021f8d3bc..55efc1ab0 100644 --- a/src/mod_http_upload.erl +++ b/src/mod_http_upload.erl @@ -139,8 +139,6 @@ start(ServerHost, Opts) -> true) of true -> ejabberd_hooks:add(remove_user, ServerHost, ?MODULE, - remove_user, 50), - ejabberd_hooks:add(anonymous_purge_hook, ServerHost, ?MODULE, remove_user, 50); false -> ok @@ -162,8 +160,6 @@ stop(ServerHost) -> true) of true -> ejabberd_hooks:delete(remove_user, ServerHost, ?MODULE, - remove_user, 50), - ejabberd_hooks:delete(anonymous_purge_hook, ServerHost, ?MODULE, remove_user, 50); false -> ok diff --git a/src/mod_ip_blacklist.erl b/src/mod_ip_blacklist.erl deleted file mode 100644 index b4a0b1aa0..000000000 --- a/src/mod_ip_blacklist.erl +++ /dev/null @@ -1,139 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : mod_ip_blacklist.erl -%%% Author : Mickael Remond -%%% Purpose : Download blacklists from ProcessOne -%%% Created : 5 May 2008 by Mickael Remond -%%% Usage : Add the following line in modules section of ejabberd.cfg: -%%% {mod_ip_blacklist, []} -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2017 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(mod_ip_blacklist). - --author('mremond@process-one.net'). - --behaviour(gen_mod). - -%% API: --export([start/2, preinit/2, init/1, stop/1]). - --export([update_bl_c2s/0]). - --export([is_ip_in_c2s_blacklist/3, mod_opt_type/1, depends/2]). - --include("ejabberd.hrl"). --include("logger.hrl"). - --define(PROCNAME, ?MODULE). - --define(BLC2S, - <<"http://xaai.process-one.net/bl_c2s.txt">>). - --define(UPDATE_INTERVAL, 6). - --record(state, {timer}). - -%% Start once for all vhost --record(bl_c2s, {ip = <<"">> :: binary()}). - -start(_Host, _Opts) -> - Pid = spawn(?MODULE, preinit, [self(), #state{}]), - receive {ok, Pid, PreinitResult} -> PreinitResult end. - -preinit(Parent, State) -> - Pid = self(), - try register(?PROCNAME, Pid) of - true -> Parent ! {ok, Pid, true}, init(State) - catch - error:_ -> Parent ! {ok, Pid, true} - end. - -depends(_Host, _Opts) -> - []. - -%% TODO: -stop(_Host) -> ok. - -init(State) -> - ets:new(bl_c2s, - [named_table, public, {keypos, #bl_c2s.ip}]), - update_bl_c2s(), - ejabberd_hooks:add(check_bl_c2s, ?MODULE, - is_ip_in_c2s_blacklist, 50), - timer:apply_interval(timer:hours(?UPDATE_INTERVAL), - ?MODULE, update_bl_c2s, []), - loop(State). - -%% Remove timer when stop is received. -loop(_State) -> receive stop -> ok end. - -%% Download blacklist file from ProcessOne XAAI -%% and update the table internal table -%% TODO: Support comment lines starting by % -update_bl_c2s() -> - ?INFO_MSG("Updating C2S Blacklist", []), - case p1_http:get(?BLC2S) of - {ok, 200, _Headers, Body} -> - IPs = str:tokens(iolist_to_binary(Body), <<"\n">>), - ets:delete_all_objects(bl_c2s), - lists:foreach(fun (IP) -> - ets:insert(bl_c2s, - #bl_c2s{ip = IP}) - end, - IPs); - {error, Reason} -> - ?ERROR_MSG("Cannot download C2S blacklist file. " - "Reason: ~p", - [Reason]) - end. - -%% Hook is run with: -%% ejabberd_hooks:run_fold(check_bl_c2s, false, [IP]), -%% Return: false: IP not blacklisted -%% true: IP is blacklisted -%% IPV4 IP tuple: --spec is_ip_in_c2s_blacklist( - {true, binary(), binary()} | false, - {inet:ip_address(), non_neg_integer()}, - binary()) -> {stop, {true, binary(), binary()}} | false. -is_ip_in_c2s_blacklist(_Val, IP, Lang) when is_tuple(IP) -> - BinaryIP = jlib:ip_to_list(IP), - case ets:lookup(bl_c2s, BinaryIP) of - [] -> %% Not in blacklist - false; - [_] -> - LogReason = io_lib:fwrite( - "This IP address is blacklisted in ~s", - [?BLC2S]), - ReasonT = io_lib:fwrite( - translate:translate( - Lang, - <<"This IP address is blacklisted in ~s">>), - [?BLC2S]), - {stop, {true, LogReason, ReasonT}} - end; -is_ip_in_c2s_blacklist(_Val, _IP, _Lang) -> false. - -%% TODO: -%% - For now, we do not kick user already logged on a given IP after -%% we update the blacklist. - - -mod_opt_type(_) -> []. diff --git a/src/mod_last.erl b/src/mod_last.erl index 6d62a2e2f..b5d17311e 100644 --- a/src/mod_last.erl +++ b/src/mod_last.erl @@ -37,7 +37,7 @@ process_sm_iq/1, on_presence_update/4, import_info/0, import/5, import_start/2, store_last_info/4, get_last_info/2, remove_user/2, transform_options/1, mod_opt_type/1, - opt_type/1, register_user/2, depends/2]). + opt_type/1, register_user/2, depends/2, privacy_check_packet/4]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -64,6 +64,8 @@ start(Host, Opts) -> ?NS_LAST, ?MODULE, process_local_iq, IQDisc), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_LAST, ?MODULE, process_sm_iq, IQDisc), + ejabberd_hooks:add(privacy_check_packet, Host, ?MODULE, + privacy_check_packet, 30), ejabberd_hooks:add(register_user, Host, ?MODULE, register_user, 50), ejabberd_hooks:add(remove_user, Host, ?MODULE, @@ -128,13 +130,10 @@ process_sm_iq(#iq{from = From, to = To, lang = Lang} = IQ) -> if (Subscription == both) or (Subscription == from) or (From#jid.luser == To#jid.luser) and (From#jid.lserver == To#jid.lserver) -> - UserListRecord = - ejabberd_hooks:run_fold(privacy_get_user_list, Server, - #userlist{}, [User, Server]), + Pres = xmpp:set_from_to(#presence{}, To, From), case ejabberd_hooks:run_fold(privacy_check_packet, Server, allow, - [User, Server, UserListRecord, - {To, From, #presence{}}, out]) of + [To, Pres, out]) of allow -> get_last_iq(IQ, User, Server); deny -> xmpp:make_error(IQ, xmpp:err_forbidden()) end; @@ -143,6 +142,31 @@ process_sm_iq(#iq{from = From, to = To, lang = Lang} = IQ) -> xmpp:make_error(IQ, xmpp:err_subscription_required(Txt, Lang)) end. +privacy_check_packet(allow, C2SState, + #iq{from = From, to = To, type = T} = IQ, in) + when T == get; T == set -> + case xmpp:has_subtag(IQ, #last{}) of + true -> + Sub = ejabberd_c2s:get_subscription(From, C2SState), + if Sub == from; Sub == both -> + Pres = #presence{from = To, to = From}, + case ejabberd_hooks:run_fold( + privacy_check_packet, allow, + [C2SState, Pres, out]) of + allow -> + allow; + deny -> + {stop, deny} + end; + true -> + {stop, deny} + end; + false -> + allow + end; +privacy_check_packet(Acc, _, _, _) -> + Acc. + %% @spec (LUser::string(), LServer::string()) -> %% {ok, TimeStamp::integer(), Status::string()} | not_found | {error, Reason} -spec get_last(binary(), binary()) -> {ok, non_neg_integer(), binary()} | diff --git a/src/mod_legacy_auth.erl b/src/mod_legacy_auth.erl new file mode 100644 index 000000000..e9057b432 --- /dev/null +++ b/src/mod_legacy_auth.erl @@ -0,0 +1,144 @@ +%%%------------------------------------------------------------------- +%%% Created : 11 Dec 2016 by Evgeny Khramtsov +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%------------------------------------------------------------------- +-module(mod_legacy_auth). +-behaviour(gen_mod). + +-protocol({xep, 78, '2.5'}). + +%% gen_mod API +-export([start/2, stop/1, depends/2, mod_opt_type/1]). +%% hooks +-export([c2s_unauthenticated_packet/2, c2s_stream_features/2]). + +-include("xmpp.hrl"). + +%%%=================================================================== +%%% API +%%%=================================================================== +start(Host, _Opts) -> + ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE, + c2s_unauthenticated_packet, 50), + ejabberd_hooks:add(c2s_pre_auth_features, Host, ?MODULE, + c2s_stream_features, 50). + +stop(Host) -> + ejabberd_hooks:delete(c2s_unauthenticated_packet, Host, ?MODULE, + c2s_unauthenticated_packet, 50), + ejabberd_hooks:delete(c2s_pre_auth_features, Host, ?MODULE, + c2s_stream_features, 50). + +depends(_Host, _Opts) -> + []. + +mod_opt_type(_) -> + []. + +c2s_unauthenticated_packet(State, #iq{type = T, sub_els = [_]} = IQ) + when T == get; T == set -> + case xmpp:get_subtag(IQ, #legacy_auth{}) of + #legacy_auth{} = Auth -> + {stop, authenticate(State, xmpp:set_els(IQ, [Auth]))}; + false -> + State + end; +c2s_unauthenticated_packet(State, _) -> + State. + +c2s_stream_features(Acc, LServer) -> + case gen_mod:is_loaded(LServer, ?MODULE) of + true -> + [#legacy_auth_feature{}|Acc]; + false -> + Acc + end. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +authenticate(#{server := Server} = State, + #iq{type = get, sub_els = [#legacy_auth{}]} = IQ) -> + LServer = jid:nameprep(Server), + Auth = #legacy_auth{username = <<>>, password = <<>>, resource = <<>>}, + Res = case ejabberd_auth:plain_password_required(LServer) of + false -> + xmpp:make_iq_result(IQ, Auth#legacy_auth{digest = <<>>}); + true -> + xmpp:make_iq_result(IQ, Auth) + end, + ejabberd_c2s:send(State, Res); +authenticate(State, + #iq{type = set, lang = Lang, + sub_els = [#legacy_auth{username = U, + resource = R}]} = IQ) + when U == undefined; R == undefined; U == <<"">>; R == <<"">> -> + Txt = <<"Both the username and the resource are required">>, + Err = xmpp:make_error(IQ, xmpp:err_not_acceptable(Txt, Lang)), + ejabberd_c2s:send(State, Err); +authenticate(#{stream_id := StreamID, server := Server, + access := Access, ip := IP} = State, + #iq{type = set, lang = Lang, + sub_els = [#legacy_auth{username = U, + password = P0, + digest = D0, + resource = R}]} = IQ) -> + P = if is_binary(P0) -> P0; true -> <<>> end, + D = if is_binary(D0) -> D0; true -> <<>> end, + DGen = fun (PW) -> p1_sha:sha(<>) end, + JID = jid:make(U, Server, R), + case JID /= error andalso + acl:access_matches(Access, + #{usr => jid:split(JID), ip => IP}, + JID#jid.lserver) == allow of + true -> + case ejabberd_auth:check_password_with_authmodule( + U, U, JID#jid.lserver, P, D, DGen) of + {true, AuthModule} -> + State1 = ejabberd_c2s:handle_auth_success( + U, <<"legacy">>, AuthModule, State), + State2 = State1#{user := U}, + open_session(State2, IQ, R); + _ -> + Err = xmpp:make_error(IQ, xmpp:err_not_authorized()), + process_auth_failure(State, U, Err, 'not-authorized') + end; + false when JID == error -> + Err = xmpp:make_error(IQ, xmpp:err_jid_malformed()), + process_auth_failure(State, U, Err, 'jid-malformed'); + false -> + Txt = <<"Denied by ACL">>, + Err = xmpp:make_error(IQ, xmpp:err_forbidden(Txt, Lang)), + process_auth_failure(State, U, Err, 'forbidden') + end. + +open_session(State, IQ, R) -> + case ejabberd_c2s:bind(R, State) of + {ok, State1} -> + Res = xmpp:make_iq_result(IQ), + ejabberd_c2s:send(State1, Res); + {error, Err, State1} -> + Res = xmpp:make_error(IQ, Err), + ejabberd_c2s:send(State1, Res) + end. + +process_auth_failure(State, User, StanzaErr, Reason) -> + State1 = ejabberd_c2s:send(State, StanzaErr), + ejabberd_c2s:handle_auth_failure(User, <<"legacy">>, Reason, State1). diff --git a/src/mod_mam.erl b/src/mod_mam.erl index 721b06f03..f55c1ccf2 100644 --- a/src/mod_mam.erl +++ b/src/mod_mam.erl @@ -32,10 +32,10 @@ %% API -export([start/2, stop/1, depends/2]). --export([user_send_packet/4, user_send_packet_strip_tag/4, user_receive_packet/5, +-export([user_send_packet/1, user_send_packet_strip_tag/1, user_receive_packet/1, process_iq_v0_2/1, process_iq_v0_3/1, disco_sm_features/5, remove_user/2, remove_room/3, mod_opt_type/1, muc_process_iq/2, - muc_filter_message/5, message_is_archived/5, delete_old_messages/2, + muc_filter_message/5, message_is_archived/3, delete_old_messages/2, get_commands_spec/0, msg_to_el/4, get_room_config/4, set_room_option/3]). -include("xmpp.hrl"). @@ -103,8 +103,6 @@ start(Host, Opts) -> get_room_config, 50), ejabberd_hooks:add(set_room_option, Host, ?MODULE, set_room_option, 50), - ejabberd_hooks:add(anonymous_purge_hook, Host, ?MODULE, - remove_user, 50), case gen_mod:get_opt(assume_mam_usage, Opts, fun(B) when is_boolean(B) -> B end, false) of true -> @@ -153,8 +151,6 @@ stop(Host) -> get_room_config, 50), ejabberd_hooks:delete(set_room_option, Host, ?MODULE, set_room_option, 50), - ejabberd_hooks:delete(anonymous_purge_hook, Host, - ?MODULE, remove_user, 50), case gen_mod:get_module_opt(Host, ?MODULE, assume_mam_usage, fun(B) when is_boolean(B) -> B end, false) of true -> @@ -199,46 +195,50 @@ set_room_option(_Acc, {mam, Val}, _Lang) -> set_room_option(Acc, _Property, _Lang) -> Acc. --spec user_receive_packet(stanza(), ejabberd_c2s:state(), jid(), jid(), jid()) -> stanza(). -user_receive_packet(Pkt, C2SState, JID, Peer, _To) -> +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_receive_packet({Pkt, #{jid := JID} = C2SState}) -> + Peer = xmpp:get_from(Pkt), LUser = JID#jid.luser, LServer = JID#jid.lserver, - case should_archive(Pkt, LServer) of + Pkt2 = case should_archive(Pkt, LServer) of true -> - NewPkt = strip_my_archived_tag(Pkt, LServer), - case store_msg(C2SState, NewPkt, LUser, LServer, Peer, recv) of + Pkt1 = strip_my_archived_tag(Pkt, LServer), + case store_msg(C2SState, Pkt1, LUser, LServer, Peer, recv) of {ok, ID} -> - set_stanza_id(NewPkt, JID, ID); + set_stanza_id(Pkt1, JID, ID); _ -> - NewPkt + Pkt1 end; _ -> Pkt - end. + end, + {Pkt2, C2SState}. --spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -user_send_packet(Pkt, C2SState, JID, Peer) -> +-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_send_packet({Pkt, #{jid := JID} = C2SState}) -> + Peer = xmpp:get_to(Pkt), LUser = JID#jid.luser, LServer = JID#jid.lserver, - case should_archive(Pkt, LServer) of + Pkt2 = case should_archive(Pkt, LServer) of true -> - NewPkt = strip_my_archived_tag(Pkt, LServer), - case store_msg(C2SState, xmpp:set_from_to(NewPkt, JID, Peer), + Pkt1 = strip_my_archived_tag(Pkt, LServer), + case store_msg(C2SState, xmpp:set_from_to(Pkt1, JID, Peer), LUser, LServer, Peer, send) of {ok, ID} -> - set_stanza_id(NewPkt, JID, ID); + set_stanza_id(Pkt1, JID, ID); _ -> - NewPkt + Pkt1 end; false -> Pkt - end. + end, + {Pkt2, C2SState}. --spec user_send_packet_strip_tag(stanza(), ejabberd_c2s:state(), - jid(), jid()) -> stanza(). -user_send_packet_strip_tag(Pkt, _C2SState, JID, _Peer) -> +-spec user_send_packet_strip_tag({stanza(), ejabberd_c2s:state()}) -> + {stanza(), ejabberd_c2s:state()}. +user_send_packet_strip_tag({Pkt, #{jid := JID} = C2SState}) -> LServer = JID#jid.lserver, - strip_my_archived_tag(Pkt, LServer). + {strip_my_archived_tag(Pkt, LServer), C2SState}. -spec muc_filter_message(message(), mod_muc_room:state(), jid(), jid(), binary()) -> message(). @@ -337,12 +337,12 @@ disco_sm_features({result, OtherFeatures}, disco_sm_features(Acc, _From, _To, _Node, _Lang) -> Acc. --spec message_is_archived(boolean(), ejabberd_c2s:state(), - jid(), jid(), message()) -> boolean(). -message_is_archived(true, _C2SState, _Peer, _JID, _Pkt) -> +-spec message_is_archived(boolean(), ejabberd_c2s:state(), message()) -> boolean(). +message_is_archived(true, _C2SState, _Pkt) -> true; -message_is_archived(false, C2SState, Peer, - #jid{luser = LUser, lserver = LServer}, Pkt) -> +message_is_archived(false, #{jid := JID} = C2SState, Pkt) -> + #jid{luser = LUser, lserver = LServer} = JID, + Peer = xmpp:get_from(Pkt), case gen_mod:get_module_opt(LServer, ?MODULE, assume_mam_usage, fun(B) when is_boolean(B) -> B end, false) of true -> diff --git a/src/mod_metrics.erl b/src/mod_metrics.erl index 7ee8a3b65..7f35a6a50 100644 --- a/src/mod_metrics.erl +++ b/src/mod_metrics.erl @@ -38,8 +38,8 @@ -export([offline_message_hook/3, sm_register_connection_hook/3, sm_remove_connection_hook/3, - user_send_packet/4, user_receive_packet/5, - s2s_send_packet/3, s2s_receive_packet/3, + user_send_packet/1, user_receive_packet/1, + s2s_send_packet/3, s2s_receive_packet/1, remove_user/2, register_user/2]). %%==================================================================== @@ -86,23 +86,27 @@ sm_register_connection_hook(_SID, #jid{lserver=LServer}, _Info) -> sm_remove_connection_hook(_SID, #jid{lserver=LServer}, _Info) -> push(LServer, sm_remove_connection). --spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -user_send_packet(Packet, _C2SState, #jid{lserver=LServer}, _To) -> +-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_send_packet({Packet, #{jid := #jid{lserver = LServer}} = C2SState}) -> push(LServer, user_send_packet), - Packet. + {Packet, C2SState}. --spec user_receive_packet(stanza(), ejabberd_c2s:state(), jid(), jid(), jid()) -> stanza(). -user_receive_packet(Packet, _C2SState, _JID, _From, #jid{lserver=LServer}) -> +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_receive_packet({Packet, #{jid := #jid{lserver = LServer}} = C2SState}) -> push(LServer, user_receive_packet), - Packet. + {Packet, C2SState}. -spec s2s_send_packet(jid(), jid(), stanza()) -> any(). s2s_send_packet(#jid{lserver=LServer}, _To, _Packet) -> push(LServer, s2s_send_packet). --spec s2s_receive_packet(jid(), jid(), stanza()) -> any(). -s2s_receive_packet(_From, #jid{lserver=LServer}, _Packet) -> - push(LServer, s2s_receive_packet). +-spec s2s_receive_packet({stanza(), ejabberd_s2s_in:state()}) -> + {stanza(), ejabberd_s2s_in:state()}. +s2s_receive_packet({Packet, S2SState}) -> + To = xmpp:get_to(Packet), + LServer = ejabberd_router:host_of_route(To#jid.lserver), + push(LServer, s2s_receive_packet), + {Packet, S2SState}. -spec remove_user(binary(), binary()) -> any(). remove_user(_User, Server) -> diff --git a/src/mod_muc.erl b/src/mod_muc.erl index dadc3e0f7..a91fcc810 100644 --- a/src/mod_muc.erl +++ b/src/mod_muc.erl @@ -49,12 +49,20 @@ process_register/1, process_muc_unique/1, process_mucsub/1, - broadcast_service_message/2, + broadcast_service_message/3, export/1, import_info/0, import/5, import_start/2, opts_to_binary/1, + find_online_room/2, + register_online_room/3, + get_online_rooms/1, + count_online_rooms/1, + register_online_user/4, + unregister_online_user/4, + count_online_rooms_by_user/3, + get_online_rooms_by_user/3, can_use_nick/4]). -export([init/1, handle_call/3, handle_cast/2, @@ -63,7 +71,6 @@ -include("ejabberd.hrl"). -include("logger.hrl"). --include_lib("stdlib/include/ms_transform.hrl"). -include("xmpp.hrl"). -include("mod_muc.hrl"). @@ -88,6 +95,16 @@ -callback get_rooms(binary(), binary()) -> [#muc_room{}]. -callback get_nick(binary(), binary(), jid()) -> binary() | error. -callback set_nick(binary(), binary(), jid(), binary()) -> {atomic, ok | false}. +-callback register_online_room(binary(), binary(), pid()) -> any(). +-callback unregister_online_room(binary(), binary(), pid()) -> any(). +-callback find_online_room(binary(), binary()) -> {ok, pid()} | error. +-callback get_online_rooms(binary(), undefined | rsm_set()) -> [{binary(), binary(), pid()}]. +-callback count_online_rooms(binary()) -> non_neg_integer(). +-callback rsm_supported() -> boolean(). +-callback register_online_user(ljid(), binary(), binary()) -> any(). +-callback unregister_online_user(ljid(), binary(), binary()) -> any(). +-callback count_online_rooms_by_user(binary(), binary()) -> non_neg_integer(). +-callback get_online_rooms_by_user(binary(), binary()) -> [{binary(), binary()}]. %%==================================================================== %% API @@ -114,16 +131,17 @@ depends(_Host, _Opts) -> [{mod_mam, soft}]. shutdown_rooms(Host) -> + RMod = gen_mod:ram_db_mod(Host, ?MODULE), MyHost = gen_mod:get_module_opt_host(Host, mod_muc, <<"conference.@HOST@">>), - Rooms = mnesia:dirty_select(muc_online_room, - [{#muc_online_room{name_host = '$1', - pid = '$2'}, - [{'==', {element, 2, '$1'}, MyHost}, - {'==', {node, '$2'}, node()}], - ['$2']}]), - [Pid ! shutdown || Pid <- Rooms], - Rooms. + Rooms = RMod:get_online_rooms(MyHost, undefined), + lists:filter( + fun({_, _, Pid}) when node(Pid) == node() -> + Pid ! shutdown, + true; + (_) -> + false + end, Rooms). %% This function is called by a room in three situations: %% A) The owner of the room destroyed it @@ -165,6 +183,48 @@ can_use_nick(ServerHost, Host, JID, Nick) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:can_use_nick(LServer, Host, JID, Nick). +-spec find_online_room(binary(), binary()) -> {ok, pid()} | error. +find_online_room(Room, Host) -> + ServerHost = ejabberd_router:host_of_route(Host), + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + RMod:find_online_room(Room, Host). + +-spec register_online_room(binary(), binary(), pid()) -> any(). +register_online_room(Room, Host, Pid) -> + ServerHost = ejabberd_router:host_of_route(Host), + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + RMod:register_online_room(Room, Host, Pid). + +-spec get_online_rooms(binary()) -> [{binary(), binary(), pid()}]. +get_online_rooms(Host) -> + ServerHost = ejabberd_router:host_of_route(Host), + get_online_rooms(ServerHost, Host). + +-spec count_online_rooms(binary()) -> non_neg_integer(). +count_online_rooms(Host) -> + ServerHost = ejabberd_router:host_of_route(Host), + count_online_rooms(ServerHost, Host). + +-spec register_online_user(binary(), ljid(), binary(), binary()) -> any(). +register_online_user(ServerHost, LJID, Name, Host) -> + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + RMod:register_online_user(LJID, Name, Host). + +-spec unregister_online_user(binary(), ljid(), binary(), binary()) -> any(). +unregister_online_user(ServerHost, LJID, Name, Host) -> + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + RMod:unregister_online_user(LJID, Name, Host). + +-spec count_online_rooms_by_user(binary(), binary(), binary()) -> non_neg_integer(). +count_online_rooms_by_user(ServerHost, LUser, LServer) -> + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + RMod:count_online_rooms_by_user(LUser, LServer). + +-spec get_online_rooms_by_user(binary(), binary(), binary()) -> [{binary(), binary()}]. +get_online_rooms_by_user(ServerHost, LUser, LServer) -> + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + RMod:get_online_rooms_by_user(LUser, LServer). + %%==================================================================== %% gen_server callbacks %%==================================================================== @@ -175,16 +235,9 @@ init([Host, Opts]) -> MyHost = gen_mod:get_opt_host(Host, Opts, <<"conference.@HOST@">>), Mod = gen_mod:db_mod(Host, Opts, ?MODULE), + RMod = gen_mod:ram_db_mod(Host, Opts, ?MODULE), Mod:init(Host, [{host, MyHost}|Opts]), - update_tables(), - ejabberd_mnesia:create(?MODULE, muc_online_room, - [{ram_copies, [node()]}, - {type, ordered_set}, - {attributes, record_info(fields, muc_online_room)}]), - mnesia:add_table_copy(muc_online_room, node(), ram_copies), - catch ets:new(muc_online_users, [bag, named_table, public, {keypos, 2}]), - clean_table_from_bad_node(node(), MyHost), - mnesia:subscribe(system), + RMod:init(Host, [{host, MyHost}|Opts]), Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all), AccessCreate = gen_mod:get_opt(access_create, Opts, @@ -298,7 +351,8 @@ handle_call({create, Room, From, Nick, Opts}, _From, Room, HistorySize, RoomShaper, From, Nick, NewOpts), - register_room(Host, Room, Pid), + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + RMod:register_online_room(Room, Host, Pid), {reply, ok, State}. handle_cast(_Msg, State) -> {noreply, State}. @@ -317,18 +371,14 @@ handle_info({route, From, To, Packet}, ok end, {noreply, State}; -handle_info({room_destroyed, RoomHost, Pid}, State) -> - F = fun () -> - mnesia:delete_object(#muc_online_room{name_host = - RoomHost, - pid = Pid}) - end, - mnesia:transaction(F), +handle_info({room_destroyed, {Room, Host}, Pid}, State) -> + ServerHost = State#state.server_host, + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + RMod:unregister_online_room(Room, Host, Pid), {noreply, State}; -handle_info({mnesia_system_event, {mnesia_down, Node}}, State) -> - clean_table_from_bad_node(Node), - {noreply, State}; -handle_info(_Info, State) -> {noreply, State}. +handle_info(Info, State) -> + ?ERROR_MSG("unexpected info: ~p", [Info]), + {noreply, State}. terminate(_Reason, #state{host = MyHost}) -> ejabberd_router:unregister_route(MyHost), @@ -374,7 +424,7 @@ do_route1(Host, ServerHost, Access, _HistorySize, _RoomShaper, case acl:match_rule(ServerHost, AccessAdmin, From) of allow -> Msg = xmpp:get_text(Body), - broadcast_service_message(Host, Msg); + broadcast_service_message(ServerHost, Host, Msg); deny -> ErrText = <<"Only service administrators are allowed " "to send service messages">>, @@ -390,8 +440,9 @@ do_route1(Host, ServerHost, Access, HistorySize, RoomShaper, From, To, Packet, DefRoomOpts) -> {_AccessRoute, AccessCreate, _AccessAdmin, _AccessPersistent} = Access, {Room, _, Nick} = jid:tolower(To), - case mnesia:dirty_read(muc_online_room, {Room, Host}) of - [] -> + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + case RMod:find_online_room(Room, Host) of + error -> case is_create_request(Packet) of true -> case check_user_can_create_room( @@ -402,7 +453,7 @@ do_route1(Host, ServerHost, Access, HistorySize, RoomShaper, Host, ServerHost, Access, Room, HistorySize, RoomShaper, From, Nick, DefRoomOpts), - register_room(Host, Room, Pid), + RMod:register_online_room(Room, Host, Pid), mod_muc_room:route(Pid, From, Nick, Packet), ok; false -> @@ -417,8 +468,7 @@ do_route1(Host, ServerHost, Access, HistorySize, RoomShaper, Err = xmpp:err_item_not_found(ErrText, Lang), ejabberd_router:route_error(To, From, Packet, Err) end; - [R] -> - Pid = R#muc_online_room.pid, + {ok, Pid} -> ?DEBUG("MUC: send to process ~p~n", [Pid]), mod_muc_room:route(Pid, From, Nick, Packet), ok @@ -462,15 +512,20 @@ process_disco_info(#iq{type = set, lang = Lang} = IQ) -> process_disco_info(#iq{type = get, to = To, lang = Lang, sub_els = [#disco_info{node = <<"">>}]} = IQ) -> ServerHost = ejabberd_router:host_of_route(To#jid.lserver), + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), X = ejabberd_hooks:run_fold(disco_info, ServerHost, [], [ServerHost, ?MODULE, <<"">>, Lang]), MAMFeatures = case gen_mod:is_loaded(ServerHost, mod_mam) of true -> [?NS_MAM_TMP, ?NS_MAM_0, ?NS_MAM_1]; false -> [] end, + RSMFeatures = case RMod:rsm_supported() of + true -> [?NS_RSM]; + false -> [] + end, Features = [?NS_DISCO_INFO, ?NS_DISCO_ITEMS, - ?NS_REGISTER, ?NS_MUC, ?NS_RSM, - ?NS_VCARD, ?NS_MUCSUB, ?NS_MUC_UNIQUE | MAMFeatures], + ?NS_REGISTER, ?NS_MUC, ?NS_VCARD, ?NS_MUCSUB, ?NS_MUC_UNIQUE + | RSMFeatures ++ MAMFeatures], Identity = #identity{category = <<"conference">>, type = <<"text">>, name = translate:translate(Lang, <<"Chatrooms">>)}, @@ -497,7 +552,8 @@ process_disco_items(#iq{type = get, from = From, to = To, lang = Lang, ServerHost, ?MODULE, max_rooms_discoitems, fun(I) when is_integer(I), I>=0 -> I end, 100), - case iq_disco_items(Host, From, Lang, MaxRoomsDiscoItems, Node, RSM) of + case iq_disco_items(ServerHost, Host, From, Lang, + MaxRoomsDiscoItems, Node, RSM) of {error, Err} -> xmpp:make_error(IQ, Err); {result, Result} -> @@ -564,17 +620,19 @@ get_rooms(ServerHost, Host) -> load_permanent_rooms(Host, ServerHost, Access, HistorySize, RoomShaper) -> + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), lists:foreach( fun(R) -> {Room, Host} = R#muc_room.name_host, - case mnesia:dirty_read(muc_online_room, {Room, Host}) of - [] -> + case RMod:find_online_room(Room, Host) of + error -> {ok, Pid} = mod_muc_room:start(Host, ServerHost, Access, Room, HistorySize, RoomShaper, R#muc_room.opts), - register_room(Host, Room, Pid); - _ -> ok + RMod:register_online_room(Room, Host, Pid); + {ok, _} -> + ok end end, get_rooms(ServerHost, Host)). @@ -594,19 +652,12 @@ start_new_room(Host, ServerHost, Access, Room, HistorySize, RoomShaper, Opts) end. -register_room(Host, Room, Pid) -> - F = fun() -> - mnesia:write(#muc_online_room{name_host = {Room, Host}, - pid = Pid}) - end, - mnesia:transaction(F). - --spec iq_disco_items(binary(), jid(), binary(), integer(), binary(), +-spec iq_disco_items(binary(), binary(), jid(), binary(), integer(), binary(), rsm_set() | undefined) -> {result, disco_items()} | {error, stanza_error()}. -iq_disco_items(Host, From, Lang, MaxRoomsDiscoItems, Node, RSM) +iq_disco_items(ServerHost, Host, From, Lang, MaxRoomsDiscoItems, Node, RSM) when Node == <<"">>; Node == <<"nonemptyrooms">>; Node == <<"emptyrooms">> -> - Count = get_vh_rooms_count(Host), + Count = count_online_rooms(ServerHost, Host), Query = if Node == <<"">>, RSM == undefined, Count > MaxRoomsDiscoItems -> {get_disco_item, only_non_empty, From, Lang}; Node == <<"nonemptyrooms">> -> @@ -616,7 +667,13 @@ iq_disco_items(Host, From, Lang, MaxRoomsDiscoItems, Node, RSM) true -> {get_disco_item, all, From, Lang} end, - Items = get_vh_rooms(Host, Query, RSM), + Items = lists:flatmap( + fun(R) -> + case get_room_disco_item(R, Query) of + {ok, Item} -> [Item]; + {error, _} -> [] + end + end, get_online_rooms(ServerHost, Host, RSM)), ResRSM = case Items of [_|_] when RSM /= undefined -> #disco_item{jid = #jid{luser = First}} = hd(Items), @@ -630,69 +687,13 @@ iq_disco_items(Host, From, Lang, MaxRoomsDiscoItems, Node, RSM) undefined end, {result, #disco_items{node = Node, items = Items, rsm = ResRSM}}; -iq_disco_items(_Host, _From, Lang, _MaxRoomsDiscoItems, _Node, _RSM) -> +iq_disco_items(_ServerHost, _Host, _From, Lang, _MaxRoomsDiscoItems, _Node, _RSM) -> {error, xmpp:err_item_not_found(<<"Node not found">>, Lang)}. --spec get_vh_rooms(binary, term(), rsm_set() | undefined) -> [disco_item()]. -get_vh_rooms(Host, Query, - #rsm_set{max = Max, 'after' = After, before = undefined}) - when is_binary(After), After /= <<"">> -> - lists:reverse(get_vh_rooms(next, {After, Host}, Host, Query, 0, Max, [])); -get_vh_rooms(Host, Query, - #rsm_set{max = Max, 'after' = undefined, before = Before}) - when is_binary(Before), Before /= <<"">> -> - get_vh_rooms(prev, {Before, Host}, Host, Query, 0, Max, []); -get_vh_rooms(Host, Query, - #rsm_set{max = Max, 'after' = undefined, before = <<"">>}) -> - get_vh_rooms(last, {<<"">>, Host}, Host, Query, 0, Max, []); -get_vh_rooms(Host, Query, #rsm_set{max = Max}) -> - lists:reverse(get_vh_rooms(first, {<<"">>, Host}, Host, Query, 0, Max, [])); -get_vh_rooms(Host, Query, undefined) -> - lists:reverse(get_vh_rooms(first, {<<"">>, Host}, Host, Query, 0, undefined, [])). - --spec get_vh_rooms(prev | next | last | first, - {binary(), binary()}, binary(), term(), - non_neg_integer(), non_neg_integer() | undefined, - [disco_item()]) -> [disco_item()]. -get_vh_rooms(_Action, _Key, _Host, _Query, Count, Max, Items) when Count >= Max -> - Items; -get_vh_rooms(Action, Key, Host, Query, Count, Max, Items) -> - Call = fun() -> - case Action of - prev -> mnesia:dirty_prev(muc_online_room, Key); - next -> mnesia:dirty_next(muc_online_room, Key); - last -> mnesia:dirty_last(muc_online_room); - first -> mnesia:dirty_first(muc_online_room) - end - end, - NewAction = case Action of - last -> prev; - first -> next; - _ -> Action - end, - try Call() of - '$end_of_table' -> - Items; - {_, Host} = NewKey -> - case get_room_disco_item(NewKey, Query) of - {ok, Item} -> - get_vh_rooms(NewAction, NewKey, Host, Query, - Count + 1, Max, [Item|Items]); - {error, _} -> - get_vh_rooms(NewAction, NewKey, Host, Query, - Count, Max, Items) - end; - NewKey -> - get_vh_rooms(NewAction, NewKey, Host, Query, Count, Max, Items) - catch _:{aborted, {badarg, _}} -> - Items - end. - --spec get_room_disco_item({binary(), binary()}, term()) -> {ok, disco_item()} | +-spec get_room_disco_item({binary(), binary(), pid()}, + term()) -> {ok, disco_item()} | {error, timeout | notfound}. -get_room_disco_item({Name, Host}, Query) -> - case mnesia:dirty_read(muc_online_room, {Name, Host}) of - [#muc_online_room{pid = Pid}|_] -> +get_room_disco_item({Name, Host, Pid}, Query) -> RoomJID = jid:make(Name, Host), try gen_fsm:sync_send_all_state_event(Pid, Query, 100) of {item, Desc} -> @@ -703,16 +704,13 @@ get_room_disco_item({Name, Host}, Query) -> {error, timeout}; _:{noproc, _} -> {error, notfound} - end; - _ -> - {error, notfound} end. -get_subscribed_rooms(_ServerHost, Host, From) -> - Rooms = get_vh_rooms(Host), +get_subscribed_rooms(ServerHost, Host, From) -> + Rooms = get_online_rooms(ServerHost, Host), BareFrom = jid:remove_resource(From), lists:flatmap( - fun(#muc_online_room{name_host = {Name, _}, pid = Pid}) -> + fun({Name, _, Pid}) -> case gen_fsm:sync_send_all_state_event(Pid, {is_subscribed, BareFrom}) of true -> [jid:make(Name, Host)]; false -> [] @@ -793,72 +791,28 @@ process_iq_register_set(ServerHost, Host, From, {error, xmpp:err_not_acceptable(ErrText, Lang)} end. -broadcast_service_message(Host, Msg) -> +-spec broadcast_service_message(binary(), binary(), message()) -> ok. +broadcast_service_message(ServerHost, Host, Msg) -> lists:foreach( - fun(#muc_online_room{pid = Pid}) -> + fun({_, _, Pid}) -> gen_fsm:send_all_state_event( Pid, {service_message, Msg}) - end, get_vh_rooms(Host)). + end, get_online_rooms(ServerHost, Host)). +-spec get_online_rooms(binary(), binary()) -> [{binary(), binary(), pid()}]. +get_online_rooms(ServerHost, Host) -> + get_online_rooms(ServerHost, Host, undefined). -get_vh_rooms(Host) -> - mnesia:dirty_select(muc_online_room, - [{#muc_online_room{name_host = '$1', _ = '_'}, - [{'==', {element, 2, '$1'}, Host}], - ['$_']}]). +-spec get_online_rooms(binary(), binary(), undefined | rsm_set()) -> + [{binary(), binary(), pid()}]. +get_online_rooms(ServerHost, Host, RSM) -> + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + RMod:get_online_rooms(Host, RSM). --spec get_vh_rooms_count(binary()) -> non_neg_integer(). -get_vh_rooms_count(Host) -> - ets:select_count(muc_online_room, - ets:fun2ms( - fun(#muc_online_room{name_host = {_, H}}) -> - H == Host - end)). - -clean_table_from_bad_node(Node) -> - F = fun() -> - Es = mnesia:select( - muc_online_room, - [{#muc_online_room{pid = '$1', _ = '_'}, - [{'==', {node, '$1'}, Node}], - ['$_']}]), - lists:foreach(fun(E) -> - mnesia:delete_object(E) - end, Es) - end, - mnesia:async_dirty(F). - -clean_table_from_bad_node(Node, Host) -> - F = fun() -> - Es = mnesia:select( - muc_online_room, - [{#muc_online_room{pid = '$1', - name_host = {'_', Host}, - _ = '_'}, - [{'==', {node, '$1'}, Node}], - ['$_']}]), - lists:foreach(fun(E) -> - mnesia:delete_object(E) - end, Es) - end, - mnesia:async_dirty(F). - -update_tables() -> - try - case mnesia:table_info(muc_online_room, type) of - ordered_set -> ok; - _ -> - case mnesia:delete_table(muc_online_room) of - {atomic, ok} -> ok; - Err -> erlang:error(Err) - end - end - catch _:{aborted, {no_exists, muc_online_room}} -> ok; - _:{aborted, {no_exists, muc_online_room, type}} -> ok; - E:R -> - ?ERROR_MSG("failed to update mnesia table '~s': ~p", - [muc_online_room, {E, R}]) - end. +-spec count_online_rooms(binary(), binary()) -> non_neg_integer(). +count_online_rooms(ServerHost, Host) -> + RMod = gen_mod:ram_db_mod(ServerHost, ?MODULE), + RMod:count_online_rooms(Host). opts_to_binary(Opts) -> lists:map( @@ -922,6 +876,7 @@ mod_opt_type(access_create) -> mod_opt_type(access_persistent) -> fun acl:access_rules_validator/1; mod_opt_type(db_type) -> fun(T) -> ejabberd_config:v_db(?MODULE, T) end; +mod_opt_type(ram_db_type) -> fun(T) -> ejabberd_config:v_db(?MODULE, T) end; mod_opt_type(default_room_options) -> fun (L) when is_list(L) -> L end; mod_opt_type(history_size) -> @@ -963,7 +918,7 @@ mod_opt_type(user_presence_shaper) -> fun (A) when is_atom(A) -> A end; mod_opt_type(_) -> [access, access_admin, access_create, access_persistent, - db_type, default_room_options, history_size, host, + db_type, ram_db_type, default_room_options, history_size, host, max_room_desc, max_room_id, max_room_name, max_rooms_discoitems, max_user_conferences, max_users, max_users_admin_threshold, max_users_presence, diff --git a/src/mod_muc_admin.erl b/src/mod_muc_admin.erl index 7606dcfa3..55f182ba3 100644 --- a/src/mod_muc_admin.erl +++ b/src/mod_muc_admin.erl @@ -45,7 +45,6 @@ -include("logger.hrl"). -include("xmpp.hrl"). -include("mod_muc_room.hrl"). --include("mod_muc.hrl"). -include("ejabberd_http.hrl"). -include("ejabberd_web_admin.hrl"). -include("ejabberd_commands.hrl"). @@ -224,22 +223,12 @@ get_commands_spec() -> %%% muc_online_rooms(ServerHost) -> - MUCHost = find_host(ServerHost), - Rooms = ets:tab2list(muc_online_room), - lists:foldl( - fun(Room, Results) -> - {Roomname, Host} = Room#muc_online_room.name_host, - case MUCHost of - global -> - [<> | Results]; - Host -> - [<> | Results]; - _ -> - Results - end - end, - [], - Rooms). + Hosts = find_hosts(ServerHost), + lists:flatmap( + fun(Host) -> + [{<>} + || {Name, _, _} <- mod_muc:get_online_rooms(Host)] + end, Hosts). muc_unregister_nick(Nick) -> F2 = fun(N) -> @@ -254,14 +243,18 @@ muc_unregister_nick(Nick) -> end. get_user_rooms(LUser, LServer) -> - US = {LUser, LServer}, - case catch ets:select(muc_online_users, - [{#muc_online_users{us = US, room='$1', host='$2', _ = '_'}, [], [{{'$1', '$2'}}]}]) - of - Res when is_list(Res) -> - [<> || {R, H} <- Res]; - _ -> [] - end. + lists:flatmap( + fun(ServerHost) -> + case gen_mod:is_loaded(ServerHost, mod_muc) of + true -> + Rooms = mod_muc:get_online_rooms_by_user( + ServerHost, LUser, LServer), + [<> + || {Name, Host} <- Rooms]; + false -> + [] + end + end, ?MYHOSTS). %%---------------------------- %% Ad-hoc commands @@ -291,10 +284,14 @@ web_menu_host(Acc, _Host, Lang) -> ])). web_page_main(_, #request{path=[<<"muc">>], lang = Lang} = _Request) -> + OnlineRoomsNumber = lists:foldl( + fun(Host, Acc) -> + Acc ++ mod_muc:count_online_rooms(Host) + end, 0, find_hosts(global)), Res = [?XCT(<<"h1">>, <<"Multi-User Chat">>), ?XCT(<<"h3">>, <<"Statistics">>), ?XAE(<<"table">>, [], - [?XE(<<"tbody">>, [?TDTD(<<"Total rooms">>, ets:info(muc_online_room, size)), + [?XE(<<"tbody">>, [?TDTD(<<"Total rooms">>, OnlineRoomsNumber), ?TDTD(<<"Permanent rooms">>, mnesia:table_info(muc_room, size)), ?TDTD(<<"Registered nicknames">>, mnesia:table_info(muc_registered, size)) ]) @@ -473,8 +470,8 @@ create_room_with_opts(Name1, Host1, ServerHost, CustomRoomOpts) -> RoomShaper = gen_mod:get_module_opt(ServerHost, mod_muc, room_shaper, fun(X) -> X end, none), %% If the room does not exist yet in the muc_online_room - case mnesia:dirty_read(muc_online_room, {Name, Host}) of - [] -> + case mod_muc:find_online_room(Name, Host) of + error -> %% Start the room {ok, Pid} = mod_muc_room:start( Host, @@ -484,19 +481,12 @@ create_room_with_opts(Name1, Host1, ServerHost, CustomRoomOpts) -> HistorySize, RoomShaper, RoomOpts), - {atomic, ok} = register_room(Host, Name, Pid), + mod_muc:register_online_room(Host, Name, Pid), ok; - _ -> + {ok, _} -> error end. -register_room(Host, Name, Pid) -> - F = fun() -> - mnesia:write(#muc_online_room{name_host = {Name, Host}, - pid = Pid}) - end, - mnesia:transaction(F). - %% Create the room only in the database. %% It is required to restart the MUC service for the room to appear. muc_create_room(ServerHost, {Name, Host, _}, DefRoomOpts) -> @@ -509,12 +499,11 @@ muc_create_room(ServerHost, {Name, Host, _}, DefRoomOpts) -> %% If the room has participants, they are not notified that the room was destroyed; %% they will notice when they try to chat and receive an error that the room doesn't exist. destroy_room(Name, Service) -> - case mnesia:dirty_read(muc_online_room, {Name, Service}) of - [R] -> - Pid = R#muc_online_room.pid, + case mod_muc:find_online_room(Name, Service) of + {ok, Pid} -> gen_fsm:send_all_state_event(Pid, destroy), ok; - [] -> + error -> error end. @@ -619,19 +608,12 @@ muc_unused2(Action, ServerHost, Host, Last_allowed) -> %%--------------- %% Get info -get_rooms(Host) -> - Get_room_names = fun(Room_reg, Names) -> - Pid = Room_reg#muc_online_room.pid, - case {Host, Room_reg#muc_online_room.name_host} of - {Host, {Name1, Host}} -> - [{Name1, Host, Pid} | Names]; - {global, {Name1, Host1}} -> - [{Name1, Host1, Pid} | Names]; - _ -> - Names - end - end, - ets:foldr(Get_room_names, [], muc_online_room). +get_rooms(ServerHost) -> + Hosts = find_hosts(ServerHost), + lists:flatmap( + fun(Host) -> + mod_muc:get_online_rooms(Host) + end, Hosts). get_room_config(Room_pid) -> {ok, R} = gen_fsm:sync_send_all_state_event(Room_pid, get_config), @@ -830,11 +812,11 @@ format_room_option(OptionString, ValueString) -> %% @doc Get the Pid of an existing MUC room, or 'room_not_found'. get_room_pid(Name, Service) -> - case mnesia:dirty_read(muc_online_room, {Name, Service}) of - [] -> + case mod_muc:find_online_room(Name, Service) of + error -> room_not_found; - [Room] -> - Room#muc_online_room.pid + {ok, Pid} -> + Pid end. %% It is required to put explicitely all the options because @@ -901,10 +883,9 @@ get_options(Config) -> %% [{JID::string(), Domain::string(), Role::string(), Reason::string()}] %% @doc Get the affiliations of the room Name@Service. get_room_affiliations(Name, Service) -> - case mnesia:dirty_read(muc_online_room, {Name, Service}) of - [R] -> + case mod_muc:find_online_room(Name, Service) of + {ok, Pid} -> %% Get the PID of the online room, then request its state - Pid = R#muc_online_room.pid, {ok, StateData} = gen_fsm:sync_send_all_state_event(Pid, get_state), Affiliations = ?DICT:to_list(StateData#state.affiliations), lists:map( @@ -913,7 +894,7 @@ get_room_affiliations(Name, Service) -> ({{Uname, Domain, _Res}, Aff}) when is_atom(Aff)-> {Uname, Domain, Aff, <<>>} end, Affiliations); - [] -> + error -> throw({error, "The room does not exist."}) end. @@ -931,14 +912,13 @@ get_room_affiliations(Name, Service) -> %% In any other case the action will be to create the affiliation. set_room_affiliation(Name, Service, JID, AffiliationString) -> Affiliation = jlib:binary_to_atom(AffiliationString), - case mnesia:dirty_read(muc_online_room, {Name, Service}) of - [R] -> + case mod_muc:find_online_room(Name, Service) of + {ok, Pid} -> %% Get the PID for the online room so we can get the state of the room - Pid = R#muc_online_room.pid, {ok, StateData} = gen_fsm:sync_send_all_state_event(Pid, {process_item_change, {jid:from_string(JID), affiliation, Affiliation, <<"">>}, <<"">>}), mod_muc:store_room(StateData#state.server_host, StateData#state.host, StateData#state.room, make_opts(StateData)), ok; - [] -> + error -> error end. @@ -1074,4 +1054,28 @@ find_host(ServerHost) when is_list(ServerHost) -> find_host(ServerHost) -> gen_mod:get_module_opt_host(ServerHost, mod_muc, <<"conference.@HOST@">>). +find_hosts(Global) when Global == global; + Global == "global"; + Global == <<"global">> -> + lists:flatmap( + fun(ServerHost) -> + case gen_mod:is_loaded(ServerHost, mod_muc) of + true -> + [gen_mod:get_module_opt_host( + ServerHost, mod_muc, <<"conference.@HOST@">>)]; + false -> + [] + end + end, ?MYHOSTS); +find_hosts(ServerHost) when is_list(ServerHost) -> + find_hosts(list_to_binary(ServerHost)); +find_hosts(ServerHost) -> + case gen_mod:is_loaded(ServerHost, mod_muc) of + true -> + [gen_mod:get_module_opt_host( + ServerHost, mod_muc, <<"conference.@HOST@">>)]; + false -> + [] + end. + mod_opt_type(_) -> []. diff --git a/src/mod_muc_log.erl b/src/mod_muc_log.erl index f1089212e..700f7284e 100644 --- a/src/mod_muc_log.erl +++ b/src/mod_muc_log.erl @@ -47,7 +47,6 @@ -include("logger.hrl"). -include("xmpp.hrl"). --include("mod_muc.hrl"). -include("mod_muc_room.hrl"). -define(T(Text), translate:translate(Lang, Text)). @@ -1169,13 +1168,11 @@ get_room_occupants(RoomJIDString) -> -spec get_room_state(binary(), binary()) -> mod_muc_room:state(). get_room_state(RoomName, MucService) -> - case mnesia:dirty_read(muc_online_room, - {RoomName, MucService}) - of - [R] -> - RoomPid = R#muc_online_room.pid, + case mod_muc:find_online_room(RoomName, MucService) of + {ok, RoomPid} -> get_room_state(RoomPid); - [] -> #state{} + error -> + #state{} end. -spec get_room_state(pid()) -> mod_muc_room:state(). diff --git a/src/mod_muc_mnesia.erl b/src/mod_muc_mnesia.erl index 9c6ebf924..9fdd1dce2 100644 --- a/src/mod_muc_mnesia.erl +++ b/src/mod_muc_mnesia.erl @@ -30,28 +30,34 @@ %% API -export([init/2, import/3, store_room/4, restore_room/3, forget_room/3, can_use_nick/4, get_rooms/2, get_nick/3, set_nick/4]). +-export([register_online_room/3, unregister_online_room/3, find_online_room/2, + get_online_rooms/2, count_online_rooms/1, rsm_supported/0, + register_online_user/3, unregister_online_user/3, + count_online_rooms_by_user/2, get_online_rooms_by_user/2]). -export([set_affiliation/6, set_affiliations/4, get_affiliation/5, get_affiliations/3, search_affiliation/4]). +%% gen_server callbacks +-export([init/1, handle_cast/2, handle_call/3, handle_info/2, + terminate/2, code_change/3]). --include("jlib.hrl"). -include("mod_muc.hrl"). -include("logger.hrl"). +-include("xmpp.hrl"). +-include_lib("stdlib/include/ms_transform.hrl"). + +-record(state, {}). %%%=================================================================== %%% API %%%=================================================================== -init(_Host, Opts) -> - MyHost = proplists:get_value(host, Opts), - ejabberd_mnesia:create(?MODULE, muc_room, - [{disc_copies, [node()]}, - {attributes, - record_info(fields, muc_room)}]), - ejabberd_mnesia:create(?MODULE, muc_registered, - [{disc_copies, [node()]}, - {attributes, - record_info(fields, muc_registered)}, - {index, [nick]}]), - update_tables(MyHost). +init(Host, Opts) -> + Name = gen_mod:get_module_proc(Host, ?MODULE), + case gen_server:start_link({local, Name}, ?MODULE, [Host, Opts], []) of + {ok, _Pid} -> + ok; + Err -> + Err + end. store_room(_LServer, Host, Name, Opts) -> F = fun () -> @@ -148,6 +154,123 @@ get_affiliations(_ServerHost, _Room, _Host) -> search_affiliation(_ServerHost, _Room, _Host, _Affiliation) -> {error, not_implemented}. +register_online_room(Room, Host, Pid) -> + F = fun() -> + mnesia:write( + #muc_online_room{name_host = {Room, Host}, pid = Pid}) + end, + mnesia:transaction(F). + +unregister_online_room(Room, Host, Pid) -> + F = fun () -> + mnesia:delete_object( + #muc_online_room{name_host = {Room, Host}, pid = Pid}) + end, + mnesia:transaction(F). + +find_online_room(Room, Host) -> + case mnesia:dirty_read(muc_online_room, {Room, Host}) of + [] -> error; + [#muc_online_room{pid = Pid}] -> {ok, Pid} + end. + +count_online_rooms(Host) -> + ets:select_count( + muc_online_room, + ets:fun2ms( + fun(#muc_online_room{name_host = {_, H}}) -> + H == Host + end)). + +get_online_rooms(Host, + #rsm_set{max = Max, 'after' = After, before = undefined}) + when is_binary(After), After /= <<"">> -> + lists:reverse(get_online_rooms(next, {After, Host}, Host, 0, Max, [])); +get_online_rooms(Host, + #rsm_set{max = Max, 'after' = undefined, before = Before}) + when is_binary(Before), Before /= <<"">> -> + get_online_rooms(prev, {Before, Host}, Host, 0, Max, []); +get_online_rooms(Host, + #rsm_set{max = Max, 'after' = undefined, before = <<"">>}) -> + get_online_rooms(last, {<<"">>, Host}, Host, 0, Max, []); +get_online_rooms(Host, #rsm_set{max = Max}) -> + lists:reverse(get_online_rooms(first, {<<"">>, Host}, Host, 0, Max, [])); +get_online_rooms(Host, undefined) -> + mnesia:dirty_select( + muc_online_room, + ets:fun2ms( + fun(#muc_online_room{name_host = {Name, H}, pid = Pid}) + when H == Host -> {Name, Host, Pid} + end)). + +-spec get_online_rooms(prev | next | last | first, + {binary(), binary()}, binary(), + non_neg_integer(), non_neg_integer() | undefined, + [{binary(), binary(), pid()}]) -> + [{binary(), binary(), pid()}]. +get_online_rooms(_Action, _Key, _Host, Count, Max, Items) when Count >= Max -> + Items; +get_online_rooms(Action, Key, Host, Count, Max, Items) -> + Call = fun() -> + case Action of + prev -> mnesia:dirty_prev(muc_online_room, Key); + next -> mnesia:dirty_next(muc_online_room, Key); + last -> mnesia:dirty_last(muc_online_room); + first -> mnesia:dirty_first(muc_online_room) + end + end, + NewAction = case Action of + last -> prev; + first -> next; + _ -> Action + end, + try Call() of + '$end_of_table' -> + Items; + {Room, Host} = NewKey -> + case find_online_room(Room, Host) of + {ok, Pid} -> + get_online_rooms(NewAction, NewKey, Host, + Count + 1, Max, [{Room, Host, Pid}|Items]); + {error, _} -> + get_online_rooms(NewAction, NewKey, Host, + Count, Max, Items) + end; + NewKey -> + get_online_rooms(NewAction, NewKey, Host, Count, Max, Items) + catch _:{aborted, {badarg, _}} -> + Items + end. + +rsm_supported() -> + true. + +register_online_user({U, S, R}, Room, Host) -> + ets:insert(muc_online_users, + #muc_online_users{us = {U, S}, resource = R, + room = Room, host = Host}). + +unregister_online_user({U, S, R}, Room, Host) -> + ets:delete_object(muc_online_users, + #muc_online_users{us = {U, S}, resource = R, + room = Room, host = Host}). + +count_online_rooms_by_user(U, S) -> + ets:select_count( + muc_online_users, + ets:fun2ms( + fun(#muc_online_users{us = {U1, S1}}) -> + U == U1 andalso S == S1 + end)). + +get_online_rooms_by_user(U, S) -> + ets:select( + muc_online_users, + ets:fun2ms( + fun(#muc_online_users{us = {U1, S1}, room = Room, host = Host}) + when U == U1 andalso S == S1 -> {Room, Host} + end)). + import(_LServer, <<"muc_room">>, [Name, RoomHost, SOpts, _TimeStamp]) -> Opts = mod_muc:opts_to_binary(ejabberd_sql:decode_term(SOpts)), @@ -161,9 +284,93 @@ import(_LServer, <<"muc_registered">>, #muc_registered{us_host = {{U, S}, RoomHost}, nick = Nick}). +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([Host, Opts]) -> + MyHost = proplists:get_value(host, Opts), + case gen_mod:db_mod(Host, Opts, mod_muc) of + ?MODULE -> + ejabberd_mnesia:create(?MODULE, muc_room, + [{disc_copies, [node()]}, + {attributes, + record_info(fields, muc_room)}]), + ejabberd_mnesia:create(?MODULE, muc_registered, + [{disc_copies, [node()]}, + {attributes, + record_info(fields, muc_registered)}, + {index, [nick]}]), + update_tables(MyHost); + _ -> + ok + end, + case gen_mod:ram_db_mod(Host, Opts, mod_muc) of + ?MODULE -> + update_muc_online_table(), + ejabberd_mnesia:create(?MODULE, muc_online_room, + [{ram_copies, [node()]}, + {type, ordered_set}, + {attributes, record_info(fields, muc_online_room)}]), + mnesia:add_table_copy(muc_online_room, node(), ram_copies), + catch ets:new(muc_online_users, [bag, named_table, public, {keypos, 2}]), + clean_table_from_bad_node(node(), MyHost), + mnesia:subscribe(system); + _ -> + ok + end, + {ok, #state{}}. + +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info({mnesia_system_event, {mnesia_down, Node}}, State) -> + clean_table_from_bad_node(Node), + {noreply, State}; +handle_info(Info, State) -> + ?ERROR_MSG("unexpected info: ~p", [Info]), + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + %%%=================================================================== %%% Internal functions %%%=================================================================== +clean_table_from_bad_node(Node) -> + F = fun() -> + Es = mnesia:select( + muc_online_room, + [{#muc_online_room{pid = '$1', _ = '_'}, + [{'==', {node, '$1'}, Node}], + ['$_']}]), + lists:foreach(fun(E) -> + mnesia:delete_object(E) + end, Es) + end, + mnesia:async_dirty(F). + +clean_table_from_bad_node(Node, Host) -> + F = fun() -> + Es = mnesia:select( + muc_online_room, + [{#muc_online_room{pid = '$1', + name_host = {'_', Host}, + _ = '_'}, + [{'==', {node, '$1'}, Node}], + ['$_']}]), + lists:foreach(fun(E) -> + mnesia:delete_object(E) + end, Es) + end, + mnesia:async_dirty(F). + update_tables(Host) -> update_muc_room_table(Host), update_muc_registered_table(Host). @@ -204,3 +411,20 @@ update_muc_registered_table(_Host) -> ?INFO_MSG("Recreating muc_registered table", []), mnesia:transform_table(muc_registered, ignore, Fields) end. + +update_muc_online_table() -> + try + case mnesia:table_info(muc_online_room, type) of + ordered_set -> ok; + _ -> + case mnesia:delete_table(muc_online_room) of + {atomic, ok} -> ok; + Err -> erlang:error(Err) + end + end + catch _:{aborted, {no_exists, muc_online_room}} -> ok; + _:{aborted, {no_exists, muc_online_room, type}} -> ok; + E:R -> + ?ERROR_MSG("failed to update mnesia table '~s': ~p", + [muc_online_room, {E, R, erlang:get_stacktrace()}]) + end. diff --git a/src/mod_muc_riak.erl b/src/mod_muc_riak.erl index 156396caa..23681e883 100644 --- a/src/mod_muc_riak.erl +++ b/src/mod_muc_riak.erl @@ -30,10 +30,14 @@ %% API -export([init/2, import/3, store_room/4, restore_room/3, forget_room/3, can_use_nick/4, get_rooms/2, get_nick/3, set_nick/4]). +-export([register_online_room/3, unregister_online_room/3, find_online_room/2, + get_online_rooms/2, count_online_rooms/1, rsm_supported/0, + register_online_user/3, unregister_online_user/3, + count_online_rooms_by_user/2, get_online_rooms_by_user/2]). -export([set_affiliation/6, set_affiliations/4, get_affiliation/5, get_affiliations/3, search_affiliation/4]). --include("jlib.hrl"). +-include("jid.hrl"). -include("mod_muc.hrl"). %%%=================================================================== @@ -136,6 +140,36 @@ get_affiliations(_ServerHost, _Room, _Host) -> search_affiliation(_ServerHost, _Room, _Host, _Affiliation) -> {error, not_implemented}. +register_online_room(_, _, _) -> + erlang:error(not_implemented). + +unregister_online_room(_, _, _) -> + erlang:error(not_implemented). + +find_online_room(_, _) -> + erlang:error(not_implemented). + +count_online_rooms(_) -> + erlang:error(not_implemented). + +get_online_rooms(_, _) -> + erlang:error(not_implemented). + +rsm_supported() -> + false. + +register_online_user(_, _, _) -> + erlang:error(not_implemented). + +unregister_online_user(_, _, _) -> + erlang:error(not_implemented). + +count_online_rooms_by_user(_, _) -> + erlang:error(not_implemented). + +get_online_rooms_by_user(_, _) -> + erlang:error(not_implemented). + import(_LServer, <<"muc_room">>, [Name, RoomHost, SOpts, _TimeStamp]) -> Opts = mod_muc:opts_to_binary(ejabberd_sql:decode_term(SOpts)), diff --git a/src/mod_muc_room.erl b/src/mod_muc_room.erl index f524fb7cc..40e9633b1 100644 --- a/src/mod_muc_room.erl +++ b/src/mod_muc_room.erl @@ -1791,7 +1791,7 @@ add_new_user(From, Nick, Packet, StateData) -> Affiliation = get_affiliation(From, StateData), ServiceAffiliation = get_service_affiliation(From, StateData), - NConferences = tab_count_user(From), + NConferences = tab_count_user(From, StateData), MaxConferences = gen_mod:get_module_opt(StateData#state.server_host, mod_muc, max_user_conferences, @@ -4000,38 +4000,25 @@ add_to_log(Type, Data, StateData) -> %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% Users number checking --spec tab_add_online_user(jid(), state()) -> ok. +-spec tab_add_online_user(jid(), state()) -> any(). tab_add_online_user(JID, StateData) -> - {LUser, LServer, LResource} = jid:tolower(JID), - US = {LUser, LServer}, Room = StateData#state.room, Host = StateData#state.host, - catch ets:insert(muc_online_users, - #muc_online_users{us = US, resource = LResource, - room = Room, host = Host}), - ok. + ServerHost = StateData#state.server_host, + mod_muc:register_online_user(ServerHost, jid:tolower(JID), Room, Host). --spec tab_remove_online_user(jid(), state()) -> ok. +-spec tab_remove_online_user(jid(), state()) -> any(). tab_remove_online_user(JID, StateData) -> - {LUser, LServer, LResource} = jid:tolower(JID), - US = {LUser, LServer}, Room = StateData#state.room, Host = StateData#state.host, - catch ets:delete_object(muc_online_users, - #muc_online_users{us = US, resource = LResource, - room = Room, host = Host}), - ok. + ServerHost = StateData#state.server_host, + mod_muc:unregister_online_user(ServerHost, jid:tolower(JID), Room, Host). --spec tab_count_user(jid()) -> non_neg_integer(). -tab_count_user(JID) -> +-spec tab_count_user(jid(), state()) -> non_neg_integer(). +tab_count_user(JID, StateData) -> + ServerHost = StateData#state.server_host, {LUser, LServer, _} = jid:tolower(JID), - US = {LUser, LServer}, - case catch ets:select(muc_online_users, - [{#muc_online_users{us = US, _ = '_'}, [], [[]]}]) - of - Res when is_list(Res) -> length(Res); - _ -> 0 - end. + mod_muc:count_online_rooms_by_user(ServerHost, LUser, LServer). -spec element_size(stanza()) -> non_neg_integer(). element_size(El) -> diff --git a/src/mod_muc_sql.erl b/src/mod_muc_sql.erl index 3771e28b7..f02cc77a8 100644 --- a/src/mod_muc_sql.erl +++ b/src/mod_muc_sql.erl @@ -33,6 +33,10 @@ -export([init/2, store_room/4, restore_room/3, forget_room/3, can_use_nick/4, get_rooms/2, get_nick/3, set_nick/4, import/3, export/1]). +-export([register_online_room/3, unregister_online_room/3, find_online_room/2, + get_online_rooms/2, count_online_rooms/1, rsm_supported/0, + register_online_user/3, unregister_online_user/3, + count_online_rooms_by_user/2, get_online_rooms_by_user/2]). -export([set_affiliation/6, set_affiliations/4, get_affiliation/5, get_affiliations/3, search_affiliation/4]). @@ -161,6 +165,36 @@ get_affiliations(_ServerHost, _Room, _Host) -> search_affiliation(_ServerHost, _Room, _Host, _Affiliation) -> {error, not_implemented}. +register_online_room(_, _, _) -> + erlang:error(not_implemented). + +unregister_online_room(_, _, _) -> + erlang:error(not_implemented). + +find_online_room(_, _) -> + erlang:error(not_implemented). + +count_online_rooms(_) -> + erlang:error(not_implemented). + +get_online_rooms(_, _) -> + erlang:error(not_implemented). + +rsm_supported() -> + false. + +register_online_user(_, _, _) -> + erlang:error(not_implemented). + +unregister_online_user(_, _, _) -> + erlang:error(not_implemented). + +count_online_rooms_by_user(_, _) -> + erlang:error(not_implemented). + +get_online_rooms_by_user(_, _) -> + erlang:error(not_implemented). + export(_Server) -> [{muc_room, fun(Host, #muc_room{name_host = {Name, RoomHost}, opts = Opts}) -> diff --git a/src/mod_offline.erl b/src/mod_offline.erl index 432214f2e..b34572ba8 100644 --- a/src/mod_offline.erl +++ b/src/mod_offline.erl @@ -44,7 +44,7 @@ store_packet/3, store_offline_msg/5, resend_offline_messages/2, - pop_offline_messages/3, + c2s_self_presence/1, get_sm_features/5, get_sm_identity/5, get_sm_items/5, @@ -61,6 +61,8 @@ count_offline_messages/2, get_offline_els/2, find_x_expire/2, + c2s_handle_info/2, + c2s_copy_session/2, webadmin_page/3, webadmin_user/4, webadmin_user_parse_query/5]). @@ -90,6 +92,8 @@ -define(MAX_USER_MESSAGES, infinity). -type us() :: {binary(), binary()}. +-type c2s_state() :: ejabberd_c2s:state(). + -callback init(binary(), gen_mod:opts()) -> any(). -callback import(#offline_msg{}) -> ok. -callback store_messages(binary(), us(), [#offline_msg{}], @@ -140,12 +144,9 @@ init([Host, Opts]) -> no_queue), ejabberd_hooks:add(offline_message_hook, Host, ?MODULE, store_packet, 50), - ejabberd_hooks:add(resend_offline_messages_hook, Host, - ?MODULE, pop_offline_messages, 50), + ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, c2s_self_presence, 50), ejabberd_hooks:add(remove_user, Host, ?MODULE, remove_user, 50), - ejabberd_hooks:add(anonymous_purge_hook, Host, - ?MODULE, remove_user, 50), ejabberd_hooks:add(disco_sm_features, Host, ?MODULE, get_sm_features, 50), ejabberd_hooks:add(disco_local_features, Host, @@ -155,6 +156,8 @@ init([Host, Opts]) -> ejabberd_hooks:add(disco_sm_items, Host, ?MODULE, get_sm_items, 50), ejabberd_hooks:add(disco_info, Host, ?MODULE, get_info, 50), + ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50), + ejabberd_hooks:add(c2s_copy_session, Host, ?MODULE, c2s_copy_session, 50), ejabberd_hooks:add(webadmin_page_host, Host, ?MODULE, webadmin_page, 50), ejabberd_hooks:add(webadmin_user, Host, @@ -199,17 +202,16 @@ terminate(_Reason, State) -> Host = State#state.host, ejabberd_hooks:delete(offline_message_hook, Host, ?MODULE, store_packet, 50), - ejabberd_hooks:delete(resend_offline_messages_hook, - Host, ?MODULE, pop_offline_messages, 50), + ejabberd_hooks:delete(c2s_self_presence, Host, ?MODULE, c2s_self_presence, 50), ejabberd_hooks:delete(remove_user, Host, ?MODULE, remove_user, 50), - ejabberd_hooks:delete(anonymous_purge_hook, Host, - ?MODULE, remove_user, 50), ejabberd_hooks:delete(disco_sm_features, Host, ?MODULE, get_sm_features, 50), ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, get_sm_features, 50), ejabberd_hooks:delete(disco_sm_identity, Host, ?MODULE, get_sm_identity, 50), ejabberd_hooks:delete(disco_sm_items, Host, ?MODULE, get_sm_items, 50), ejabberd_hooks:delete(disco_info, Host, ?MODULE, get_info, 50), + ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50), + ejabberd_hooks:delete(c2s_copy_session, Host, ?MODULE, c2s_copy_session, 50), ejabberd_hooks:delete(webadmin_page_host, Host, ?MODULE, webadmin_page, 50), ejabberd_hooks:delete(webadmin_user, Host, @@ -276,15 +278,13 @@ get_sm_identity(Acc, #jid{luser = U, lserver = S}, #jid{luser = U, lserver = S}, get_sm_identity(Acc, _From, _To, _Node, _Lang) -> Acc. -get_sm_items(_Acc, #jid{luser = U, lserver = S, lresource = R} = JID, +get_sm_items(_Acc, #jid{luser = U, lserver = S} = JID, #jid{luser = U, lserver = S}, ?NS_FLEX_OFFLINE, _Lang) -> - case ejabberd_sm:get_session_pid(U, S, R) of - Pid when is_pid(Pid) -> + ejabberd_sm:route(JID, {resend_offline, false}), Mod = gen_mod:db_mod(S, ?MODULE), Hdrs = Mod:read_message_headers(U, S), BareJID = jid:remove_resource(JID), - Pid ! dont_ask_offline, {result, lists:map( fun({Seq, From, _To, _TS, _El}) -> Node = integer_to_binary(Seq), @@ -292,22 +292,14 @@ get_sm_items(_Acc, #jid{luser = U, lserver = S, lresource = R} = JID, node = Node, name = jid:to_string(From)} end, Hdrs)}; - none -> - {result, []} - end; get_sm_items(Acc, _From, _To, _Node, _Lang) -> Acc. -spec get_info([xdata()], binary(), module(), binary(), binary()) -> [xdata()]; ([xdata()], jid(), jid(), binary(), binary()) -> [xdata()]. -get_info(_Acc, #jid{luser = U, lserver = S, lresource = R}, +get_info(_Acc, #jid{luser = U, lserver = S} = JID, #jid{luser = U, lserver = S}, ?NS_FLEX_OFFLINE, Lang) -> - case ejabberd_sm:get_session_pid(U, S, R) of - Pid when is_pid(Pid) -> - Pid ! dont_ask_offline; - none -> - ok - end, + ejabberd_sm:route(JID, {resend_offline, false}), [#xdata{type = result, fields = flex_offline:encode( [{number_of_messages, count_offline_messages(U, S)}], @@ -315,6 +307,18 @@ get_info(_Acc, #jid{luser = U, lserver = S, lresource = R}, get_info(Acc, _From, _To, _Node, _Lang) -> Acc. +-spec c2s_handle_info(c2s_state(), term()) -> c2s_state(). +c2s_handle_info(State, {resend_offline, Flag}) -> + {stop, State#{resend_offline => Flag}}; +c2s_handle_info(State, _) -> + State. + +-spec c2s_copy_session(c2s_state(), c2s_state()) -> c2s_state(). +c2s_copy_session(State, #{resend_offline := Flag}) -> + State#{resend_offline => Flag}; +c2s_copy_session(State, _) -> + State. + -spec handle_offline_query(iq()) -> iq(). handle_offline_query(#iq{from = #jid{luser = U1, lserver = S1}, to = #jid{luser = U2, lserver = S2}, @@ -394,18 +398,15 @@ set_offline_tag(Msg, Node) -> xmpp:set_subtag(Msg, #offline{items = [#offline_item{node = Node}]}). -spec handle_offline_fetch(jid()) -> ok. -handle_offline_fetch(#jid{luser = U, lserver = S, lresource = R}) -> - case ejabberd_sm:get_session_pid(U, S, R) of - none -> - ok; - Pid when is_pid(Pid) -> - Pid ! dont_ask_offline, +handle_offline_fetch(#jid{luser = U, lserver = S} = JID) -> + ejabberd_sm:route(JID, {resend_offline, false}), lists:foreach( fun({Node, El}) -> - NewEl = set_offline_tag(El, Node), - Pid ! {route, xmpp:get_from(El), xmpp:get_to(El), NewEl} - end, read_messages(U, S)) - end. + El1 = set_offline_tag(El, Node), + From = xmpp:get_from(El1), + To = xmpp:get_to(El1), + ejabberd_router:route(From, To, El1) + end, read_messages(U, S)). -spec fetch_msg_by_node(jid(), binary()) -> error | {ok, #offline_msg{}}. fetch_msg_by_node(To, Seq) -> @@ -560,43 +561,67 @@ resend_offline_messages(User, Server) -> _ -> ok end. --spec pop_offline_messages([{route, jid(), jid(), message()}], - binary(), binary()) -> - [{route, jid(), jid(), message()}]. -pop_offline_messages(Ls, User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), +c2s_self_presence({#presence{type = available} = NewPres, State} = Acc) -> + NewPrio = get_priority_from_presence(NewPres), + LastPrio = try maps:get(pres_last, State) of + LastPres -> get_priority_from_presence(LastPres) + catch _:{badkey, _} -> + -1 + end, + if LastPrio < 0 andalso NewPrio >= 0 -> + route_offline_messages(State); + true -> + ok + end, + Acc; +c2s_self_presence(Acc) -> + Acc. + +-spec route_offline_messages(c2s_state()) -> ok. +route_offline_messages(#{jid := #jid{luser = LUser, lserver = LServer}} = State) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:pop_messages(LUser, LServer) of - {ok, Rs} -> - Ls ++ - lists:flatmap( - fun(#offline_msg{expire = Expire} = R) -> - case offline_msg_to_route(LServer, R) of - error -> - []; - {route, _From, _To, Msg} = RouteMsg -> - case is_expired_message(Expire, Msg) of - true -> []; - false -> [RouteMsg] - end - end - end, Rs); + {ok, OffMsgs} -> + lists:foreach( + fun(OffMsg) -> + route_offline_message(State, OffMsg) + end, OffMsgs); _ -> - Ls + ok end. -is_expired_message(Expire, Pkt) -> - TS = p1_time_compat:timestamp(), - Exp = case Expire of - undefined -> find_x_expire(TS, Pkt); - _ -> Expire - end, - case Exp of - never -> false; - TimeStamp -> TS >= TimeStamp +-spec route_offline_message(c2s_state(), #offline_msg{}) -> ok. +route_offline_message(#{lserver := LServer} = State, + #offline_msg{expire = Expire} = OffMsg) -> + case offline_msg_to_route(LServer, OffMsg) of + error -> + ok; + {route, From, To, Msg} -> + case is_message_expired(Expire, Msg) of + true -> + ok; + false -> + case privacy_check_packet(State, Msg, in) of + allow -> ejabberd_router:route(From, To, Msg); + false -> ok + end + end end. +-spec is_message_expired(erlang:timestamp() | never, message()) -> boolean(). +is_message_expired(Expire, Msg) -> + TS = p1_time_compat:timestamp(), + Expire1 = case Expire of + undefined -> find_x_expire(TS, Msg); + _ -> Expire + end, + Expire1 /= never andalso Expire1 =< TS. + +-spec privacy_check_packet(c2s_state(), stanza(), in | out) -> allow | deny. +privacy_check_packet(#{lserver := LServer} = State, Pkt, Dir) -> + ejabberd_hooks:run_fold(privacy_check_packet, + LServer, allow, [State, Pkt, Dir]). + remove_expired_messages(Server) -> LServer = jid:nameprep(Server), Mod = gen_mod:db_mod(LServer, ?MODULE), @@ -640,14 +665,15 @@ get_offline_els(LUser, LServer) -> -spec offline_msg_to_route(binary(), #offline_msg{}) -> {route, jid(), jid(), message()} | error. -offline_msg_to_route(LServer, #offline_msg{} = R) -> +offline_msg_to_route(LServer, #offline_msg{from = From, to = To} = R) -> try xmpp:decode(R#offline_msg.packet, ?NS_CLIENT, [ignore_els]) of Pkt -> - NewPkt = add_delay_info(Pkt, LServer, R#offline_msg.timestamp), - {route, R#offline_msg.from, R#offline_msg.to, NewPkt} + Pkt1 = xmpp:set_from_to(Pkt, From, To), + Pkt2 = add_delay_info(Pkt1, LServer, R#offline_msg.timestamp), + {route, From, To, Pkt2} catch _:{xmpp_codec, Why} -> ?ERROR_MSG("failed to decode packet ~p of user ~s: ~s", - [R#offline_msg.packet, jid:to_string(R#offline_msg.to), + [R#offline_msg.packet, jid:to_string(To), xmpp:format_error(Why)]), error end. @@ -847,9 +873,17 @@ add_delay_info(Packet, LServer, TS) -> undefined -> p1_time_compat:timestamp(); _ -> TS end, - xmpp_util:add_delay_info(Packet, jid:make(LServer), NewTS, + Packet1 = xmpp:put_meta(Packet, from_offline, true), + xmpp_util:add_delay_info(Packet1, jid:make(LServer), NewTS, <<"Offline storage">>). +-spec get_priority_from_presence(presence()) -> integer(). +get_priority_from_presence(#presence{priority = Prio}) -> + case Prio of + undefined -> 0; + _ -> Prio + end. + export(LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:export(LServer). diff --git a/src/mod_ping.erl b/src/mod_ping.erl index 4cdd7c46d..49664f338 100644 --- a/src/mod_ping.erl +++ b/src/mod_ping.erl @@ -54,8 +54,8 @@ -export([init/1, terminate/2, handle_call/3, handle_cast/2, handle_info/2, code_change/3]). --export([iq_ping/1, user_online/3, user_offline/3, - user_send/4, mod_opt_type/1, depends/2]). +-export([iq_ping/1, user_online/3, user_offline/3, disco_features/5, + user_send/1, mod_opt_type/1, depends/2]). -record(state, {host = <<"">>, @@ -116,7 +116,7 @@ init([Host, Opts]) -> end, none), IQDisc = gen_mod:get_opt(iqdisc, Opts, fun gen_iq_handler:check_type/1, no_queue), - mod_disco:register_feature(Host, ?NS_PING), + ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 50), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_PING, ?MODULE, iq_ping, IQDisc), gen_iq_handler:add_iq_handler(ejabberd_local, Host, @@ -145,11 +145,12 @@ terminate(_Reason, #state{host = Host}) -> ?MODULE, user_online, 100), ejabberd_hooks:delete(user_send_packet, Host, ?MODULE, user_send, 100), + ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, + disco_features, 50), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_PING), gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, - ?NS_PING), - mod_disco:unregister_feature(Host, ?NS_PING). + ?NS_PING). handle_call(stop, _From, State) -> {stop, normal, ok, State}; @@ -215,10 +216,22 @@ user_online(_SID, JID, _Info) -> user_offline(_SID, JID, _Info) -> stop_ping(JID#jid.lserver, JID). --spec user_send(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -user_send(Packet, _C2SState, JID, _From) -> +-spec user_send({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_send({Packet, #{jid := JID} = C2SState}) -> start_ping(JID#jid.lserver, JID), - Packet. + {Packet, C2SState}. + +-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty, + jid(), jid(), binary(), binary()) -> + {error, stanza_error()} | {result, [binary()]}. +disco_features({error, Err}, _From, _To, _Node, _Lang) -> + {error, Err}; +disco_features(empty, _From, _To, <<"">>, _Lang) -> + {result, [?NS_PING]}; +disco_features({result, Feats}, _From, _To, <<"">>, _Lang) -> + {result, [?NS_PING|Feats]}; +disco_features(Acc, _From, _To, _Node, _Lang) -> + Acc. %%==================================================================== %% Internal functions diff --git a/src/mod_pres_counter.erl b/src/mod_pres_counter.erl index f1057e2d8..c22562a11 100644 --- a/src/mod_pres_counter.erl +++ b/src/mod_pres_counter.erl @@ -27,7 +27,7 @@ -behavior(gen_mod). --export([start/2, stop/1, check_packet/6, +-export([start/2, stop/1, check_packet/4, mod_opt_type/1, depends/2]). -include("ejabberd.hrl"). @@ -51,10 +51,12 @@ stop(Host) -> depends(_Host, _Opts) -> []. --spec check_packet(allow | deny, binary(), binary(), _, - {jid(), jid(), stanza()}, in | out) -> allow | deny. -check_packet(_, _User, Server, _PrivacyList, - {From, To, #presence{type = Type}}, Dir) -> +-spec check_packet(allow | deny, ejabberd_c2s:state() | jid(), + stanza(), in | out) -> allow | deny. +check_packet(Acc, #{jid := JID}, Packet, Dir) -> + check_packet(Acc, JID, Packet, Dir); +check_packet(_, #jid{lserver = LServer}, + #presence{from = From, to = To, type = Type}, Dir) -> IsSubscription = case Type of subscribe -> true; subscribed -> true; @@ -67,11 +69,11 @@ check_packet(_, _User, Server, _PrivacyList, in -> To; out -> From end, - update(Server, JID, Dir); + update(LServer, JID, Dir); true -> allow end; -check_packet(_, _User, _Server, _PrivacyList, _Pkt, _Dir) -> - allow. +check_packet(Acc, _, _, _) -> + Acc. update(Server, JID, Dir) -> StormCount = gen_mod:get_module_opt(Server, ?MODULE, count, diff --git a/src/mod_privacy.erl b/src/mod_privacy.erl index 97de32eee..cfced6d06 100644 --- a/src/mod_privacy.erl +++ b/src/mod_privacy.erl @@ -32,10 +32,10 @@ -behaviour(gen_mod). -export([start/2, stop/1, process_iq/1, export/1, import_info/0, - process_iq_set/3, process_iq_get/3, get_user_list/3, - check_packet/6, remove_user/2, encode_list_item/1, - is_list_needdb/1, updated_list/3, - import_start/2, import_stop/2, + c2s_session_opened/1, c2s_copy_session/2, push_list_update/3, + user_send_packet/1, user_receive_packet/1, disco_features/5, + check_packet/4, remove_user/2, encode_list_item/1, + is_list_needdb/1, import_start/2, import_stop/2, item_to_xml/1, get_user_lists/2, import/5, set_privacy_list/1, mod_opt_type/1, depends/2]). @@ -64,102 +64,124 @@ start(Host, Opts) -> one_queue), Mod = gen_mod:db_mod(Host, Opts, ?MODULE), Mod:init(Host, Opts), - mod_disco:register_feature(Host, ?NS_PRIVACY), - ejabberd_hooks:add(privacy_iq_get, Host, ?MODULE, - process_iq_get, 50), - ejabberd_hooks:add(privacy_iq_set, Host, ?MODULE, - process_iq_set, 50), - ejabberd_hooks:add(privacy_get_user_list, Host, ?MODULE, - get_user_list, 50), + ejabberd_hooks:add(disco_local_features, Host, ?MODULE, + disco_features, 50), + ejabberd_hooks:add(c2s_session_opened, Host, ?MODULE, + c2s_session_opened, 50), + ejabberd_hooks:add(c2s_copy_session, Host, ?MODULE, + c2s_copy_session, 50), + ejabberd_hooks:add(user_send_packet, Host, ?MODULE, + user_send_packet, 50), + ejabberd_hooks:add(user_receive_packet, Host, ?MODULE, + user_receive_packet, 50), ejabberd_hooks:add(privacy_check_packet, Host, ?MODULE, check_packet, 50), - ejabberd_hooks:add(privacy_updated_list, Host, ?MODULE, - updated_list, 50), ejabberd_hooks:add(remove_user, Host, ?MODULE, remove_user, 50), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_PRIVACY, ?MODULE, process_iq, IQDisc). stop(Host) -> - mod_disco:unregister_feature(Host, ?NS_PRIVACY), - ejabberd_hooks:delete(privacy_iq_get, Host, ?MODULE, - process_iq_get, 50), - ejabberd_hooks:delete(privacy_iq_set, Host, ?MODULE, - process_iq_set, 50), - ejabberd_hooks:delete(privacy_get_user_list, Host, - ?MODULE, get_user_list, 50), + ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, + disco_features, 50), + ejabberd_hooks:delete(c2s_session_opened, Host, ?MODULE, + c2s_session_opened, 50), + ejabberd_hooks:delete(c2s_copy_session, Host, ?MODULE, + c2s_copy_session, 50), + ejabberd_hooks:delete(user_send_packet, Host, ?MODULE, + user_send_packet, 50), + ejabberd_hooks:delete(user_receive_packet, Host, ?MODULE, + user_receive_packet, 50), ejabberd_hooks:delete(privacy_check_packet, Host, ?MODULE, check_packet, 50), - ejabberd_hooks:delete(privacy_updated_list, Host, - ?MODULE, updated_list, 50), ejabberd_hooks:delete(remove_user, Host, ?MODULE, remove_user, 50), gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_PRIVACY). --spec process_iq(iq()) -> iq(). -process_iq(IQ) -> - xmpp:make_error(IQ, xmpp:err_not_allowed()). - --spec process_iq_get({error, stanza_error()} | {result, xmpp_element() | undefined}, - iq(), userlist()) -> {error, stanza_error()} | - {result, xmpp_element() | undefined}. -process_iq_get(_, #iq{lang = Lang, - sub_els = [#privacy_query{default = Default, - active = Active}]}, - _) when Default /= undefined; Active /= undefined -> - Txt = <<"Only element is allowed in this query">>, - {error, xmpp:err_bad_request(Txt, Lang)}; -process_iq_get(_, #iq{from = From, lang = Lang, - sub_els = [#privacy_query{lists = Lists}]}, - #userlist{name = Active}) -> - #jid{luser = LUser, lserver = LServer} = From, - case Lists of - [] -> - process_lists_get(LUser, LServer, Active, Lang); - [#privacy_list{name = ListName}] -> - process_list_get(LUser, LServer, ListName, Lang); - _ -> - Txt = <<"Too many elements">>, - {error, xmpp:err_bad_request(Txt, Lang)} - end; -process_iq_get(Acc, _, _) -> +-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty, + jid(), jid(), binary(), binary()) -> + {error, stanza_error()} | {result, [binary()]}. +disco_features({error, Err}, _From, _To, _Node, _Lang) -> + {error, Err}; +disco_features(empty, _From, _To, <<"">>, _Lang) -> + {result, [?NS_PRIVACY]}; +disco_features({result, Feats}, _From, _To, <<"">>, _Lang) -> + {result, [?NS_PRIVACY|Feats]}; +disco_features(Acc, _From, _To, _Node, _Lang) -> Acc. --spec process_lists_get(binary(), binary(), binary(), binary()) -> - {error, stanza_error()} | {result, privacy_query()}. -process_lists_get(LUser, LServer, Active, Lang) -> +-spec process_iq(iq()) -> iq(). +process_iq(#iq{type = Type, + from = #jid{luser = U, lserver = S}, + to = #jid{luser = U, lserver = S}} = IQ) -> + case Type of + get -> process_iq_get(IQ); + set -> process_iq_set(IQ) + end; +process_iq(#iq{lang = Lang} = IQ) -> + Txt = <<"Query to another users is forbidden">>, + xmpp:make_error(IQ, xmpp:err_forbidden(Txt, Lang)). + +-spec process_iq_get(iq()) -> iq(). +process_iq_get(#iq{lang = Lang, + sub_els = [#privacy_query{default = Default, + active = Active}]} = IQ) + when Default /= undefined; Active /= undefined -> + Txt = <<"Only element is allowed in this query">>, + xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang)); +process_iq_get(#iq{lang = Lang, + sub_els = [#privacy_query{lists = Lists}]} = IQ) -> + case Lists of + [] -> + process_lists_get(IQ); + [#privacy_list{name = ListName}] -> + process_list_get(IQ, ListName); + _ -> + Txt = <<"Too many elements">>, + xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang)) + end; +process_iq_get(#iq{lang = Lang} = IQ) -> + Txt = <<"No module is handling this query">>, + xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)). + +-spec process_lists_get(iq()) -> iq(). +process_lists_get(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang, + meta = #{privacy_active_list := Active}} = IQ) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:process_lists_get(LUser, LServer) of error -> Txt = <<"Database failure">>, - {error, xmpp:err_internal_server_error(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)); {_Default, []} -> - {result, #privacy_query{}}; + xmpp:make_iq_result(IQ, #privacy_query{}); {Default, ListNames} -> - {result, + xmpp:make_iq_result( + IQ, #privacy_query{active = Active, default = Default, lists = [#privacy_list{name = ListName} - || ListName <- ListNames]}} + || ListName <- ListNames]}) end. --spec process_list_get(binary(), binary(), binary(), binary()) -> - {error, stanza_error()} | {result, privacy_query()}. -process_list_get(LUser, LServer, Name, Lang) -> +-spec process_list_get(iq(), binary()) -> iq(). +process_list_get(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ, Name) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:process_list_get(LUser, LServer, Name) of error -> Txt = <<"Database failure">>, - {error, xmpp:err_internal_server_error(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)); not_found -> Txt = <<"No privacy list with this name found">>, - {error, xmpp:err_item_not_found(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang)); Items -> LItems = lists:map(fun encode_list_item/1, Items), - {result, + xmpp:make_iq_result( + IQ, #privacy_query{ - lists = [#privacy_list{name = Name, items = LItems}]}} + lists = [#privacy_list{name = Name, items = LItems}]}) end. -spec item_to_xml(listitem()) -> xmlel(). @@ -224,69 +246,61 @@ decode_value(Type, Value) -> undefined -> none end. --spec process_iq_set({error, stanza_error()} | - {result, xmpp_element() | undefined} | - {result, xmpp_element() | undefined, userlist()}, - iq(), #userlist{}) -> - {error, stanza_error()} | - {result, xmpp_element() | undefined} | - {result, xmpp_element() | undefined, userlist()}. -process_iq_set(_, #iq{from = From, lang = Lang, +-spec process_iq_set(iq()) -> iq(). +process_iq_set(#iq{lang = Lang, sub_els = [#privacy_query{default = Default, active = Active, - lists = Lists}]}, - #userlist{} = UserList) -> - #jid{luser = LUser, lserver = LServer} = From, + lists = Lists}]} = IQ) -> case Lists of [#privacy_list{items = Items, name = ListName}] when Default == undefined, Active == undefined -> - process_lists_set(LUser, LServer, ListName, Items, UserList, Lang); + process_lists_set(IQ, ListName, Items); [] when Default == undefined, Active /= undefined -> - process_active_set(LUser, LServer, Active, Lang); + process_active_set(IQ, Active); [] when Active == undefined, Default /= undefined -> - process_default_set(LUser, LServer, Default, Lang); + process_default_set(IQ, Default); _ -> Txt = <<"The stanza MUST contain only one element, " "one element, or one element">>, - {error, xmpp:err_bad_request(Txt, Lang)} + xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang)) end; -process_iq_set(Acc, _, _) -> - Acc. +process_iq_set(#iq{lang = Lang} = IQ) -> + Txt = <<"No module is handling this query">>, + xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)). --spec process_default_set(binary(), binary(), none | binary(), - binary()) -> {error, stanza_error()} | {result, undefined}. -process_default_set(LUser, LServer, Value, Lang) -> +-spec process_default_set(iq(), binary()) -> iq(). +process_default_set(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ, Value) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:process_default_set(LUser, LServer, Value) of {atomic, error} -> Txt = <<"Database failure">>, - {error, xmpp:err_internal_server_error(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)); {atomic, not_found} -> Txt = <<"No privacy list with this name found">>, - {error, xmpp:err_item_not_found(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang)); {atomic, ok} -> - {result, undefined}; + xmpp:make_iq_result(IQ); Err -> ?ERROR_MSG("failed to set default list '~s' for user ~s@~s: ~p", [Value, LUser, LServer, Err]), - {error, xmpp:err_internal_server_error()} + xmpp:make_error(IQ, xmpp:err_internal_server_error()) end. --spec process_active_set(binary(), binary(), none | binary(), binary()) -> - {error, stanza_error()} | - {result, undefined, userlist()}. -process_active_set(_LUser, _LServer, none, _Lang) -> - {result, undefined, #userlist{}}; -process_active_set(LUser, LServer, Name, Lang) -> +-spec process_active_set(IQ, none | binary()) -> IQ. +process_active_set(IQ, none) -> + xmpp:make_iq_result(xmpp:put_meta(IQ, privacy_list, #userlist{})); +process_active_set(#iq{from = #jid{luser = LUser, lserver = LServer}, + lang = Lang} = IQ, Name) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:process_active_set(LUser, LServer, Name) of error -> Txt = <<"No privacy list with this name found">>, - {error, xmpp:err_item_not_found(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang)); Items -> NeedDb = is_list_needdb(Items), - {result, undefined, - #userlist{name = Name, list = Items, needdb = NeedDb}} + List = #userlist{name = Name, list = Items, needdb = NeedDb}, + xmpp:make_iq_result(xmpp:put_meta(IQ, privacy_list, List)) end. -spec set_privacy_list(privacy()) -> any(). @@ -294,64 +308,100 @@ set_privacy_list(#privacy{us = {_, LServer}} = Privacy) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:set_privacy_list(Privacy). --spec process_lists_set(binary(), binary(), binary(), [privacy_item()], - #userlist{}, binary()) -> {error, stanza_error()} | - {result, undefined}. -process_lists_set(_LUser, _LServer, Name, [], #userlist{name = Name}, Lang) -> +-spec process_lists_set(iq(), binary(), [privacy_item()]) -> iq(). +process_lists_set(#iq{meta = #{privacy_active_list := Name}, + lang = Lang} = IQ, Name, []) -> Txt = <<"Cannot remove active list">>, - {error, xmpp:err_conflict(Txt, Lang)}; -process_lists_set(LUser, LServer, Name, [], _UserList, Lang) -> + xmpp:make_error(IQ, xmpp:err_conflict(Txt, Lang)); +process_lists_set(#iq{from = #jid{luser = LUser, lserver = LServer} = From, + lang = Lang} = IQ, Name, []) -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:remove_privacy_list(LUser, LServer, Name) of {atomic, conflict} -> Txt = <<"Cannot remove default list">>, - {error, xmpp:err_conflict(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_conflict(Txt, Lang)); {atomic, not_found} -> Txt = <<"No privacy list with this name found">>, - {error, xmpp:err_item_not_found(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang)); {atomic, ok} -> - ejabberd_sm:route(jid:make(LUser, LServer, - <<"">>), - jid:make(LUser, LServer, <<"">>), - {broadcast, {privacy_list, - #userlist{name = Name, - list = []}, - Name}}), - {result, undefined}; + push_list_update(From, #userlist{name = Name}, Name), + xmpp:make_iq_result(IQ); Err -> ?ERROR_MSG("failed to remove privacy list '~s' for user ~s@~s: ~p", [Name, LUser, LServer, Err]), Txt = <<"Database failure">>, - {error, xmpp:err_internal_server_error(Txt, Lang)} + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)) end; -process_lists_set(LUser, LServer, Name, Items, _UserList, Lang) -> +process_lists_set(#iq{from = #jid{luser = LUser, lserver = LServer} = From, + lang = Lang} = IQ, Name, Items) -> case catch lists:map(fun decode_item/1, Items) of {error, Why} -> Txt = xmpp:format_error(Why), - {error, xmpp:err_bad_request(Txt, Lang)}; + xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang)); List -> Mod = gen_mod:db_mod(LServer, ?MODULE), case Mod:set_privacy_list(LUser, LServer, Name, List) of {atomic, ok} -> - NeedDb = is_list_needdb(List), - ejabberd_sm:route(jid:make(LUser, LServer, - <<"">>), - jid:make(LUser, LServer, <<"">>), - {broadcast, {privacy_list, - #userlist{name = Name, - list = List, - needdb = NeedDb}, - Name}}), - {result, undefined}; + UserList = #userlist{name = Name, list = List, + needdb = is_list_needdb(List)}, + push_list_update(From, UserList, Name), + xmpp:make_iq_result(IQ); Err -> ?ERROR_MSG("failed to set privacy list '~s' " "for user ~s@~s: ~p", [Name, LUser, LServer, Err]), Txt = <<"Database failure">>, - {error, xmpp:err_internal_server_error(Txt, Lang)} + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)) end end. +-spec push_list_update(jid(), #userlist{}, binary() | none) -> ok. +push_list_update(From, List, Name) -> + BareFrom = jid:remove_resource(From), + lists:foreach( + fun(R) -> + To = jid:replace_resource(From, R), + IQ = #iq{type = set, from = BareFrom, to = To, + id = <<"push", (randoms:get_string())/binary>>, + sub_els = [#privacy_query{ + lists = [#privacy_list{name = Name}]}], + meta = #{privacy_updated_list => List}}, + ejabberd_router:route(BareFrom, To, IQ) + end, ejabberd_sm:get_user_resources(From#jid.luser, From#jid.lserver)). + +-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_send_packet({#iq{type = Type, + to = #jid{luser = U, lserver = S, lresource = <<"">>}, + from = #jid{luser = U, lserver = S}, + sub_els = [_]} = IQ, + #{privacy_list := #userlist{name = Name}} = State}) + when Type == get; Type == set -> + NewIQ = case xmpp:has_subtag(IQ, #privacy_query{}) of + true -> xmpp:put_meta(IQ, privacy_active_list, Name); + false -> IQ + end, + {NewIQ, State}; +user_send_packet(Acc) -> + Acc. + +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_receive_packet({#iq{type = result, meta = #{privacy_list := List}} = IQ, + State}) -> + {IQ, State#{privacy_list => List}}; +user_receive_packet({#iq{type = set, meta = #{privacy_updated_list := New}} = IQ, + #{user := U, server := S, resource := R, + privacy_list := Old} = State}) -> + State1 = if Old#userlist.name == New#userlist.name -> + State#{privacy_list => New}; + true -> + State + end, + From = jid:make(U, S, <<"">>), + To = jid:make(U, S, R), + {xmpp:set_from_to(IQ, From, To), State1}; +user_receive_packet(Acc) -> + Acc. + -spec decode_item(privacy_item()) -> listitem(). decode_item(#privacy_item{order = Order, action = Action, @@ -394,15 +444,20 @@ is_list_needdb(Items) -> end, Items). --spec get_user_list(userlist(), binary(), binary()) -> userlist(). -get_user_list(_Acc, User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), +-spec get_user_list(binary(), binary()) -> #userlist{}. +get_user_list(LUser, LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), {Default, Items} = Mod:get_user_list(LUser, LServer), NeedDb = is_list_needdb(Items), - #userlist{name = Default, list = Items, - needdb = NeedDb}. + #userlist{name = Default, list = Items, needdb = NeedDb}. + +-spec c2s_session_opened(ejabberd_c2s:state()) -> ejabberd_c2s:state(). +c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer}} = State) -> + State#{privacy_list => get_user_list(LUser, LServer)}. + +-spec c2s_copy_session(ejabberd_c2s:state(), ejabberd_c2s:state()) -> ejabberd_c2s:state(). +c2s_copy_session(State, #{privacy_list := List}) -> + State#{privacy_list => List}. -spec get_user_lists(binary(), binary()) -> {ok, privacy()} | error. get_user_lists(User, Server) -> @@ -414,28 +469,32 @@ get_user_lists(User, Server) -> %% From is the sender, To is the destination. %% If Dir = out, User@Server is the sender account (From). %% If Dir = in, User@Server is the destination account (To). --spec check_packet(allow | deny, binary(), binary(), userlist(), - {jid(), jid(), stanza()}, in | out) -> allow | deny. -check_packet(_, _User, _Server, _UserList, - {#jid{luser = <<"">>, lserver = Server} = _From, - #jid{lserver = Server} = _To, _}, - in) -> +-spec check_packet(allow | deny, ejabberd_c2s:state() | jid(), + stanza(), in | out) -> allow | deny. +check_packet(_, #{jid := #jid{luser = LUser, lserver = LServer}, + privacy_list := #userlist{list = List, needdb = NeedDb}}, + Packet, Dir) -> + From = xmpp:get_from(Packet), + To = xmpp:get_to(Packet), + case {From, To} of + {#jid{luser = <<"">>, lserver = LServer}, + #jid{lserver = LServer}} when Dir == in -> + %% Allow any packets from local server + allow; + {#jid{lserver = LServer}, + #jid{luser = <<"">>, lserver = LServer}} when Dir == out -> + %% Allow any packets to local server allow; -check_packet(_, _User, _Server, _UserList, - {#jid{lserver = Server} = _From, - #jid{luser = <<"">>, lserver = Server} = _To, _}, - out) -> + {#jid{luser = LUser, lserver = LServer, lresource = <<"">>}, + #jid{luser = LUser, lserver = LServer}} when Dir == in -> + %% Allow incoming packets from user's bare jid to his full jid allow; -check_packet(_, _User, _Server, _UserList, - {#jid{luser = User, lserver = Server} = _From, - #jid{luser = User, lserver = Server} = _To, _}, - _Dir) -> + {#jid{luser = LUser, lserver = LServer}, + #jid{luser = LUser, lserver = LServer, lresource = <<"">>}} when Dir == out -> + %% Allow outgoing packets from user's full jid to his bare JID + allow; + _ when List == [] -> allow; -check_packet(_, User, Server, - #userlist{list = List, needdb = NeedDb}, - {From, To, Packet}, Dir) -> - case List of - [] -> allow; _ -> PType = case Packet of #message{} -> message; @@ -455,18 +514,21 @@ check_packet(_, User, Server, in -> jid:tolower(From); out -> jid:tolower(To) end, - {Subscription, Groups} = case NeedDb of + {Subscription, Groups} = + case NeedDb of true -> ejabberd_hooks:run_fold(roster_get_jid_info, - jid:nameprep(Server), + LServer, {none, []}, - [User, Server, - LJID]); - false -> {[], []} + [LUser, LServer, LJID]); + false -> + {[], []} end, - check_packet_aux(List, PType2, LJID, Subscription, - Groups) - end. + check_packet_aux(List, PType2, LJID, Subscription, Groups) + end; +check_packet(Acc, #jid{luser = LUser, lserver = LServer} = JID, Packet, Dir) -> + List = get_user_list(LUser, LServer), + check_packet(Acc, #{jid => JID, privacy_list => List}, Packet, Dir). -spec check_packet_aux([listitem()], message | iq | presence_in | presence_out | other, @@ -538,13 +600,6 @@ remove_user(User, Server) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:remove_user(LUser, LServer). --spec updated_list(userlist(), userlist(), userlist()) -> userlist(). -updated_list(_, #userlist{name = OldName} = Old, - #userlist{name = NewName} = New) -> - if OldName == NewName -> New; - true -> Old - end. - numeric_to_binary(<<0, 0, _/binary>>) -> <<"0">>; numeric_to_binary(<<0, _, _:6/binary, T/binary>>) -> diff --git a/src/mod_privilege.erl b/src/mod_privilege.erl index 60f1be7cb..b860d9a39 100644 --- a/src/mod_privilege.erl +++ b/src/mod_privilege.erl @@ -38,7 +38,7 @@ terminate/2, code_change/3]). -export([component_connected/1, component_disconnected/2, roster_access/2, process_message/3, - process_presence_out/4, process_presence_in/5]). + process_presence_out/1, process_presence_in/1]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -133,10 +133,11 @@ roster_access(false, #iq{from = From, to = To, type = Type}) -> false end. --spec process_presence_out(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -process_presence_out(#presence{type = Type} = Pres, _C2SState, - #jid{luser = LUser, lserver = LServer} = From, - #jid{luser = LUser, lserver = LServer, lresource = <<"">>}) +-spec process_presence_out({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +process_presence_out({#presence{ + from = #jid{luser = LUser, lserver = LServer} = From, + to = #jid{luser = LUser, lserver = LServer, lresource = <<"">>}, + type = Type} = Pres, C2SState}) when Type == available; Type == unavailable -> %% Self-presence processing Permissions = get_permissions(LServer), @@ -151,15 +152,15 @@ process_presence_out(#presence{type = Type} = Pres, _C2SState, ok end end, dict:to_list(Permissions)), - Pres; -process_presence_out(Acc, _, _, _) -> + {Pres, C2SState}; +process_presence_out(Acc) -> Acc. --spec process_presence_in(stanza(), ejabberd_c2s:state(), - jid(), jid(), jid()) -> stanza(). -process_presence_in(#presence{type = Type} = Pres, _C2SState, _, - #jid{luser = U, lserver = S} = From, - #jid{luser = LUser, lserver = LServer}) +-spec process_presence_in({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +process_presence_in({#presence{ + from = #jid{luser = U, lserver = S} = From, + to = #jid{luser = LUser, lserver = LServer}, + type = Type} = Pres, C2SState}) when {U, S} /= {LUser, LServer} andalso (Type == available orelse Type == unavailable) -> Permissions = get_permissions(LServer), @@ -179,8 +180,8 @@ process_presence_in(#presence{type = Type} = Pres, _C2SState, _, ok end end, dict:to_list(Permissions)), - Pres; -process_presence_in(Acc, _, _, _, _) -> + {Pres, C2SState}; +process_presence_in(Acc) -> Acc. %%%=================================================================== diff --git a/src/mod_proxy65.erl b/src/mod_proxy65.erl index d9e84f376..8486802d0 100644 --- a/src/mod_proxy65.erl +++ b/src/mod_proxy65.erl @@ -43,6 +43,12 @@ -define(PROCNAME, ejabberd_mod_proxy65). +-callback init() -> any(). +-callback register_stream(binary(), pid()) -> ok | {error, any()}. +-callback unregister_stream(binary()) -> ok | {error, any()}. +-callback activate_stream(binary(), binary(), pos_integer() | infinity, node()) -> + ok | {error, limit | conflict | notfound | term()}. + start(Host, Opts) -> case mod_proxy65_service:add_listener(Host, Opts) of {error, _} = Err -> erlang:error(Err); @@ -50,7 +56,12 @@ start(Host, Opts) -> Proc = gen_mod:get_module_proc(Host, ?PROCNAME), ChildSpec = {Proc, {?MODULE, start_link, [Host, Opts]}, transient, infinity, supervisor, [?MODULE]}, - supervisor:start_child(ejabberd_sup, ChildSpec) + case supervisor:start_child(ejabberd_sup, ChildSpec) of + {error, _} = Err -> erlang:error(Err); + _ -> + Mod = gen_mod:ram_db_mod(global, ?MODULE), + Mod:init() + end end. stop(Host) -> @@ -77,12 +88,9 @@ init([Host, Opts]) -> ejabberd_mod_proxy65_sup), mod_proxy65_stream]}, transient, infinity, supervisor, [ejabberd_tmp_sup]}, - StreamManager = {mod_proxy65_sm, - {mod_proxy65_sm, start_link, [Host, Opts]}, transient, - 5000, worker, [mod_proxy65_sm]}, {ok, {{one_for_one, 10, 1}, - [StreamManager, StreamSupervisor, Service]}}. + [StreamSupervisor, Service]}}. depends(_Host, _Opts) -> []. @@ -112,7 +120,9 @@ mod_opt_type(max_connections) -> fun (I) when is_integer(I), I > 0 -> I; (infinity) -> infinity end; +mod_opt_type(ram_db_type) -> + fun(T) -> ejabberd_config:v_db(?MODULE, T) end; mod_opt_type(_) -> [auth_type, recbuf, shaper, sndbuf, access, host, hostname, ip, name, port, - max_connections]. + max_connections, ram_db_type]. diff --git a/src/mod_proxy65_mnesia.erl b/src/mod_proxy65_mnesia.erl new file mode 100644 index 000000000..e50b29c98 --- /dev/null +++ b/src/mod_proxy65_mnesia.erl @@ -0,0 +1,145 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2017, Evgeny Khramtsov +%%% @doc +%%% +%%% @end +%%% Created : 16 Jan 2017 by Evgeny Khramtsov +%%%------------------------------------------------------------------- +-module(mod_proxy65_mnesia). +-behaviour(gen_server). +-behaviour(mod_proxy65). + +%% API +-export([init/0, register_stream/2, unregister_stream/1, activate_stream/4]). +-export([start_link/0]). +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-include("logger.hrl"). + +-record(bytestream, + {sha1 = <<"">> :: binary() | '$1', + target :: pid() | '_', + initiator :: pid() | '_', + active = false :: boolean() | '_', + jid_i :: undefined | binary() | '_'}). + +-record(state, {}). + +%%%=================================================================== +%%% API +%%%=================================================================== +start_link() -> + gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). + +init() -> + Spec = {?MODULE, {?MODULE, start_link, []}, transient, + 5000, worker, [?MODULE]}, + supervisor:start_child(ejabberd_sup, Spec). + +register_stream(SHA1, StreamPid) -> + F = fun () -> + case mnesia:read(bytestream, SHA1, write) of + [] -> + mnesia:write(#bytestream{sha1 = SHA1, + target = StreamPid}); + [#bytestream{target = Pid, initiator = undefined} = + ByteStream] when is_pid(Pid), Pid /= StreamPid -> + mnesia:write(ByteStream#bytestream{ + initiator = StreamPid}) + end + end, + case mnesia:transaction(F) of + {atomic, ok} -> + ok; + {aborted, Reason} -> + ?ERROR_MSG("Mnesia transaction failed: ~p", [Reason]), + {error, Reason} + end. + +unregister_stream(SHA1) -> + F = fun () -> mnesia:delete({bytestream, SHA1}) end, + case mnesia:transaction(F) of + {atomic, ok} -> + ok; + {aborted, Reason} -> + ?ERROR_MSG("Mnesia transaction failed: ~p", [Reason]), + {error, Reason} + end. + +activate_stream(SHA1, Initiator, MaxConnections, _Node) -> + case gen_server:call(?MODULE, + {activate_stream, SHA1, Initiator, MaxConnections}) of + {atomic, {ok, IPid, TPid}} -> + {ok, IPid, TPid}; + {atomic, {limit, IPid, TPid}} -> + {error, {limit, IPid, TPid}}; + {atomic, conflict} -> + {error, conflict}; + {atomic, notfound} -> + {error, notfound}; + Err -> + {error, Err} + end. + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([]) -> + ejabberd_mnesia:create(?MODULE, bytestream, + [{ram_copies, [node()]}, + {attributes, record_info(fields, bytestream)}]), + mnesia:add_table_copy(bytestream, node(), ram_copies), + {ok, #state{}}. + +handle_call({activate_stream, SHA1, Initiator, MaxConnections}, _From, State) -> + F = fun () -> + case mnesia:read(bytestream, SHA1, write) of + [#bytestream{target = TPid, initiator = IPid} = + ByteStream] when is_pid(TPid), is_pid(IPid) -> + ActiveFlag = ByteStream#bytestream.active, + if ActiveFlag == false -> + ConnsPerJID = mnesia:select( + bytestream, + [{#bytestream{sha1 = '$1', + jid_i = Initiator, + _ = '_'}, + [], ['$1']}]), + if length(ConnsPerJID) < MaxConnections -> + mnesia:write( + ByteStream#bytestream{active = true, + jid_i = Initiator}), + {ok, IPid, TPid}; + true -> + {limit, IPid, TPid} + end; + true -> + conflict + end; + _ -> + notfound + end + end, + Reply = mnesia:transaction(F), + {reply, Reply, State}; +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== diff --git a/src/mod_proxy65_service.erl b/src/mod_proxy65_service.erl index 7964c3238..f51b33db2 100644 --- a/src/mod_proxy65_service.erl +++ b/src/mod_proxy65_service.erl @@ -175,31 +175,39 @@ process_bytestreams(#iq{type = set, lang = Lang, from = InitiatorJID, to = To, all), case acl:match_rule(ServerHost, ACL, InitiatorJID) of allow -> + Node = ejabberd_cluster:get_node_by_id(To#jid.lresource), Target = jid:to_string(jid:tolower(TargetJID)), Initiator = jid:to_string(jid:tolower(InitiatorJID)), SHA1 = p1_sha:sha(<>), - case mod_proxy65_sm:activate_stream(SHA1, InitiatorJID, - TargetJID, ServerHost) of - ok -> + Mod = gen_mod:ram_db_mod(global, mod_proxy65), + MaxConnections = max_connections(ServerHost), + case Mod:activate_stream(SHA1, Initiator, MaxConnections, Node) of + {ok, InitiatorPid, TargetPid} -> + mod_proxy65_stream:activate( + {InitiatorPid, InitiatorJID}, {TargetPid, TargetJID}), xmpp:make_iq_result(IQ); - false -> + {error, notfound} -> Txt = <<"Failed to activate bytestream">>, xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang)); - limit -> + {error, {limit, InitiatorPid, TargetPid}} -> + mod_proxy65_stream:stop(InitiatorPid), + mod_proxy65_stream:stop(TargetPid), Txt = <<"Too many active bytestreams">>, xmpp:make_error(IQ, xmpp:err_resource_constraint(Txt, Lang)); - conflict -> + {error, conflict} -> Txt = <<"Bytestream already activated">>, xmpp:make_error(IQ, xmpp:err_conflict(Txt, Lang)); - Err -> + {error, Err} -> ?ERROR_MSG("failed to activate bytestream from ~s to ~s: ~p", [Initiator, Target, Err]), - xmpp:make_error(IQ, xmpp:err_internal_server_error()) + Txt = <<"Database failure">>, + xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang)) end; deny -> Txt = <<"Denied by ACL">>, xmpp:make_error(IQ, xmpp:err_forbidden(Txt, Lang)) end. + %%%------------------------- %%% Auxiliary functions. %%%------------------------- @@ -219,7 +227,8 @@ get_streamhost(Host, ServerHost) -> HostName = gen_mod:get_module_opt(ServerHost, mod_proxy65, hostname, fun iolist_to_binary/1, jlib:ip_to_list(IP)), - #streamhost{jid = jid:make(Host), + Resource = ejabberd_cluster:node_id(), + #streamhost{jid = jid:make(<<"">>, Host, Resource), host = HostName, port = Port}. @@ -246,3 +255,9 @@ get_my_ip() -> {ok, Addr} -> Addr; {error, _} -> {127, 0, 0, 1} end. + +max_connections(ServerHost) -> + gen_mod:get_module_opt(ServerHost, mod_proxy65, max_connections, + fun(I) when is_integer(I), I>0 -> I; + (infinity) -> infinity + end, infinity). diff --git a/src/mod_proxy65_sm.erl b/src/mod_proxy65_sm.erl deleted file mode 100644 index f363fbdf8..000000000 --- a/src/mod_proxy65_sm.erl +++ /dev/null @@ -1,171 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : mod_proxy65_sm.erl -%%% Author : Evgeniy Khramtsov -%%% Purpose : Bytestreams manager. -%%% Created : 12 Oct 2006 by Evgeniy Khramtsov -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2017 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(mod_proxy65_sm). - --author('xram@jabber.ru'). - --behaviour(gen_server). - -%% gen_server callbacks. --export([init/1, handle_info/2, handle_call/3, - handle_cast/2, terminate/2, code_change/3]). - --export([start_link/2, register_stream/1, - unregister_stream/1, activate_stream/4]). - --record(state, {max_connections = infinity :: non_neg_integer() | infinity}). - --record(bytestream, - {sha1 = <<"">> :: binary() | '$1', - target :: pid() | '_', - initiator :: pid() | '_', - active = false :: boolean() | '_', - jid_i = {<<"">>, <<"">>, <<"">>} :: jid:ljid() | '_'}). - --define(PROCNAME, ejabberd_mod_proxy65_sm). - -%% Unused callbacks. -handle_cast(_Request, State) -> {noreply, State}. - -code_change(_OldVsn, State, _Extra) -> {ok, State}. - -handle_info(_Info, State) -> {noreply, State}. - -%%---------------- - -start_link(Host, Opts) -> - Proc = gen_mod:get_module_proc(Host, ?PROCNAME), - gen_server:start_link({local, Proc}, ?MODULE, [Opts], - []). - -init([Opts]) -> - ejabberd_mnesia:create(?MODULE, bytestream, - [{ram_copies, [node()]}, - {attributes, record_info(fields, bytestream)}]), - mnesia:add_table_copy(bytestream, node(), ram_copies), - MaxConnections = gen_mod:get_opt(max_connections, Opts, - fun(I) when is_integer(I), I>0 -> - I; - (infinity) -> - infinity - end, infinity), - {ok, #state{max_connections = MaxConnections}}. - -terminate(_Reason, _State) -> ok. - -handle_call({activate, SHA1, IJid}, _From, State) -> - MaxConns = State#state.max_connections, - F = fun () -> - case mnesia:read(bytestream, SHA1, write) of - [#bytestream{target = TPid, initiator = IPid} = - ByteStream] - when is_pid(TPid), is_pid(IPid) -> - ActiveFlag = ByteStream#bytestream.active, - if ActiveFlag == false -> - ConnsPerJID = mnesia:select(bytestream, - [{#bytestream{sha1 = - '$1', - jid_i = - IJid, - _ = '_'}, - [], ['$1']}]), - if length(ConnsPerJID) < MaxConns -> - mnesia:write(ByteStream#bytestream{active = - true, - jid_i = - IJid}), - {ok, IPid, TPid}; - true -> {limit, IPid, TPid} - end; - true -> conflict - end; - _ -> false - end - end, - Reply = mnesia:transaction(F), - {reply, Reply, State}; -handle_call(_Request, _From, State) -> - {reply, ok, State}. - -%%%---------------------- -%%% API. -%%%---------------------- -%%%--------------------------------------------------- -%%% register_stream(SHA1) -> {atomic, ok} | -%%% {atomic, error} | -%%% transaction abort -%%% SHA1 = string() -%%%--------------------------------------------------- -register_stream(SHA1) when is_binary(SHA1) -> - StreamPid = self(), - F = fun () -> - case mnesia:read(bytestream, SHA1, write) of - [] -> - mnesia:write(#bytestream{sha1 = SHA1, - target = StreamPid}); - [#bytestream{target = Pid, initiator = undefined} = - ByteStream] - when is_pid(Pid), Pid /= StreamPid -> - mnesia:write(ByteStream#bytestream{initiator = - StreamPid}); - _ -> error - end - end, - mnesia:transaction(F). - -%%%---------------------------------------------------- -%%% unregister_stream(SHA1) -> ok | transaction abort -%%% SHA1 = string() -%%%---------------------------------------------------- -unregister_stream(SHA1) when is_binary(SHA1) -> - F = fun () -> mnesia:delete({bytestream, SHA1}) end, - mnesia:transaction(F). - -%%%-------------------------------------------------------- -%%% activate_stream(SHA1, IJid, TJid, Host) -> ok | -%%% false | -%%% limit | -%%% conflict | -%%% error -%%% SHA1 = string() -%%% IJid = TJid = jid() -%%% Host = string() -%%%-------------------------------------------------------- -activate_stream(SHA1, IJid, TJid, Host) - when is_binary(SHA1) -> - Proc = gen_mod:get_module_proc(Host, ?PROCNAME), - case catch gen_server:call(Proc, {activate, SHA1, IJid}) - of - {atomic, {ok, IPid, TPid}} -> - mod_proxy65_stream:activate({IPid, IJid}, {TPid, TJid}); - {atomic, {limit, IPid, TPid}} -> - mod_proxy65_stream:stop(IPid), - mod_proxy65_stream:stop(TPid), - limit; - {atomic, conflict} -> conflict; - {atomic, false} -> false; - _ -> error - end. diff --git a/src/mod_proxy65_stream.erl b/src/mod_proxy65_stream.erl index 66e481c47..a04e9e94b 100644 --- a/src/mod_proxy65_stream.erl +++ b/src/mod_proxy65_stream.erl @@ -99,9 +99,10 @@ init([Socket, Host, Opts]) -> socket = Socket, shaper = Shaper, timer = TRef}}. terminate(_Reason, StateName, #state{sha1 = SHA1}) -> - catch mod_proxy65_sm:unregister_stream(SHA1), + Mod = gen_mod:ram_db_mod(global, mod_proxy65), + Mod:unregister_stream(SHA1), if StateName == stream_established -> - ?INFO_MSG("Bytestream terminated", []); + ?INFO_MSG("(~w) Bytestream terminated", [self()]); true -> ok end. @@ -168,8 +169,9 @@ wait_for_request(Packet, Request = mod_proxy65_lib:unpack_request(Packet), case Request of #s5_request{sha1 = SHA1, cmd = connect} -> - case catch mod_proxy65_sm:register_stream(SHA1) of - {atomic, ok} -> + Mod = gen_mod:ram_db_mod(global, mod_proxy65), + case Mod:register_stream(SHA1, self()) of + ok -> inet:setopts(Socket, [{active, false}]), gen_tcp:send(Socket, mod_proxy65_lib:make_reply(Request)), diff --git a/src/mod_pubsub.erl b/src/mod_pubsub.erl index 108c0b593..eba2cab29 100644 --- a/src/mod_pubsub.erl +++ b/src/mod_pubsub.erl @@ -54,7 +54,8 @@ on_user_offline/3, remove_user/2, disco_local_identity/5, disco_local_features/5, disco_local_items/5, disco_sm_identity/5, - disco_sm_features/5, disco_sm_items/5]). + disco_sm_features/5, disco_sm_items/5, + c2s_handle_info/2]). %% exported iq handlers -export([iq_sm/1, process_disco_info/1, process_disco_items/1, @@ -274,7 +275,6 @@ init([ServerHost, Opts]) -> ejabberd_mnesia:create(?MODULE, pubsub_last_item, [{ram_copies, [node()]}, {attributes, record_info(fields, pubsub_last_item)}]), - mod_disco:register_feature(ServerHost, ?NS_PUBSUB), lists:foreach( fun(H) -> T = gen_mod:get_module_proc(H, config), @@ -306,8 +306,8 @@ init([ServerHost, Opts]) -> ?MODULE, out_subscription, 50), ejabberd_hooks:add(remove_user, ServerHost, ?MODULE, remove_user, 50), - ejabberd_hooks:add(anonymous_purge_hook, ServerHost, - ?MODULE, remove_user, 50), + ejabberd_hooks:add(c2s_handle_info, ServerHost, + ?MODULE, c2s_handle_info, 50), gen_iq_handler:add_iq_handler(ejabberd_local, Host, ?NS_DISCO_INFO, ?MODULE, process_disco_info, IQDisc), gen_iq_handler:add_iq_handler(ejabberd_local, Host, ?NS_DISCO_ITEMS, @@ -542,7 +542,7 @@ disco_local_features(Acc, _From, To, <<>>, _Lang) -> {result, I} -> I; _ -> [] end, - {result, Feats ++ [feature(F) || F <- features(Host, <<>>)]}; + {result, Feats ++ [?NS_PUBSUB|[feature(F) || F <- features(Host, <<>>)]]}; disco_local_features(Acc, _From, _To, _Node, _Lang) -> Acc. @@ -922,15 +922,14 @@ terminate(_Reason, ?MODULE, out_subscription, 50), ejabberd_hooks:delete(remove_user, ServerHost, ?MODULE, remove_user, 50), - ejabberd_hooks:delete(anonymous_purge_hook, ServerHost, - ?MODULE, remove_user, 50), + ejabberd_hooks:delete(c2s_handle_info, ServerHost, + ?MODULE, c2s_handle_info, 50), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_DISCO_INFO), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_DISCO_ITEMS), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_PUBSUB), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_PUBSUB_OWNER), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_VCARD), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_COMMANDS), - mod_disco:unregister_feature(ServerHost, ?NS_PUBSUB), case whereis(gen_mod:get_module_proc(ServerHost, ?LOOPNAME)) of undefined -> ?ERROR_MSG("~s process is dead, pubsub was broken", [?LOOPNAME]); @@ -2236,10 +2235,9 @@ dispatch_items({FromU, FromS, FromR} = From, {ToU, ToS, ToR} = To, end, if C2SPid == undefined -> ok; true -> - ejabberd_c2s:send_filtered(C2SPid, - {pep_message, <>}, + C2SPid ! {send_filtered, {pep_message, <>}, service_jid(From), jid:make(To), - Stanza) + Stanza} end; dispatch_items(From, To, _Node, Stanza) -> ejabberd_router:route(service_jid(From), jid:make(To), Stanza). @@ -2773,8 +2771,9 @@ get_resource_state({U, S, R}, ShowValues, JIDs) -> lists:append([{U, S, R}], JIDs); Pid -> Show = case ejabberd_c2s:get_presence(Pid) of - {_, _, <<"available">>, _} -> <<"online">>; - {_, _, State, _} -> State + #presence{type = unavailable} -> <<"unavailable">>; + #presence{show = undefined} -> <<"online">>; + #presence{show = S} -> atom_to_binary(S, latin1) end, case lists:member(Show, ShowValues) of %% If yes, item can be delivered @@ -3020,25 +3019,56 @@ broadcast_stanza({LUser, LServer, LResource}, Publisher, Node, Nidx, Type, NodeO broadcast_stanza({LUser, LServer, <<>>}, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM), %% Handles implicit presence subscriptions SenderResource = user_resource(LUser, LServer, LResource), - case ejabberd_sm:get_session_pid(LUser, LServer, SenderResource) of - C2SPid when is_pid(C2SPid) -> NotificationType = get_option(NodeOptions, notification_type, headline), Stanza = add_message_type(BaseStanza, NotificationType), %% set the from address on the notification to the bare JID of the account owner %% Also, add "replyto" if entity has presence subscription to the account owner %% See XEP-0163 1.1 section 4.3.1 - ejabberd_c2s:broadcast(C2SPid, - {pep_message, <<((Node))/binary, "+notify">>}, - _Sender = jid:make(LUser, LServer, <<"">>), - _StanzaToSend = add_extended_headers( - Stanza, - _ReplyTo = extended_headers([Publisher]))); - _ -> - ?DEBUG("~p@~p has no session; can't deliver ~p to contacts", [LUser, LServer, BaseStanza]) - end; + ejabberd_sm:route(jid:make(LUser, LServer, SenderResource), + {pep_message, <<((Node))/binary, "+notify">>, + jid:make(LUser, LServer, <<"">>), + add_extended_headers( + Stanza, extended_headers([Publisher]))}); broadcast_stanza(Host, _Publisher, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM) -> broadcast_stanza(Host, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM). +-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state(). +c2s_handle_info(#{server := Server} = C2SState, + {pep_message, Feature, From, Packet}) -> + LServer = jid:nameprep(Server), + lists:foreach( + fun({USR, Caps}) -> + Features = mod_caps:get_features(LServer, Caps), + case lists:member(Feature, Features) of + true -> + To = jid:make(USR), + NewPacket = xmpp:set_from_to(Packet, From, To), + ejabberd_router:route(From, To, NewPacket); + false -> + ok + end + end, mod_caps:list_features(C2SState)), + {stop, C2SState}; +c2s_handle_info(#{server := Server} = C2SState, + {send_filtered, {pep_message, Feature}, From, To, Packet}) -> + LServer = jid:nameprep(Server), + case mod_caps:get_user_caps(To, C2SState) of + {ok, Caps} -> + Features = mod_caps:get_features(LServer, Caps), + case lists:member(Feature, Features) of + true -> + NewPacket = xmpp:set_from_to(Packet, From, To), + ejabberd_router:route(From, To, NewPacket); + false -> + ok + end; + error -> + ok + end, + {stop, C2SState}; +c2s_handle_info(C2SState, _) -> + C2SState. + subscribed_nodes_by_jid(NotifyType, SubsByDepth) -> NodesToDeliver = fun (Depth, Node, Subs, Acc) -> NodeName = case Node#pubsub_node.nodeid of diff --git a/src/mod_register.erl b/src/mod_register.erl index 3a544ccfa..7bb60753b 100644 --- a/src/mod_register.erl +++ b/src/mod_register.erl @@ -34,7 +34,7 @@ -behaviour(gen_mod). -export([start/2, stop/1, stream_feature_register/2, - unauthenticated_iq_register/4, try_register/5, + c2s_unauthenticated_packet/2, try_register/5, process_iq/1, send_registration_notifications/3, transform_options/1, transform_module_options/1, mod_opt_type/1, opt_type/1, depends/2]). @@ -50,10 +50,10 @@ start(Host, Opts) -> ?NS_REGISTER, ?MODULE, process_iq, IQDisc), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_REGISTER, ?MODULE, process_iq, IQDisc), - ejabberd_hooks:add(c2s_stream_features, Host, ?MODULE, + ejabberd_hooks:add(c2s_pre_auth_features, Host, ?MODULE, stream_feature_register, 50), - ejabberd_hooks:add(c2s_unauthenticated_iq, Host, - ?MODULE, unauthenticated_iq_register, 50), + ejabberd_hooks:add(c2s_unauthenticated_packet, Host, + ?MODULE, c2s_unauthenticated_packet, 50), ejabberd_mnesia:create(?MODULE, mod_register_ip, [{ram_copies, [node()]}, {local_content, true}, {attributes, [key, value]}]), @@ -62,10 +62,10 @@ start(Host, Opts) -> ok. stop(Host) -> - ejabberd_hooks:delete(c2s_stream_features, Host, + ejabberd_hooks:delete(c2s_pre_auth_features, Host, ?MODULE, stream_feature_register, 50), - ejabberd_hooks:delete(c2s_unauthenticated_iq, Host, - ?MODULE, unauthenticated_iq_register, 50), + ejabberd_hooks:delete(c2s_unauthenticated_packet, Host, + ?MODULE, c2s_unauthenticated_packet, 50), gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_REGISTER), gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, @@ -86,20 +86,21 @@ stream_feature_register(Acc, Host) -> Acc end. --spec unauthenticated_iq_register(empty | iq(), binary(), iq(), - {inet:ip_address(), non_neg_integer()}) -> - empty | iq(). -unauthenticated_iq_register(_Acc, Server, - #iq{sub_els = [#register{}]} = IQ, IP) -> - Address = case IP of - {A, _Port} -> A; - _ -> undefined - end, - ResIQ = process_iq(xmpp:set_from_to(IQ, jid:make(<<>>), jid:make(Server)), - Address), - xmpp:set_from_to(ResIQ, jid:make(Server), undefined); -unauthenticated_iq_register(Acc, _Server, _IQ, _IP) -> - Acc. +c2s_unauthenticated_packet(#{ip := IP, server := Server} = State, + #iq{type = T, sub_els = [_]} = IQ) + when T == set; T == get -> + case xmpp:get_subtag(IQ, #register{}) of + #register{} -> + {Address, _} = IP, + IQ1 = xmpp:set_from_to(IQ, jid:make(<<>>), jid:make(Server)), + ResIQ = process_iq(IQ1, Address), + ResIQ1 = xmpp:set_from_to(ResIQ, jid:make(Server), undefined), + {stop, ejabberd_c2s:send(State, ResIQ1)}; + false -> + State + end; +c2s_unauthenticated_packet(State, _) -> + State. process_iq(#iq{from = From} = IQ) -> process_iq(IQ, jid:tolower(From)). diff --git a/src/mod_roster.erl b/src/mod_roster.erl index 792bc1d43..44649631a 100644 --- a/src/mod_roster.erl +++ b/src/mod_roster.erl @@ -43,9 +43,9 @@ -export([start/2, stop/1, process_iq/1, export/1, import_info/0, process_local_iq/1, get_user_roster/2, - import/5, get_subscription_lists/3, get_roster/2, - import_start/2, import_stop/2, - get_in_pending_subscriptions/3, in_subscription/6, + import/5, c2s_session_opened/1, get_roster/2, + import_start/2, import_stop/2, user_receive_packet/1, + c2s_self_presence/1, in_subscription/6, out_subscription/4, set_items/3, remove_user/2, get_jid_info/4, encode_item/1, webadmin_page/3, webadmin_user/4, get_versioning_feature/2, @@ -63,6 +63,8 @@ -include("ejabberd_web_admin.hrl"). +-define(SETS, gb_sets). + -export_type([subscription/0]). -callback init(binary(), gen_mod:opts()) -> any(). @@ -92,22 +94,22 @@ start(Host, Opts) -> ?MODULE, in_subscription, 50), ejabberd_hooks:add(roster_out_subscription, Host, ?MODULE, out_subscription, 50), - ejabberd_hooks:add(roster_get_subscription_lists, Host, - ?MODULE, get_subscription_lists, 50), + ejabberd_hooks:add(c2s_session_opened, Host, ?MODULE, + c2s_session_opened, 50), ejabberd_hooks:add(roster_get_jid_info, Host, ?MODULE, get_jid_info, 50), ejabberd_hooks:add(remove_user, Host, ?MODULE, remove_user, 50), - ejabberd_hooks:add(anonymous_purge_hook, Host, ?MODULE, - remove_user, 50), - ejabberd_hooks:add(resend_subscription_requests_hook, - Host, ?MODULE, get_in_pending_subscriptions, 50), - ejabberd_hooks:add(roster_get_versioning_feature, Host, + ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, + c2s_self_presence, 50), + ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE, get_versioning_feature, 50), ejabberd_hooks:add(webadmin_page_host, Host, ?MODULE, webadmin_page, 50), ejabberd_hooks:add(webadmin_user, Host, ?MODULE, webadmin_user, 50), + ejabberd_hooks:add(user_receive_packet, Host, ?MODULE, + user_receive_packet, 50), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_ROSTER, ?MODULE, process_iq, IQDisc). @@ -118,22 +120,22 @@ stop(Host) -> ?MODULE, in_subscription, 50), ejabberd_hooks:delete(roster_out_subscription, Host, ?MODULE, out_subscription, 50), - ejabberd_hooks:delete(roster_get_subscription_lists, - Host, ?MODULE, get_subscription_lists, 50), + ejabberd_hooks:delete(c2s_session_opened, Host, ?MODULE, + c2s_session_opened, 50), ejabberd_hooks:delete(roster_get_jid_info, Host, ?MODULE, get_jid_info, 50), ejabberd_hooks:delete(remove_user, Host, ?MODULE, remove_user, 50), - ejabberd_hooks:delete(anonymous_purge_hook, Host, - ?MODULE, remove_user, 50), - ejabberd_hooks:delete(resend_subscription_requests_hook, - Host, ?MODULE, get_in_pending_subscriptions, 50), - ejabberd_hooks:delete(roster_get_versioning_feature, + ejabberd_hooks:delete(c2s_self_presence, Host, ?MODULE, + c2s_self_presence, 50), + ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE, get_versioning_feature, 50), ejabberd_hooks:delete(webadmin_page_host, Host, ?MODULE, webadmin_page, 50), ejabberd_hooks:delete(webadmin_user, Host, ?MODULE, webadmin_user, 50), + ejabberd_hooks:delete(user_receive_packet, Host, ?MODULE, + user_receive_packet, 50), gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_ROSTER). @@ -214,10 +216,16 @@ roster_version_on_db(Host) -> %% Returns a list that may contain an xmlelement with the XEP-237 feature if it's enabled. -spec get_versioning_feature([xmpp_element()], binary()) -> [xmpp_element()]. get_versioning_feature(Acc, Host) -> + case gen_mod:is_loaded(Host, ?MODULE) of + true -> case roster_versioning_enabled(Host) of true -> [#rosterver_feature{}|Acc]; - false -> [] + false -> + Acc + end; + false -> + Acc end. roster_version(LServer, LUser) -> @@ -417,10 +425,6 @@ process_iq_set(#iq{from = From, to = To, end. push_item(User, Server, From, Item) -> - ejabberd_sm:route(jid:make(<<"">>, <<"">>, <<"">>), - jid:make(User, Server, <<"">>), - {broadcast, {item, Item#roster.jid, - Item#roster.subscription}}), case roster_versioning_enabled(Server) of true -> push_item_version(Server, User, From, Item, @@ -442,15 +446,12 @@ push_item(User, Server, Resource, From, Item, not_found -> undefined; _ -> RosterVersion end, - ResIQ = #iq{type = set, -%% @doc Roster push, calculate and include the version attribute. -%% TODO: don't push to those who didn't load roster + To = jid:make(User, Server, Resource), + ResIQ = #iq{type = set, from = From, to = To, id = <<"push", (randoms:get_string())/binary>>, sub_els = [#roster_query{ver = Ver, items = [encode_item(Item)]}]}, - ejabberd_router:route(From, - jid:make(User, Server, Resource), - ResIQ). + ejabberd_router:route(From, To, xmpp:put_meta(ResIQ, roster_item, Item)). push_item_version(Server, User, From, Item, RosterVersion) -> @@ -460,26 +461,88 @@ push_item_version(Server, User, From, Item, end, ejabberd_sm:get_user_resources(User, Server)). --spec get_subscription_lists({[ljid()], [ljid()]}, binary(), binary()) - -> {[ljid()], [ljid()]}. -get_subscription_lists(_Acc, User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), +-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +user_receive_packet({#iq{type = set, meta = #{roster_item := Item}} = IQ, State}) -> + {IQ, roster_change(State, Item)}; +user_receive_packet(Acc) -> + Acc. + +-spec roster_change(ejabberd_c2s:state(), #roster{}) -> ejabberd_c2s:state(). +roster_change(#{user := U, server := S, resource := R, + pres_a := PresA, pres_f := PresF, pres_t := PresT} = State, + #roster{jid = IJID, subscription = ISubscription}) -> + LIJID = jid:tolower(IJID), + IsFrom = (ISubscription == both) or (ISubscription == from), + IsTo = (ISubscription == both) or (ISubscription == to), + OldIsFrom = ?SETS:is_element(LIJID, PresF), + FSet = if IsFrom -> ?SETS:add_element(LIJID, PresF); + true -> ?SETS:del_element(LIJID, PresF) + end, + TSet = if IsTo -> ?SETS:add_element(LIJID, PresT); + true -> ?SETS:del_element(LIJID, PresT) + end, + State1 = State#{pres_f => FSet, pres_t => TSet}, + case maps:get(pres_last, State, undefined) of + undefined -> + State1; + LastPres -> + From = jid:make(U, S, R), + To = jid:make(IJID), + Cond1 = IsFrom andalso not OldIsFrom, + Cond2 = not IsFrom andalso OldIsFrom andalso + ?SETS:is_element(LIJID, PresA), + if Cond1 -> + case ejabberd_hooks:run_fold( + privacy_check_packet, allow, + [State1, LastPres, out]) of + deny -> + ok; + allow -> + Pres = xmpp:set_from_to(LastPres, From, To), + ejabberd_router:route(From, To, Pres) + end, + A = ?SETS:add_element(LIJID, PresA), + State1#{pres_a => A}; + Cond2 -> + PU = #presence{from = From, to = To, type = unavailable}, + case ejabberd_hooks:run_fold( + privacy_check_packet, allow, + [State1, PU, out]) of + deny -> + ok; + allow -> + ejabberd_router:route(From, To, PU) + end, + A = ?SETS:del_element(LIJID, PresA), + State1#{pres_a => A}; + true -> + State1 + end + end. + +-spec c2s_session_opened(ejabberd_c2s:state()) -> ejabberd_c2s:state(). +c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer} = JID, + pres_f := PresF, pres_t := PresT} = State) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Items = Mod:get_only_items(LUser, LServer), - fill_subscription_lists(LServer, Items, [], []). + {F, T} = fill_subscription_lists(Items, PresF, PresT), + LJID = jid:tolower(jid:remove_resource(JID)), + State#{pres_f => ?SETS:add(LJID, F), pres_t => ?SETS:add(LJID, T)}. -fill_subscription_lists(LServer, [I | Is], F, T) -> +fill_subscription_lists([I | Is], F, T) -> J = element(3, I#roster.usj), - case I#roster.subscription of + {F1, T1} = case I#roster.subscription of both -> - fill_subscription_lists(LServer, Is, [J | F], [J | T]); + {?SETS:add_element(J, F), ?SETS:add_element(J, T)}; from -> - fill_subscription_lists(LServer, Is, [J | F], T); - to -> fill_subscription_lists(LServer, Is, F, [J | T]); - _ -> fill_subscription_lists(LServer, Is, F, T) - end; -fill_subscription_lists(_LServer, [], F, T) -> + {?SETS:add_element(J, F), T}; + to -> + {F, ?SETS:add_element(J, T)}; + _ -> + {F, T} + end, + fill_subscription_lists(Is, F1, T1); +fill_subscription_lists([], F, T) -> {F, T}. ask_to_pending(subscribe) -> out; @@ -772,27 +835,47 @@ process_item_set_t(LUser, LServer, #roster_item{jid = JID1} = QueryItem) -> end; process_item_set_t(_LUser, _LServer, _) -> ok. --spec get_in_pending_subscriptions([presence()], binary(), binary()) -> [presence()]. -get_in_pending_subscriptions(Ls, User, Server) -> - LServer = jid:nameprep(Server), +-spec c2s_self_presence({presence(), ejabberd_c2s:state()}) + -> {presence(), ejabberd_c2s:state()}. +c2s_self_presence({_, #{pres_last := _}} = Acc) -> + Acc; +c2s_self_presence({#presence{type = available} = Pkt, + #{lserver := LServer} = State}) -> + Prio = get_priority_from_presence(Pkt), + if Prio >= 0 -> Mod = gen_mod:db_mod(LServer, ?MODULE), - get_in_pending_subscriptions(Ls, User, Server, Mod). + State1 = resend_pending_subscriptions(State, Mod), + {Pkt, State1}; + true -> + {Pkt, State} + end; +c2s_self_presence(Acc) -> + Acc. -get_in_pending_subscriptions(Ls, User, Server, Mod) -> - JID = jid:make(User, Server, <<"">>), +-spec resend_pending_subscriptions(ejabberd_c2s:state(), module()) -> ejabberd_c2s:state(). +resend_pending_subscriptions(#{jid := JID} = State, Mod) -> + BareJID = jid:remove_resource(JID), Result = Mod:get_only_items(JID#jid.luser, JID#jid.lserver), - Ls ++ lists:flatmap( - fun(#roster{ask = Ask} = R) when Ask == in; Ask == both -> + lists:foldl( + fun(#roster{ask = Ask} = R, AccState) when Ask == in; Ask == both -> Message = R#roster.askmessage, Status = if is_binary(Message) -> (Message); true -> <<"">> end, - [#presence{from = R#roster.jid, to = JID, + Sub = #presence{from = R#roster.jid, to = BareJID, type = subscribe, - status = xmpp:mk_text(Status)}]; - (_) -> - [] - end, Result). + status = xmpp:mk_text(Status)}, + ejabberd_c2s:send(AccState, Sub); + (_, AccState) -> + AccState + end, State, Result). + +-spec get_priority_from_presence(presence()) -> integer(). +get_priority_from_presence(#presence{priority = Prio}) -> + case Prio of + undefined -> 0; + _ -> Prio + end. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% diff --git a/src/mod_s2s_dialback.erl b/src/mod_s2s_dialback.erl new file mode 100644 index 000000000..4be58d42c --- /dev/null +++ b/src/mod_s2s_dialback.erl @@ -0,0 +1,334 @@ +%%%------------------------------------------------------------------- +%%% Created : 16 Dec 2016 by Evgeny Khramtsov +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%------------------------------------------------------------------- +-module(mod_s2s_dialback). +-behaviour(gen_mod). + +-protocol({xep, 220, '1.1.1'}). +-protocol({xep, 185, '1.0'}). + +%% gen_mod API +-export([start/2, stop/1, depends/2, mod_opt_type/1]). +%% Hooks +-export([s2s_out_auth_result/2, s2s_out_downgraded/2, + s2s_in_packet/2, s2s_out_packet/2, s2s_in_recv/3, + s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]). + +-include("ejabberd.hrl"). +-include("xmpp.hrl"). +-include("logger.hrl"). + +%%%=================================================================== +%%% API +%%%=================================================================== +start(Host, _Opts) -> + case ejabberd_s2s:tls_verify(Host) of + true -> + ?ERROR_MSG("disabling ~s for host ~s because option " + "'s2s_use_starttls' is set to 'required_trusted'", + [?MODULE, Host]); + false -> + ejabberd_hooks:add(s2s_out_init, Host, ?MODULE, s2s_out_init, 50), + ejabberd_hooks:add(s2s_out_closed, Host, ?MODULE, s2s_out_closed, 50), + ejabberd_hooks:add(s2s_in_pre_auth_features, Host, ?MODULE, + s2s_in_features, 50), + ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE, + s2s_in_features, 50), + ejabberd_hooks:add(s2s_in_handle_recv, Host, ?MODULE, + s2s_in_recv, 50), + ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE, + s2s_in_packet, 50), + ejabberd_hooks:add(s2s_in_authenticated_packet, Host, ?MODULE, + s2s_in_packet, 50), + ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE, + s2s_out_packet, 50), + ejabberd_hooks:add(s2s_out_downgraded, Host, ?MODULE, + s2s_out_downgraded, 50), + ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE, + s2s_out_auth_result, 50) + end. + +stop(Host) -> + ejabberd_hooks:delete(s2s_out_init, Host, ?MODULE, s2s_out_init, 50), + ejabberd_hooks:delete(s2s_out_closed, Host, ?MODULE, s2s_out_closed, 50), + ejabberd_hooks:delete(s2s_in_pre_auth_features, Host, ?MODULE, + s2s_in_features, 50), + ejabberd_hooks:delete(s2s_in_post_auth_features, Host, ?MODULE, + s2s_in_features, 50), + ejabberd_hooks:delete(s2s_in_handle_recv, Host, ?MODULE, + s2s_in_recv, 50), + ejabberd_hooks:delete(s2s_in_unauthenticated_packet, Host, ?MODULE, + s2s_in_packet, 50), + ejabberd_hooks:delete(s2s_in_authenticated_packet, Host, ?MODULE, + s2s_in_packet, 50), + ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE, + s2s_out_packet, 50), + ejabberd_hooks:delete(s2s_out_downgraded, Host, ?MODULE, + s2s_out_downgraded, 50), + ejabberd_hooks:delete(s2s_out_auth_result, Host, ?MODULE, + s2s_out_auth_result, 50). + +depends(_Host, _Opts) -> + []. + +mod_opt_type(_) -> + []. + +s2s_in_features(Acc, _) -> + [#db_feature{errors = true}|Acc]. + +s2s_out_init({ok, State}, Opts) -> + case proplists:get_value(db_verify, Opts) of + {StreamID, Key, Pid} -> + %% This is an outbound s2s connection created at step 1. + %% The purpose of this connection is to verify dialback key ONLY. + %% The connection is not registered in s2s table and thus is not + %% seen by anyone. + %% The connection will be closed immediately after receiving the + %% verification response (at step 3) + {ok, State#{db_verify => {StreamID, Key, Pid}}}; + undefined -> + {ok, State#{db_enabled => true}} + end; +s2s_out_init(Acc, _Opts) -> + Acc. + +s2s_out_closed(#{server := LServer, + remote_server := RServer, + db_verify := {StreamID, _Key, _Pid}} = State, Reason) -> + %% Outbound s2s verificating connection (created at step 1) is + %% closed suddenly without receiving the response. + %% Building a response on our own + Response = #db_verify{from = RServer, to = LServer, + id = StreamID, type = error, + sub_els = [mk_error(Reason)]}, + s2s_out_packet(State, Response); +s2s_out_closed(State, _Reason) -> + State. + +s2s_out_auth_result(#{db_verify := _} = State, _) -> + %% The temporary outbound s2s connect (intended for verification) + %% has passed authentication state (either successfully or not, no matter) + %% and at this point we can send verification request as described + %% in section 2.1.2, step 2 + {stop, send_verify_request(State)}; +s2s_out_auth_result(#{db_enabled := true, + sockmod := SockMod, + socket := Socket, ip := IP, + server := LServer, + remote_server := RServer} = State, {false, _}) -> + %% SASL authentication has failed, retrying with dialback + %% Sending dialback request, section 2.1.1, step 1 + ?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)", + [SockMod:pp(Socket), LServer, RServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + State1 = maps:remove(stop_reason, State#{on_route => queue}), + {stop, send_db_request(State1)}; +s2s_out_auth_result(State, _) -> + State. + +s2s_out_downgraded(#{db_verify := _} = State, _) -> + %% The verifying outbound s2s connection detected non-RFC compliant + %% server, send verification request immediately without auth phase, + %% section 2.1.2, step 2 + {stop, send_verify_request(State)}; +s2s_out_downgraded(#{db_enabled := true, + sockmod := SockMod, + socket := Socket, ip := IP, + server := LServer, + remote_server := RServer} = State, _) -> + %% non-RFC compliant server detected, send dialback request instantly, + %% section 2.1.1, step 1 + ?INFO_MSG("(~s) Trying s2s dialback authentication with " + "non-RFC compliant server: ~s -> ~s (~s)", + [SockMod:pp(Socket), LServer, RServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + {stop, send_db_request(State)}; +s2s_out_downgraded(State, _) -> + State. + +s2s_in_packet(#{stream_id := StreamID} = State, + #db_result{from = From, to = To, key = Key, type = undefined}) -> + %% Received dialback request, section 2.2.1, step 1 + try + ok = check_from_to(From, To), + %% We're creating a temporary outbound s2s connection to + %% send verification request and to receive verification response + {ok, Pid} = ejabberd_s2s_out:start( + To, From, [{db_verify, {StreamID, Key, self()}}]), + ejabberd_s2s_out:connect(Pid), + State + catch _:{badmatch, {error, Reason}} -> + send_db_result(State, + #db_verify{from = From, to = To, type = error, + sub_els = [mk_error(Reason)]}) + end; +s2s_in_packet(State, #db_verify{to = To, from = From, key = Key, + id = StreamID, type = undefined}) -> + %% Received verification request, section 2.2.2, step 2 + Type = case make_key(To, From, StreamID) of + Key -> valid; + _ -> invalid + end, + Response = #db_verify{from = To, to = From, id = StreamID, type = Type}, + ejabberd_s2s_in:send(State, Response); +s2s_in_packet(State, Pkt) when is_record(Pkt, db_result); + is_record(Pkt, db_verify) -> + ?WARNING_MSG("Got stray dialback packet:~n~s", [xmpp:pp(Pkt)]), + State; +s2s_in_packet(State, _) -> + State. + +s2s_in_recv(State, El, {error, Why}) -> + case xmpp:get_name(El) of + Tag when Tag == <<"db:result">>; + Tag == <<"db:verify">> -> + case xmpp:get_type(El) of + T when T /= <<"valid">>, + T /= <<"invalid">>, + T /= <<"error">> -> + Err = xmpp:make_error(El, mk_error({codec_error, Why})), + {stop, ejabberd_s2s_in:send(State, Err)}; + _ -> + State + end; + _ -> + State + end; +s2s_in_recv(State, _El, _Pkt) -> + State. + +s2s_out_packet(#{server := LServer, + remote_server := RServer, + db_verify := {StreamID, _Key, Pid}} = State, + #db_verify{from = RServer, to = LServer, + id = StreamID, type = Type} = Response) + when Type /= undefined -> + %% Received verification response, section 2.1.2, step 3 + %% This is a response for the request sent at step 2 + ejabberd_s2s_in:update_state( + Pid, fun(S) -> send_db_result(S, Response) end), + %% At this point the connection is no longer needed and we can terminate it + ejabberd_s2s_out:stop(State); +s2s_out_packet(#{server := LServer, remote_server := RServer} = State, + #db_result{to = LServer, from = RServer, + type = Type} = Result) when Type /= undefined -> + %% Received dialback response, section 2.1.1, step 4 + %% This is a response to the request sent at step 1 + State1 = maps:remove(db_enabled, State), + case Type of + valid -> + State2 = ejabberd_s2s_out:handle_auth_success(<<"dialback">>, State1), + ejabberd_s2s_out:establish(State2); + _ -> + Reason = format_error(Result), + ejabberd_s2s_out:handle_auth_failure(<<"dialback">>, {auth, Reason}, State1) + end; +s2s_out_packet(State, Pkt) when is_record(Pkt, db_result); + is_record(Pkt, db_verify) -> + ?WARNING_MSG("Got stray dialback packet:~n~s", [xmpp:pp(Pkt)]), + State; +s2s_out_packet(State, _) -> + State. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec make_key(binary(), binary(), binary()) -> binary(). +make_key(From, To, StreamID) -> + Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end), + p1_sha:to_hexlist( + crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)), + [To, " ", From, " ", StreamID])). + +-spec send_verify_request(ejabberd_s2s_out:state()) -> ejabberd_s2s_out:state(). +send_verify_request(#{server := LServer, + remote_server := RServer, + db_verify := {StreamID, Key, _Pid}} = State) -> + Request = #db_verify{from = LServer, to = RServer, + key = Key, id = StreamID}, + ejabberd_s2s_out:send(State, Request). + +-spec send_db_request(ejabberd_s2s_out:state()) -> ejabberd_s2s_out:state(). +send_db_request(#{server := LServer, + remote_server := RServer, + stream_remote_id := StreamID} = State) -> + Key = make_key(LServer, RServer, StreamID), + ejabberd_s2s_out:send(State, #db_result{from = LServer, + to = RServer, + key = Key}). + +-spec send_db_result(ejabberd_s2s_in:state(), db_verify()) -> ejabberd_s2s_in:state(). +send_db_result(State, #db_verify{from = From, to = To, + type = Type, sub_els = Els}) -> + %% Sending dialback response, section 2.2.1, step 4 + %% This is a response to the request received at step 1 + Response = #db_result{from = To, to = From, type = Type, sub_els = Els}, + State1 = ejabberd_s2s_in:send(State, Response), + case Type of + valid -> + State2 = ejabberd_s2s_in:handle_auth_success( + From, <<"dialback">>, undefined, State1), + ejabberd_s2s_in:establish(State2); + _ -> + Reason = format_error(Response), + ejabberd_s2s_in:handle_auth_failure( + From, <<"dialback">>, Reason, State1) + end. + +-spec check_from_to(binary(), binary()) -> ok | {error, forbidden | host_unknown}. +check_from_to(From, To) -> + case ejabberd_router:is_my_route(To) of + false -> {error, host_unknown}; + true -> + LServer = ejabberd_router:host_of_route(To), + case ejabberd_s2s:allow_host(LServer, From) of + true -> ok; + false -> {error, forbidden} + end + end. + +-spec mk_error(term()) -> stanza_error(). +mk_error(forbidden) -> + xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG); +mk_error(host_unknown) -> + xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG); +mk_error({codec_error, Why}) -> + xmpp:err_bad_request(xmpp:io_format_error(Why), ?MYLANG); +mk_error({_Class, _Reason} = Why) -> + Txt = xmpp_stream_out:format_error(Why), + xmpp:err_remote_server_not_found(Txt, ?MYLANG); +mk_error(_) -> + xmpp:err_internal_server_error(). + +-spec format_error(db_result()) -> binary(). +format_error(#db_result{type = invalid}) -> + <<"invalid dialback key">>; +format_error(#db_result{type = error, sub_els = Els}) -> + %% TODO: improve xmpp.erl + case xmpp:get_error(#message{sub_els = Els}) of + #stanza_error{reason = Reason} -> + erlang:atom_to_binary(Reason, latin1); + undefined -> + <<"unrecognized error">> + end; +format_error(_) -> + <<"unexpected dialback result">>. diff --git a/src/mod_service_log.erl b/src/mod_service_log.erl index 1d64f1471..d29cd1329 100644 --- a/src/mod_service_log.erl +++ b/src/mod_service_log.erl @@ -29,8 +29,8 @@ -behaviour(gen_mod). --export([start/2, stop/1, log_user_send/4, - log_user_receive/5, mod_opt_type/1, depends/2]). +-export([start/2, stop/1, log_user_send/1, + log_user_receive/1, mod_opt_type/1, depends/2]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -54,15 +54,19 @@ stop(Host) -> depends(_Host, _Opts) -> []. --spec log_user_send(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza(). -log_user_send(Packet, _C2SState, From, To) -> +-spec log_user_send({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +log_user_send({Packet, C2SState}) -> + From = xmpp:get_from(Packet), + To = xmpp:get_to(Packet), log_packet(From, To, Packet, From#jid.lserver), - Packet. + {Packet, C2SState}. --spec log_user_receive(stanza(), ejabberd_c2s:state(), jid(), jid(), jid()) -> stanza(). -log_user_receive(Packet, _C2SState, _JID, From, To) -> +-spec log_user_receive({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}. +log_user_receive({Packet, C2SState}) -> + From = xmpp:get_from(Packet), + To = xmpp:get_to(Packet), log_packet(From, To, Packet, To#jid.lserver), - Packet. + {Packet, C2SState}. -spec log_packet(jid(), jid(), stanza(), binary()) -> ok. log_packet(From, To, Packet, Host) -> diff --git a/src/mod_shared_roster.erl b/src/mod_shared_roster.erl index 8400823f4..077f9bfab 100644 --- a/src/mod_shared_roster.erl +++ b/src/mod_shared_roster.erl @@ -31,9 +31,9 @@ -export([start/2, stop/1, export/1, import_info/0, webadmin_menu/3, webadmin_page/3, - get_user_roster/2, get_subscription_lists/3, + get_user_roster/2, c2s_session_opened/1, get_jid_info/4, import/5, process_item/2, import_start/2, - in_subscription/6, out_subscription/4, user_available/1, + in_subscription/6, out_subscription/4, c2s_self_presence/1, unset_presence/4, register_user/2, remove_user/2, list_groups/1, create_group/2, create_group/3, delete_group/2, get_group_opts/2, set_group_opts/3, @@ -54,6 +54,8 @@ -include("mod_shared_roster.hrl"). +-define(SETS, gb_sets). + -type group_options() :: [{atom(), any()}]. -callback init(binary(), gen_mod:opts()) -> any(). -callback import(binary(), binary(), [binary()]) -> ok. @@ -84,20 +86,18 @@ start(Host, Opts) -> ?MODULE, in_subscription, 30), ejabberd_hooks:add(roster_out_subscription, Host, ?MODULE, out_subscription, 30), - ejabberd_hooks:add(roster_get_subscription_lists, Host, - ?MODULE, get_subscription_lists, 70), + ejabberd_hooks:add(c2s_session_opened, Host, + ?MODULE, c2s_session_opened, 70), ejabberd_hooks:add(roster_get_jid_info, Host, ?MODULE, get_jid_info, 70), ejabberd_hooks:add(roster_process_item, Host, ?MODULE, process_item, 50), - ejabberd_hooks:add(user_available_hook, Host, ?MODULE, - user_available, 50), + ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, + c2s_self_presence, 50), ejabberd_hooks:add(unset_presence_hook, Host, ?MODULE, unset_presence, 50), ejabberd_hooks:add(register_user, Host, ?MODULE, register_user, 50), - ejabberd_hooks:add(anonymous_purge_hook, Host, ?MODULE, - remove_user, 50), ejabberd_hooks:add(remove_user, Host, ?MODULE, remove_user, 50). @@ -112,20 +112,18 @@ stop(Host) -> ?MODULE, in_subscription, 30), ejabberd_hooks:delete(roster_out_subscription, Host, ?MODULE, out_subscription, 30), - ejabberd_hooks:delete(roster_get_subscription_lists, - Host, ?MODULE, get_subscription_lists, 70), + ejabberd_hooks:delete(c2s_session_opened, + Host, ?MODULE, c2s_session_opened, 70), ejabberd_hooks:delete(roster_get_jid_info, Host, ?MODULE, get_jid_info, 70), ejabberd_hooks:delete(roster_process_item, Host, ?MODULE, process_item, 50), - ejabberd_hooks:delete(user_available_hook, Host, - ?MODULE, user_available, 50), + ejabberd_hooks:delete(c2s_self_presence, Host, + ?MODULE, c2s_self_presence, 50), ejabberd_hooks:delete(unset_presence_hook, Host, ?MODULE, unset_presence, 50), ejabberd_hooks:delete(register_user, Host, ?MODULE, register_user, 50), - ejabberd_hooks:delete(anonymous_purge_hook, Host, - ?MODULE, remove_user, 50), ejabberd_hooks:delete(remove_user, Host, ?MODULE, remove_user, 50). @@ -294,19 +292,21 @@ set_item(User, Server, Resource, Item) -> jid:make(Server), ResIQ). --spec get_subscription_lists({[ljid()], [ljid()]}, binary(), binary()) - -> {[ljid()], [ljid()]}. -get_subscription_lists({F, T}, User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), +c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer} = JID, + pres_f := PresF, pres_t := PresT} = State) -> US = {LUser, LServer}, DisplayedGroups = get_user_displayed_groups(US), - SRUsers = lists:usort(lists:flatmap(fun (Group) -> + SRUsers = lists:flatmap(fun(Group) -> get_group_users(LServer, Group) end, - DisplayedGroups)), - SRJIDs = [{U1, S1, <<"">>} || {U1, S1} <- SRUsers], - {lists:usort(SRJIDs ++ F), lists:usort(SRJIDs ++ T)}. + DisplayedGroups), + BareLJID = jid:tolower(jid:remove_resource(JID)), + PresBoth = lists:foldl( + fun({U, S}, Acc) -> + ?SETS:add_element({U, S, <<"">>}, Acc) + end, ?SETS:new(), [BareLJID|SRUsers]), + State#{pres_f => ?SETS:union(PresBoth, PresF), + pres_t => ?SETS:union(PresBoth, PresT)}. -spec get_jid_info({subscription(), [binary()]}, binary(), binary(), jid()) -> {subscription(), [binary()]}. @@ -739,12 +739,15 @@ push_roster_item(User, Server, ContactU, ContactS, groups = [GroupName]}, push_item(User, Server, Item). --spec user_available(jid()) -> ok. -user_available(New) -> +-spec c2s_self_presence({presence(), ejabberd_c2s:state()}) + -> {presence(), ejabberd_c2s:state()}. +c2s_self_presence({_, #{pres_last := _}} = Acc) -> + %% This is just a presence update, nothing to do + Acc; +c2s_self_presence({#presence{type = available}, #{jid := New}} = Acc) -> LUser = New#jid.luser, LServer = New#jid.lserver, - Resources = ejabberd_sm:get_user_resources(LUser, - LServer), + Resources = ejabberd_sm:get_user_resources(LUser, LServer), ?DEBUG("user_available for ~p @ ~p (~p resources)", [LUser, LServer, length(Resources)]), case length(Resources) of @@ -761,7 +764,10 @@ user_available(New) -> end, UserGroups); _ -> ok - end. + end, + Acc; +c2s_self_presence(Acc) -> + Acc. -spec unset_presence(binary(), binary(), binary(), binary()) -> ok. unset_presence(LUser, LServer, Resource, Status) -> @@ -1038,11 +1044,8 @@ split_grouphost(Host, Group) -> end. broadcast_subscription(User, Server, ContactJid, Subscription) -> - ejabberd_sm:route( - jid:make(<<"">>, Server, <<"">>), - jid:make(User, Server, <<"">>), - {broadcast, {item, ContactJid, - Subscription}}). + ejabberd_sm:route(jid:make(User, Server, <<"">>), + {item, ContactJid, Subscription}). displayed_groups_update(Members, DisplayedGroups, Subscription) -> lists:foreach(fun({U, S}) -> diff --git a/src/mod_shared_roster_ldap.erl b/src/mod_shared_roster_ldap.erl index 7ebceb9b3..e79bcc5c0 100644 --- a/src/mod_shared_roster_ldap.erl +++ b/src/mod_shared_roster_ldap.erl @@ -39,7 +39,7 @@ -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([get_user_roster/2, get_subscription_lists/3, +-export([get_user_roster/2, c2s_session_opened/1, get_jid_info/4, process_item/2, in_subscription/6, out_subscription/4, mod_opt_type/1, opt_type/1, depends/2]). @@ -49,6 +49,7 @@ -include("mod_roster.hrl"). -include("eldap.hrl"). +-define(SETS, gb_sets). -define(CACHE_SIZE, 1000). -define(USER_CACHE_VALIDITY, 300). %% in seconds -define(GROUP_CACHE_VALIDITY, 300). @@ -160,19 +161,21 @@ process_item(RosterItem, _Host) -> _ -> RosterItem#roster{subscription = both, ask = none} end. --spec get_subscription_lists({[ljid()], [ljid()]}, binary(), binary()) - -> {[ljid()], [ljid()]}. -get_subscription_lists({F, T}, User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), +c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer} = JID, + pres_f := PresF, pres_t := PresT} = State) -> US = {LUser, LServer}, DisplayedGroups = get_user_displayed_groups(US), - SRUsers = lists:usort(lists:flatmap(fun (Group) -> + SRUsers = lists:flatmap(fun(Group) -> get_group_users(LServer, Group) end, - DisplayedGroups)), - SRJIDs = [{U1, S1, <<"">>} || {U1, S1} <- SRUsers], - {lists:usort(SRJIDs ++ F), lists:usort(SRJIDs ++ T)}. + DisplayedGroups), + BareLJID = jid:tolower(jid:remove_resource(JID)), + PresBoth = lists:foldl( + fun({U, S}, Acc) -> + ?SETS:add_element({U, S, <<"">>}, Acc) + end, ?SETS:new(), [BareLJID|SRUsers]), + State#{pres_f => ?SETS:union(PresBoth, PresF), + pres_t => ?SETS:union(PresBoth, PresT)}. -spec get_jid_info({subscription(), [binary()]}, binary(), binary(), jid()) -> {subscription(), [binary()]}. @@ -246,8 +249,8 @@ init([Host, Opts]) -> ?MODULE, in_subscription, 30), ejabberd_hooks:add(roster_out_subscription, Host, ?MODULE, out_subscription, 30), - ejabberd_hooks:add(roster_get_subscription_lists, Host, - ?MODULE, get_subscription_lists, 70), + ejabberd_hooks:add(c2s_session_opened, Host, + ?MODULE, c2s_session_opened, 70), ejabberd_hooks:add(roster_get_jid_info, Host, ?MODULE, get_jid_info, 70), ejabberd_hooks:add(roster_process_item, Host, ?MODULE, @@ -275,8 +278,8 @@ terminate(_Reason, State) -> ?MODULE, in_subscription, 30), ejabberd_hooks:delete(roster_out_subscription, Host, ?MODULE, out_subscription, 30), - ejabberd_hooks:delete(roster_get_subscription_lists, - Host, ?MODULE, get_subscription_lists, 70), + ejabberd_hooks:delete(c2s_session_opened, + Host, ?MODULE, c2s_session_opened, 70), ejabberd_hooks:delete(roster_get_jid_info, Host, ?MODULE, get_jid_info, 70), ejabberd_hooks:delete(roster_process_item, Host, diff --git a/src/mod_sm.erl b/src/mod_sm.erl new file mode 100644 index 000000000..aa5b2be54 --- /dev/null +++ b/src/mod_sm.erl @@ -0,0 +1,760 @@ +%%%------------------------------------------------------------------- +%%% Author : Holger Weiss +%%% Created : 25 Dec 2016 by Evgeny Khramtsov +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%------------------------------------------------------------------- +-module(mod_sm). +-behaviour(gen_mod). +-author('holger@zedat.fu-berlin.de'). +-protocol({xep, 198, '1.5.2'}). + +%% gen_mod API +-export([start/2, stop/1, depends/2, mod_opt_type/1]). +%% hooks +-export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2, + c2s_authenticated_packet/2, c2s_unauthenticated_packet/2, + c2s_unbinded_packet/2, c2s_closed/2, c2s_terminated/2, + c2s_handle_send/3, c2s_handle_info/2, c2s_handle_call/3, + c2s_handle_recv/3]). + +-include("xmpp.hrl"). +-include("logger.hrl"). + +-define(is_sm_packet(Pkt), + is_record(Pkt, sm_enable) or + is_record(Pkt, sm_resume) or + is_record(Pkt, sm_a) or + is_record(Pkt, sm_r)). + +-type state() :: ejabberd_c2s:state(). +-type lqueue() :: {non_neg_integer(), queue:queue()}. + +%%%=================================================================== +%%% API +%%%=================================================================== +start(Host, _Opts) -> + ejabberd_hooks:add(c2s_init, ?MODULE, c2s_stream_init, 50), + ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE, + c2s_stream_started, 50), + ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE, + c2s_stream_features, 50), + ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE, + c2s_unauthenticated_packet, 50), + ejabberd_hooks:add(c2s_unbinded_packet, Host, ?MODULE, + c2s_unbinded_packet, 50), + ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE, + c2s_authenticated_packet, 50), + ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50), + ejabberd_hooks:add(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50), + ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50), + ejabberd_hooks:add(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50), + ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50), + ejabberd_hooks:add(c2s_terminated, Host, ?MODULE, c2s_terminated, 50). + +stop(Host) -> + %% TODO: do something with global 'c2s_init' hook + ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE, + c2s_stream_started, 50), + ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE, + c2s_stream_features, 50), + ejabberd_hooks:delete(c2s_unauthenticated_packet, Host, ?MODULE, + c2s_unauthenticated_packet, 50), + ejabberd_hooks:delete(c2s_unbinded_packet, Host, ?MODULE, + c2s_unbinded_packet, 50), + ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE, + c2s_authenticated_packet, 50), + ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50), + ejabberd_hooks:delete(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50), + ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50), + ejabberd_hooks:delete(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50), + ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50), + ejabberd_hooks:delete(c2s_terminated, Host, ?MODULE, c2s_terminated, 50). + +depends(_Host, _Opts) -> + []. + +c2s_stream_init({ok, State}, Opts) -> + MgmtOpts = lists:filter( + fun({stream_management, _}) -> true; + ({max_ack_queue, _}) -> true; + ({resume_timeout, _}) -> true; + ({max_resume_timeout, _}) -> true; + ({ack_timeout, _}) -> true; + ({resend_on_timeout, _}) -> true; + (_) -> false + end, Opts), + {ok, State#{mgmt_options => MgmtOpts}}; +c2s_stream_init(Acc, _Opts) -> + Acc. + +c2s_stream_started(#{lserver := LServer, mgmt_options := Opts} = State, + _StreamStart) -> + State1 = maps:remove(mgmt_options, State), + ResumeTimeout = get_resume_timeout(LServer, Opts), + MaxResumeTimeout = get_max_resume_timeout(LServer, Opts, ResumeTimeout), + State1#{mgmt_state => inactive, + mgmt_max_queue => get_max_ack_queue(LServer, Opts), + mgmt_timeout => ResumeTimeout, + mgmt_max_timeout => MaxResumeTimeout, + mgmt_ack_timeout => get_ack_timeout(LServer, Opts), + mgmt_resend => get_resend_on_timeout(LServer, Opts), + mgmt_stanzas_in => 0, + mgmt_stanzas_out => 0, + mgmt_stanzas_req => 0}; +c2s_stream_started(State, _StreamStart) -> + State. + +c2s_stream_features(Acc, Host) -> + case gen_mod:is_loaded(Host, ?MODULE) of + true -> + [#feature_sm{xmlns = ?NS_STREAM_MGMT_2}, + #feature_sm{xmlns = ?NS_STREAM_MGMT_3}|Acc]; + false -> + Acc + end. + +c2s_unauthenticated_packet(State, Pkt) when ?is_sm_packet(Pkt) -> + %% XEP-0198 says: "For client-to-server connections, the client MUST NOT + %% attempt to enable stream management until after it has completed Resource + %% Binding unless it is resuming a previous session". However, it also + %% says: "Stream management errors SHOULD be considered recoverable", so we + %% won't bail out. + Err = #sm_failed{reason = 'unexpected-request', xmlns = ?NS_STREAM_MGMT_3}, + {stop, send(State, Err)}; +c2s_unauthenticated_packet(State, _Pkt) -> + State. + +c2s_unbinded_packet(State, #sm_resume{} = Pkt) -> + case handle_resume(State, Pkt) of + {ok, ResumedState} -> + {stop, ResumedState}; + {error, State1} -> + {stop, State1} + end; +c2s_unbinded_packet(State, Pkt) when ?is_sm_packet(Pkt) -> + c2s_unauthenticated_packet(State, Pkt); +c2s_unbinded_packet(State, _Pkt) -> + State. + +c2s_authenticated_packet(#{mgmt_state := MgmtState} = State, Pkt) + when ?is_sm_packet(Pkt) -> + if MgmtState == pending; MgmtState == active -> + {stop, perform_stream_mgmt(Pkt, State)}; + true -> + {stop, negotiate_stream_mgmt(Pkt, State)} + end; +c2s_authenticated_packet(State, Pkt) -> + update_num_stanzas_in(State, Pkt). + +c2s_handle_recv(#{lang := Lang} = State, El, {error, Why}) -> + Xmlns = xmpp:get_ns(El), + if Xmlns == ?NS_STREAM_MGMT_2; Xmlns == ?NS_STREAM_MGMT_3 -> + Txt = xmpp:io_format_error(Why), + Err = #sm_failed{reason = 'bad-request', + text = xmpp:mk_text(Txt, Lang), + xmlns = Xmlns}, + send(State, Err); + true -> + State + end; +c2s_handle_recv(State, _, _) -> + State. + +c2s_handle_send(#{mgmt_state := MgmtState, mod := Mod, + lang := Lang} = State, Pkt, SendResult) + when MgmtState == pending; MgmtState == active -> + case xmpp:is_stanza(Pkt) of + true -> + case mgmt_queue_add(State, Pkt) of + #{mgmt_max_queue := exceeded} = State1 -> + State2 = State1#{mgmt_resend => false}, + case MgmtState of + active -> + Err = xmpp:serr_policy_violation( + <<"Too many unacked stanzas">>, Lang), + send(State2, Err); + _ -> + Mod:stop(State2) + end; + State1 when SendResult == ok -> + send_rack(State1); + State1 -> + State1 + end; + false -> + State + end; +c2s_handle_send(State, _Pkt, _Result) -> + State. + +c2s_handle_call(#{sid := {Time, _}, mod := Mod} = State, + {resume_session, Time}, From) -> + Mod:reply(From, {resume, State}), + {stop, State#{mgmt_state => resumed}}; +c2s_handle_call(#{mod := Mod} = State, {resume_session, _}, From) -> + Mod:reply(From, {error, <<"Previous session not found">>}), + {stop, State}; +c2s_handle_call(State, _Call, _From) -> + State. + +c2s_handle_info(#{mgmt_ack_timer := TRef, jid := JID, mod := Mod} = State, + {timeout, TRef, ack_timeout}) -> + ?DEBUG("Timed out waiting for stream management acknowledgement of ~s", + [jid:to_string(JID)]), + State1 = State#{stop_reason => {socket, timeout}}, + State2 = Mod:close(State1, _SendTrailer = false), + {stop, transition_to_pending(State2)}; +c2s_handle_info(#{mgmt_state := pending, jid := JID, mod := Mod} = State, + {timeout, _, pending_timeout}) -> + ?DEBUG("Timed out waiting for resumption of stream for ~s", + [jid:to_string(JID)]), + Mod:stop(State#{mgmt_state => timeout}); +c2s_handle_info(#{jid := JID} = State, {_Ref, {resume, OldState}}) -> + %% This happens if the resume_session/1 request timed out; the new session + %% now receives the late response. + ?DEBUG("Received old session state for ~s after failed resumption", + [jid:to_string(JID)]), + route_unacked_stanzas(OldState#{mgmt_resend => false}), + State; +c2s_handle_info(State, _) -> + State. + +c2s_closed(State, {stream, _}) -> + State; +c2s_closed(#{mgmt_state := active} = State, _Reason) -> + {stop, transition_to_pending(State)}; +c2s_closed(State, _Reason) -> + State. + +c2s_terminated(#{mgmt_state := resumed, jid := JID} = State, _Reason) -> + ?INFO_MSG("Closing former stream of resumed session for ~s", + [jid:to_string(JID)]), + bounce_message_queue(), + {stop, State}; +c2s_terminated(#{mgmt_state := MgmtState, mgmt_stanzas_in := In, sid := SID, + user := U, server := S, resource := R} = State, _Reason) -> + case MgmtState of + timeout -> + Info = [{num_stanzas_in, In}], + ejabberd_sm:set_offline_info(SID, U, S, R, Info); + _ -> + ok + end, + route_unacked_stanzas(State), + State; +c2s_terminated(State, _Reason) -> + State. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec negotiate_stream_mgmt(xmpp_element(), state()) -> state(). +negotiate_stream_mgmt(Pkt, State) -> + Xmlns = xmpp:get_ns(Pkt), + case Pkt of + #sm_enable{} -> + handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt); + _ when is_record(Pkt, sm_a); + is_record(Pkt, sm_r); + is_record(Pkt, sm_resume) -> + Err = #sm_failed{reason = 'unexpected-request', xmlns = Xmlns}, + send(State, Err); + _ -> + Err = #sm_failed{reason = 'bad-request', xmlns = Xmlns}, + send(State, Err) + end. + +-spec perform_stream_mgmt(xmpp_element(), state()) -> state(). +perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) -> + case xmpp:get_ns(Pkt) of + Xmlns -> + case Pkt of + #sm_r{} -> + handle_r(State); + #sm_a{} -> + handle_a(State, Pkt); + _ when is_record(Pkt, sm_enable); + is_record(Pkt, sm_resume) -> + send(State, #sm_failed{reason = 'unexpected-request', + xmlns = Xmlns}); + _ -> + send(State, #sm_failed{reason = 'bad-request', + xmlns = Xmlns}) + end; + _ -> + send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns}) + end. + +-spec handle_enable(state(), sm_enable()) -> state(). +handle_enable(#{mgmt_timeout := DefaultTimeout, + mgmt_max_timeout := MaxTimeout, + mgmt_xmlns := Xmlns, jid := JID} = State, + #sm_enable{resume = Resume, max = Max}) -> + Timeout = if Resume == false -> + 0; + Max /= undefined, Max > 0, Max =< MaxTimeout -> + Max; + true -> + DefaultTimeout + end, + Res = if Timeout > 0 -> + ?INFO_MSG("Stream management with resumption enabled for ~s", + [jid:to_string(JID)]), + #sm_enabled{xmlns = Xmlns, + id = make_resume_id(State), + resume = true, + max = Timeout}; + true -> + ?INFO_MSG("Stream management without resumption enabled for ~s", + [jid:to_string(JID)]), + #sm_enabled{xmlns = Xmlns} + end, + State1 = State#{mgmt_state => active, + mgmt_queue => queue_new(), + mgmt_timeout => Timeout}, + send(State1, Res). + +-spec handle_r(state()) -> state(). +handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) -> + Res = #sm_a{xmlns = Xmlns, h = H}, + send(State, Res). + +-spec handle_a(state(), sm_a()) -> state(). +handle_a(State, #sm_a{h = H}) -> + State1 = check_h_attribute(State, H), + resend_rack(State1). + +-spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}. +handle_resume(#{user := User, lserver := LServer, sockmod := SockMod, + lang := Lang, socket := Socket} = State, + #sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) -> + R = case inherit_session_state(State, PrevID) of + {ok, InheritedState} -> + {ok, InheritedState, H}; + {error, Err, InH} -> + {error, #sm_failed{reason = 'item-not-found', + text = xmpp:mk_text(Err, Lang), + h = InH, xmlns = Xmlns}, Err}; + {error, Err} -> + {error, #sm_failed{reason = 'item-not-found', + text = xmpp:mk_text(Err, Lang), + xmlns = Xmlns}, Err} + end, + case R of + {ok, #{jid := JID} = ResumedState, NumHandled} -> + State1 = check_h_attribute(ResumedState, NumHandled), + #{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1, + AttrId = make_resume_id(State1), + State2 = send(State1, #sm_resumed{xmlns = AttrXmlns, + h = AttrH, + previd = AttrId}), + State3 = resend_unacked_stanzas(State2), + State4 = send(State3, #sm_r{xmlns = AttrXmlns}), + %% TODO: move this to mod_client_state + %% csi_flush_queue(State4), + State5 = ejabberd_hooks:run_fold(c2s_session_resumed, LServer, State4, []), + ?INFO_MSG("(~s) Resumed session for ~s", + [SockMod:pp(Socket), jid:to_string(JID)]), + {ok, State5}; + {error, El, Msg} -> + ?INFO_MSG("Cannot resume session for ~s@~s: ~s", + [User, LServer, Msg]), + {error, send(State, El)} + end. + +-spec transition_to_pending(state()) -> state(). +transition_to_pending(#{mgmt_state := active, mod := Mod, + mgmt_timeout := 0} = State) -> + Mod:stop(State); +transition_to_pending(#{mgmt_state := active, jid := JID, + lserver := LServer, mgmt_timeout := Timeout} = State) -> + State1 = cancel_ack_timer(State), + ?INFO_MSG("Waiting for resumption of stream for ~s", [jid:to_string(JID)]), + erlang:start_timer(timer:seconds(Timeout), self(), pending_timeout), + State2 = State1#{mgmt_state => pending}, + ejabberd_hooks:run_fold(c2s_session_pending, LServer, State2, []); +transition_to_pending(State) -> + State. + +-spec check_h_attribute(state(), non_neg_integer()) -> state(). +check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, jid := JID} = State, H) + when H > NumStanzasOut -> + ?DEBUG("~s acknowledged ~B stanzas, but only ~B were sent", + [jid:to_string(JID), H, NumStanzasOut]), + mgmt_queue_drop(State#{mgmt_stanzas_out => H}, NumStanzasOut); +check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, jid := JID} = State, H) -> + ?DEBUG("~s acknowledged ~B of ~B stanzas", + [jid:to_string(JID), H, NumStanzasOut]), + mgmt_queue_drop(State, H). + +-spec update_num_stanzas_in(state(), xmpp_element()) -> state(). +update_num_stanzas_in(#{mgmt_state := MgmtState, + mgmt_stanzas_in := NumStanzasIn} = State, El) + when MgmtState == active; MgmtState == pending -> + NewNum = case {xmpp:is_stanza(El), NumStanzasIn} of + {true, 4294967295} -> + 0; + {true, Num} -> + Num + 1; + {false, Num} -> + Num + end, + State#{mgmt_stanzas_in => NewNum}; +update_num_stanzas_in(State, _El) -> + State. + +send_rack(#{mgmt_ack_timer := _} = State) -> + State; +send_rack(#{mgmt_xmlns := Xmlns, + mgmt_stanzas_out := NumStanzasOut, + mgmt_ack_timeout := AckTimeout} = State) -> + TRef = erlang:start_timer(AckTimeout, self(), ack_timeout), + State1 = State#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}, + send(State1, #sm_r{xmlns = Xmlns}). + +resend_rack(#{mgmt_ack_timer := _, + mgmt_queue := Queue, + mgmt_stanzas_out := NumStanzasOut, + mgmt_stanzas_req := NumStanzasReq} = State) -> + State1 = cancel_ack_timer(State), + case NumStanzasReq < NumStanzasOut andalso not queue_is_empty(Queue) of + true -> send_rack(State1); + false -> State1 + end; +resend_rack(State) -> + State. + +-spec mgmt_queue_add(state(), xmpp_element()) -> state(). +mgmt_queue_add(#{mgmt_stanzas_out := NumStanzasOut, + mgmt_queue := Queue} = State, Pkt) -> + NewNum = case NumStanzasOut of + 4294967295 -> 0; + Num -> Num + 1 + end, + Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Pkt}, Queue), + State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum}, + check_queue_length(State1). + +-spec mgmt_queue_drop(state(), non_neg_integer()) -> state(). +mgmt_queue_drop(#{mgmt_queue := Queue} = State, NumHandled) -> + NewQueue = queue_dropwhile( + fun({N, _T, _E}) -> N =< NumHandled end, Queue), + State#{mgmt_queue => NewQueue}. + +-spec check_queue_length(state()) -> state(). +check_queue_length(#{mgmt_max_queue := Limit} = State) + when Limit == infinity; Limit == exceeded -> + State; +check_queue_length(#{mgmt_queue := Queue, mgmt_max_queue := Limit} = State) -> + case queue_len(Queue) > Limit of + true -> + State#{mgmt_max_queue => exceeded}; + false -> + State + end. + +-spec resend_unacked_stanzas(state()) -> state(). +resend_unacked_stanzas(#{mgmt_state := MgmtState, + mgmt_queue := {QueueLen, _} = Queue, + jid := JID} = State) + when (MgmtState == active orelse + MgmtState == pending orelse + MgmtState == timeout) andalso QueueLen > 0 -> + ?DEBUG("Resending ~B unacknowledged stanza(s) to ~s", + [QueueLen, jid:to_string(JID)]), + queue_foldl( + fun({_, Time, Pkt}, AccState) -> + NewPkt = add_resent_delay_info(AccState, Pkt, Time), + send(AccState, NewPkt) + end, State, Queue); +resend_unacked_stanzas(State) -> + State. + +-spec route_unacked_stanzas(state()) -> ok. +route_unacked_stanzas(#{mgmt_state := MgmtState, + mgmt_resend := MgmtResend, + lang := Lang, user := User, + jid := JID, lserver := LServer, + mgmt_queue := {QueueLen, _} = Queue, + resource := Resource} = State) + when (MgmtState == active orelse + MgmtState == pending orelse + MgmtState == timeout) andalso QueueLen > 0 -> + ResendOnTimeout = case MgmtResend of + Resend when is_boolean(Resend) -> + Resend; + if_offline -> + case ejabberd_sm:get_user_resources(User, Resource) of + [Resource] -> + %% Same resource opened new session + true; + [] -> true; + _ -> false + end + end, + ?DEBUG("Re-routing ~B unacknowledged stanza(s) to ~s", + [QueueLen, jid:to_string(JID)]), + queue_foreach( + fun({_, _Time, #presence{from = From}}) -> + ?DEBUG("Dropping presence stanza from ~s", [jid:to_string(From)]); + ({_, _Time, #iq{} = El}) -> + Txt = <<"User session terminated">>, + route_error(El, xmpp:err_service_unavailable(Txt, Lang)); + ({_, _Time, #message{from = From, meta = #{carbon_copy := true}}}) -> + %% XEP-0280 says: "When a receiving server attempts to deliver a + %% forked message, and that message bounces with an error for + %% any reason, the receiving server MUST NOT forward that error + %% back to the original sender." Resending such a stanza could + %% easily lead to unexpected results as well. + ?DEBUG("Dropping forwarded message stanza from ~s", + [jid:to_string(From)]); + ({_, Time, #message{} = Msg}) -> + case ejabberd_hooks:run_fold(message_is_archived, + LServer, false, + [State, Msg]) of + true -> + ?DEBUG("Dropping archived message stanza from ~s", + [jid:to_string(xmpp:get_from(Msg))]); + false when ResendOnTimeout -> + NewEl = add_resent_delay_info(State, Msg, Time), + route(NewEl); + false -> + Txt = <<"User session terminated">>, + route_error(Msg, xmpp:err_service_unavailable(Txt, Lang)) + end; + ({_, _Time, El}) -> + %% Raw element of type 'error' resulting from a validation error + %% We cannot pass it to the router, it will generate an error + ?DEBUG("Do not route raw element from ack queue: ~p", [El]) + end, Queue); +route_unacked_stanzas(_State) -> + ok. + +-spec inherit_session_state(state(), binary()) -> {ok, state()} | + {error, binary()} | + {error, binary(), non_neg_integer()}. +inherit_session_state(#{user := U, server := S} = State, ResumeID) -> + case jlib:base64_to_term(ResumeID) of + {term, {R, Time}} -> + case ejabberd_sm:get_session_pid(U, S, R) of + none -> + case ejabberd_sm:get_offline_info(Time, U, S, R) of + none -> + {error, <<"Previous session PID not found">>}; + Info -> + case proplists:get_value(num_stanzas_in, Info) of + undefined -> + {error, <<"Previous session timed out">>}; + H -> + {error, <<"Previous session timed out">>, H} + end + end; + OldPID -> + OldSID = {Time, OldPID}, + try resume_session(OldSID, State) of + {resume, #{mgmt_xmlns := Xmlns, + mgmt_queue := Queue, + mgmt_timeout := Timeout, + mgmt_stanzas_in := NumStanzasIn, + mgmt_stanzas_out := NumStanzasOut} = OldState} -> + State1 = ejabberd_c2s:copy_state(State, OldState), + State2 = State1#{mgmt_xmlns => Xmlns, + mgmt_queue => Queue, + mgmt_timeout => Timeout, + mgmt_stanzas_in => NumStanzasIn, + mgmt_stanzas_out => NumStanzasOut, + mgmt_state => active}, + ejabberd_sm:close_session(OldSID, U, S, R), + State3 = ejabberd_c2s:open_session(State2), + ejabberd_c2s:stop(OldPID), + {ok, State3}; + {error, Msg} -> + {error, Msg} + catch exit:{noproc, _} -> + {error, <<"Previous session PID is dead">>}; + exit:{timeout, _} -> + {error, <<"Session state copying timed out">>} + end + end; + _ -> + {error, <<"Invalid 'previd' value">>} + end. + +-spec resume_session({integer(), pid()}, state()) -> {resume, state()} | + {error, binary()}. +resume_session({Time, Pid}, _State) -> + ejabberd_c2s:call(Pid, {resume_session, Time}, timer:seconds(15)). + +-spec make_resume_id(state()) -> binary(). +make_resume_id(#{sid := {Time, _}, resource := Resource}) -> + jlib:term_to_base64({Resource, Time}). + +-spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza(); + (state(), xmlel(), erlang:timestamp()) -> xmlel(). +add_resent_delay_info(#{lserver := LServer}, El, Time) + when is_record(El, message); is_record(El, presence) -> + xmpp_util:add_delay_info(El, jid:make(LServer), Time, <<"Resent">>); +add_resent_delay_info(_State, El, _Time) -> + El. + +-spec route(stanza()) -> ok. +route(Pkt) -> + From = xmpp:get_from(Pkt), + To = xmpp:get_to(Pkt), + ejabberd_router:route(From, To, Pkt). + +-spec route_error(stanza(), stanza_error()) -> ok. +route_error(Pkt, Err) -> + From = xmpp:get_from(Pkt), + To = xmpp:get_to(Pkt), + ejabberd_router:route_error(To, From, Pkt, Err). + +-spec send(state(), xmpp_element()) -> state(). +send(#{mod := Mod} = State, Pkt) -> + Mod:send(State, Pkt). + +-spec queue_new() -> lqueue(). +queue_new() -> + {0, queue:new()}. + +-spec queue_in(term(), lqueue()) -> lqueue(). +queue_in(Elem, {N, Q}) -> + {N+1, queue:in(Elem, Q)}. + +-spec queue_len(lqueue()) -> non_neg_integer(). +queue_len({N, _}) -> + N. + +-spec queue_foldl(fun((term(), T) -> T), T, lqueue()) -> T. +queue_foldl(F, Acc, {_N, Q}) -> + jlib:queue_foldl(F, Acc, Q). + +-spec queue_foreach(fun((_) -> _), lqueue()) -> ok. +queue_foreach(F, {_N, Q}) -> + jlib:queue_foreach(F, Q). + +-spec queue_dropwhile(fun((term()) -> boolean()), lqueue()) -> lqueue(). +queue_dropwhile(F, {N, Q}) -> + case queue:peek(Q) of + {value, Item} -> + case F(Item) of + true -> + queue_dropwhile(F, {N-1, queue:drop(Q)}); + false -> + {N, Q} + end; + empty -> + {N, Q} + end. + +-spec queue_is_empty(lqueue()) -> boolean(). +queue_is_empty({N, _Q}) -> + N == 0. + +-spec cancel_ack_timer(state()) -> state(). +cancel_ack_timer(#{mgmt_ack_timer := TRef} = State) -> + case erlang:cancel_timer(TRef) of + false -> + receive {timeout, TRef, _} -> ok + after 0 -> ok + end; + _ -> + ok + end, + maps:remove(mgmt_ack_timer, State); +cancel_ack_timer(State) -> + State. + +-spec bounce_message_queue() -> ok. +bounce_message_queue() -> + receive {route, From, To, Pkt} -> + ejabberd_router:route(From, To, Pkt), + bounce_message_queue() + after 0 -> + ok + end. + +%%%=================================================================== +%%% Configuration processing +%%%=================================================================== +get_max_ack_queue(Host, Opts) -> + VFun = mod_opt_type(max_ack_queue), + case gen_mod:get_module_opt(Host, ?MODULE, max_ack_queue, VFun) of + undefined -> gen_mod:get_opt(max_ack_queue, Opts, VFun, 1000); + Limit -> Limit + end. + +get_resume_timeout(Host, Opts) -> + VFun = mod_opt_type(resume_timeout), + case gen_mod:get_module_opt(Host, ?MODULE, resume_timeout, VFun) of + undefined -> gen_mod:get_opt(resume_timeout, Opts, VFun, 300); + Timeout -> Timeout + end. + +get_max_resume_timeout(Host, Opts, ResumeTimeout) -> + VFun = mod_opt_type(max_resume_timeout), + case gen_mod:get_module_opt(Host, ?MODULE, max_resume_timeout, VFun) of + undefined -> + case gen_mod:get_opt(max_resume_timeout, Opts, VFun) of + undefined -> ResumeTimeout; + Max when Max >= ResumeTimeout -> Max; + _ -> ResumeTimeout + end; + Max when Max >= ResumeTimeout -> Max; + _ -> ResumeTimeout + end. + +get_ack_timeout(Host, Opts) -> + VFun = mod_opt_type(ack_timeout), + T = case gen_mod:get_module_opt(Host, ?MODULE, ack_timeout, VFun) of + undefined -> gen_mod:get_opt(ack_timeout, Opts, VFun, 60); + AckTimeout -> AckTimeout + end, + case T of + infinity -> infinity; + _ -> timer:seconds(T) + end. + +get_resend_on_timeout(Host, Opts) -> + VFun = mod_opt_type(resend_on_timeout), + case gen_mod:get_module_opt(Host, ?MODULE, resend_on_timeout, VFun) of + undefined -> gen_mod:get_opt(resend_on_timeout, Opts, VFun, false); + Resend -> Resend + end. + +mod_opt_type(max_ack_queue) -> + fun(I) when is_integer(I), I > 0 -> I; + (infinity) -> infinity + end; +mod_opt_type(resume_timeout) -> + fun(I) when is_integer(I), I >= 0 -> I end; +mod_opt_type(max_resume_timeout) -> + fun(I) when is_integer(I), I >= 0 -> I end; +mod_opt_type(ack_timeout) -> + fun(I) when is_integer(I), I > 0 -> I; + (infinity) -> infinity + end; +mod_opt_type(resend_on_timeout) -> + fun(B) when is_boolean(B) -> B; + (if_offline) -> if_offline + end; +mod_opt_type(_) -> + [max_ack_queue, resume_timeout, max_resume_timeout, ack_timeout, + resend_on_timeout]. diff --git a/src/mod_vcard_xupdate.erl b/src/mod_vcard_xupdate.erl index 378faeba9..12e31b8dc 100644 --- a/src/mod_vcard_xupdate.erl +++ b/src/mod_vcard_xupdate.erl @@ -30,7 +30,7 @@ %% gen_mod callbacks -export([start/2, stop/1]). --export([update_presence/3, vcard_set/3, export/1, +-export([update_presence/1, vcard_set/3, export/1, import_info/0, import/5, import_start/2, mod_opt_type/1, depends/2]). @@ -51,14 +51,14 @@ start(Host, Opts) -> Mod = gen_mod:db_mod(Host, Opts, ?MODULE), Mod:init(Host, Opts), - ejabberd_hooks:add(c2s_update_presence, Host, ?MODULE, + ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, update_presence, 100), ejabberd_hooks:add(vcard_set, Host, ?MODULE, vcard_set, 100), ok. stop(Host) -> - ejabberd_hooks:delete(c2s_update_presence, Host, + ejabberd_hooks:delete(c2s_self_presence, Host, ?MODULE, update_presence, 100), ejabberd_hooks:delete(vcard_set, Host, ?MODULE, vcard_set, 100), @@ -70,10 +70,15 @@ depends(_Host, _Opts) -> %%==================================================================== %% Hooks %%==================================================================== --spec update_presence(presence(), binary(), binary()) -> presence(). -update_presence(#presence{type = available} = Packet, User, Host) -> - presence_with_xupdate(Packet, User, Host); -update_presence(Packet, _User, _Host) -> Packet. +-spec update_presence({presence(), ejabberd_c2s:state()}) + -> {presence(), ejabberd_c2s:state()}. +update_presence({#presence{type = available} = Pres, + #{jid := #jid{luser = LUser, lserver = LServer}} = State}) -> + Hash = get_xupdate(LUser, LServer), + Pres1 = xmpp:set_subtag(Pres, #vcard_xupdate{hash = Hash}), + {Pres1, State}; +update_presence(Acc) -> + Acc. -spec vcard_set(binary(), binary(), xmlel()) -> ok. vcard_set(LUser, LServer, VCARD) -> @@ -104,15 +109,6 @@ remove_xupdate(LUser, LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:remove_xupdate(LUser, LServer). -%%%---------------------------------------------------------------------- -%%% Presence stanza rebuilding -%%%---------------------------------------------------------------------- - -presence_with_xupdate(Presence, User, Host) -> - Hash = get_xupdate(User, Host), - Presence1 = xmpp:remove_subtag(Presence, #vcard_xupdate{}), - xmpp:set_subtag(Presence1, #vcard_xupdate{hash = Hash}). - import_info() -> [{<<"vcard_xupdate">>, 3}]. @@ -128,5 +124,8 @@ export(LServer) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:export(LServer). +%%==================================================================== +%% Options +%%==================================================================== mod_opt_type(db_type) -> fun(T) -> ejabberd_config:v_db(?MODULE, T) end; mod_opt_type(_) -> [db_type]. diff --git a/src/scram.erl b/src/scram.erl index cd62112b2..ee7960475 100644 --- a/src/scram.erl +++ b/src/scram.erl @@ -60,9 +60,7 @@ client_signature(StoredKey, AuthMessage) -> -spec client_key(binary(), binary()) -> binary(). client_key(ClientProof, ClientSignature) -> - list_to_binary(lists:zipwith(fun (X, Y) -> X bxor Y end, - binary_to_list(ClientProof), - binary_to_list(ClientSignature))). + crypto:exor(ClientProof, ClientSignature). -spec server_signature(binary(), binary()) -> binary(). @@ -71,19 +69,13 @@ server_signature(ServerKey, AuthMessage) -> hi(Password, Salt, IterationCount) -> U1 = sha_mac(Password, <>), - list_to_binary(lists:zipwith(fun (X, Y) -> X bxor Y end, - binary_to_list(U1), - binary_to_list(hi_round(Password, U1, - IterationCount - 1)))). + crypto:exor(U1, hi_round(Password, U1, IterationCount - 1)). hi_round(Password, UPrev, 1) -> sha_mac(Password, UPrev); hi_round(Password, UPrev, IterationCount) -> U = sha_mac(Password, UPrev), - list_to_binary(lists:zipwith(fun (X, Y) -> X bxor Y end, - binary_to_list(U), - binary_to_list(hi_round(Password, U, - IterationCount - 1)))). + crypto:exor(U, hi_round(Password, U, IterationCount - 1)). sha_mac(Key, Data) -> crypto:hmac(sha, Key, Data). diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl new file mode 100644 index 000000000..dd135df1e --- /dev/null +++ b/src/xmpp_stream_in.erl @@ -0,0 +1,1167 @@ +%%%------------------------------------------------------------------- +%%% Created : 26 Nov 2016 by Evgeny Khramtsov +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%------------------------------------------------------------------- +-module(xmpp_stream_in). +-define(GEN_SERVER, gen_server). +-behaviour(?GEN_SERVER). + +-protocol({rfc, 6120}). +-protocol({xep, 114, '1.6'}). + +%% API +-export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1, + send/2, close/1, close/2, send_error/3, establish/1, + get_transport/1, change_shaper/2, set_timeout/2, format_error/1]). + +%% gen_server callbacks +-export([init/1, handle_cast/2, handle_call/3, handle_info/2, + terminate/2, code_change/3]). + +%%-define(DBGFSM, true). +-ifdef(DBGFSM). +-define(FSMOPTS, [{debug, [trace]}]). +-else. +-define(FSMOPTS, []). +-endif. + +-include("xmpp.hrl"). +-type state() :: map(). +-type stop_reason() :: {stream, reset | {in | out, stream_error()}} | + {tls, inet:posix() | atom() | binary()} | + {socket, inet:posix() | closed | timeout} | + internal_failure. + +-callback init(list()) -> {ok, state()} | {error, term()} | ignore. +-callback handle_cast(term(), state()) -> state(). +-callback handle_call(term(), term(), state()) -> state(). +-callback handle_info(term(), state()) -> state(). +-callback terminate(term(), state()) -> any(). +-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. +-callback handle_stream_start(stream_start(), state()) -> state(). +-callback handle_stream_established(state()) -> state(). +-callback handle_stream_end(stop_reason(), state()) -> state(). +-callback handle_cdata(binary(), state()) -> state(). +-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state(). +-callback handle_authenticated_packet(xmpp_element(), state()) -> state(). +-callback handle_unbinded_packet(xmpp_element(), state()) -> state(). +-callback handle_auth_success(binary(), binary(), module(), state()) -> state(). +-callback handle_auth_failure(binary(), binary(), binary(), state()) -> state(). +-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state(). +-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state(). +-callback handle_timeout(state()) -> state(). +-callback get_password_fun(state()) -> fun(). +-callback check_password_fun(state()) -> fun(). +-callback check_password_digest_fun(state()) -> fun(). +-callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}. +-callback compress_methods(state()) -> [binary()]. +-callback tls_options(state()) -> [proplists:property()]. +-callback tls_required(state()) -> boolean(). +-callback tls_verify(state()) -> boolean(). +-callback tls_enabled(state()) -> boolean(). +-callback sasl_mechanisms([cyrsasl:mechanism()], state()) -> [cyrsasl:mechanism()]. +-callback unauthenticated_stream_features(state()) -> [xmpp_element()]. +-callback authenticated_stream_features(state()) -> [xmpp_element()]. + +%% All callbacks are optional +-optional_callbacks([init/1, + handle_cast/2, + handle_call/3, + handle_info/2, + terminate/2, + code_change/3, + handle_stream_start/2, + handle_stream_established/1, + handle_stream_end/2, + handle_cdata/2, + handle_authenticated_packet/2, + handle_unauthenticated_packet/2, + handle_unbinded_packet/2, + handle_auth_success/4, + handle_auth_failure/4, + handle_send/3, + handle_recv/3, + handle_timeout/1, + get_password_fun/1, + check_password_fun/1, + check_password_digest_fun/1, + bind/2, + compress_methods/1, + tls_options/1, + tls_required/1, + tls_verify/1, + tls_enabled/1, + sasl_mechanisms/2, + unauthenticated_stream_features/1, + authenticated_stream_features/1]). + +%%%=================================================================== +%%% API +%%%=================================================================== +start(Mod, Args, Opts) -> + ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + +start_link(Mod, Args, Opts) -> + ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + +call(Ref, Msg, Timeout) -> + ?GEN_SERVER:call(Ref, Msg, Timeout). + +cast(Ref, Msg) -> + ?GEN_SERVER:cast(Ref, Msg). + +reply(Ref, Reply) -> + ?GEN_SERVER:reply(Ref, Reply). + +-spec stop(pid()) -> ok; + (state()) -> no_return(). +stop(Pid) when is_pid(Pid) -> + cast(Pid, stop); +stop(#{owner := Owner} = State) when Owner == self() -> + terminate(normal, State), + exit(normal); +stop(_) -> + erlang:error(badarg). + +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Pid, Pkt) when is_pid(Pid) -> + cast(Pid, {send, Pkt}); +send(#{owner := Owner} = State, Pkt) when Owner == self() -> + send_pkt(State, Pkt); +send(_, _) -> + erlang:error(badarg). + +-spec close(pid()) -> ok; + (state()) -> state(). +close(Ref) -> + close(Ref, true). + +-spec close(pid(), boolean()) -> ok; + (state(), boolean()) -> state(). +close(Pid, SendTrailer) when is_pid(Pid) -> + cast(Pid, {close, SendTrailer}); +close(#{owner := Owner} = State, SendTrailer) when Owner == self() -> + if SendTrailer -> send_trailer(State); + true -> close_socket(State) + end; +close(_, _) -> + erlang:error(badarg). + +-spec establish(state()) -> state(). +establish(State) -> + process_stream_established(State). + +-spec set_timeout(state(), non_neg_integer() | infinity) -> state(). +set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> + case Timeout of + infinity -> State#{stream_timeout => infinity}; + _ -> + Time = p1_time_compat:monotonic_time(milli_seconds), + State#{stream_timeout => {Timeout, Time}} + end; +set_timeout(_, _) -> + erlang:error(badarg). + +get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner}) + when Owner == self() -> + SockMod:get_transport(Socket); +get_transport(_) -> + erlang:error(badarg). + +-spec change_shaper(state(), shaper:shaper()) -> ok. +change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper) + when Owner == self() -> + SockMod:change_shaper(Socket, Shaper); +change_shaper(_, _) -> + erlang:error(badarg). + +-spec format_error(stop_reason()) -> binary(). +format_error({socket, Reason}) -> + format("Connection failed: ~s", [format_inet_error(Reason)]); +format_error({stream, reset}) -> + <<"Stream reset by peer">>; +format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) -> + format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]); +format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) -> + format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]); +format_error({tls, Reason}) -> + format("TLS failed: ~s", [format_tls_error(Reason)]); +format_error(internal_failure) -> + <<"Internal server error">>; +format_error(Err) -> + format("Unrecognized error: ~w", [Err]). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([Module, {SockMod, Socket}, Opts]) -> + Encrypted = proplists:get_bool(tls, Opts), + SocketMonitor = SockMod:monitor(Socket), + case SockMod:peername(Socket) of + {ok, IP} -> + Time = p1_time_compat:monotonic_time(milli_seconds), + State = #{owner => self(), + mod => Module, + socket => Socket, + sockmod => SockMod, + socket_monitor => SocketMonitor, + stream_timeout => {timer:seconds(30), Time}, + stream_direction => in, + stream_id => new_id(), + stream_state => wait_for_stream, + stream_header_sent => false, + stream_restarted => false, + stream_compressed => false, + stream_encrypted => Encrypted, + stream_version => {1,0}, + stream_authenticated => false, + xmlns => ?NS_CLIENT, + lang => <<"">>, + user => <<"">>, + server => <<"">>, + resource => <<"">>, + lserver => <<"">>, + ip => IP}, + case try Module:init([State, Opts]) + catch _:undef -> {ok, State} + end of + {ok, State1} when not Encrypted -> + {_, State2, Timeout} = noreply(State1), + {ok, State2, Timeout}; + {ok, State1} when Encrypted -> + TLSOpts = try Module:tls_options(State1) + catch _:undef -> [] + end, + case SockMod:starttls(Socket, TLSOpts) of + {ok, TLSSocket} -> + State2 = State1#{socket => TLSSocket}, + {_, State3, Timeout} = noreply(State2), + {ok, State3, Timeout}; + {error, Reason} -> + {stop, Reason} + end; + {error, Reason} -> + {stop, Reason}; + ignore -> + ignore + end; + {error, _Reason} -> + ignore + end. + +handle_cast({send, Pkt}, State) -> + noreply(send_pkt(State, Pkt)); +handle_cast(stop, State) -> + {stop, normal, State}; +handle_cast(Cast, #{mod := Mod} = State) -> + noreply(try Mod:handle_cast(Cast, State) + catch _:undef -> State + end). + +handle_call(Call, From, #{mod := Mod} = State) -> + noreply(try Mod:handle_call(Call, From, State) + catch _:undef -> State + end). + +handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, + #{stream_state := wait_for_stream, + xmlns := XMLNS, lang := MyLang} = State) -> + El = #xmlel{name = Name, attrs = Attrs}, + noreply( + try xmpp:decode(El, XMLNS, []) of + #stream_start{} = Pkt -> + State1 = send_header(State, Pkt), + case is_disconnected(State1) of + true -> State1; + false -> process_stream(Pkt, State1) + end; + _ -> + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> send_pkt(State1, xmpp:serr_invalid_xml()) + end + catch _:{xmpp_codec, Why} -> + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + Err = xmpp:serr_invalid_xml(Txt, Lang), + send_pkt(State1, Err) + end + end); +handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> + State1 = send_header(State), + noreply( + case is_disconnected(State1) of + true -> State1; + false -> + Err = case Reason of + <<"XML stanza is too big">> -> + xmpp:serr_policy_violation(Reason, Lang); + {_, Txt} -> + xmpp:serr_not_well_formed(Txt, Lang) + end, + send_pkt(State1, Err) + end); +handle_info({'$gen_event', {xmlstreamelement, El}}, + #{xmlns := NS, mod := Mod} = State) -> + noreply( + try xmpp:decode(El, NS, [ignore_els]) of + Pkt -> + State1 = try Mod:handle_recv(El, Pkt, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> process_element(Pkt, State1) + end + catch _:{xmpp_codec, Why} -> + State1 = try Mod:handle_recv(El, {error, Why}, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> process_invalid_xml(State1, El, Why) + end + end); +handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, + #{mod := Mod} = State) -> + noreply(try Mod:handle_cdata(Data, State) + catch _:undef -> State + end); +handle_info({'$gen_event', {xmlstreamend, _}}, State) -> + noreply(process_stream_end({stream, reset}, State)); +handle_info({'$gen_event', closed}, State) -> + noreply(process_stream_end({socket, closed}, State)); +handle_info(timeout, #{mod := Mod} = State) -> + Disconnected = is_disconnected(State), + noreply(try Mod:handle_timeout(State) + catch _:undef when not Disconnected -> + send_pkt(State, xmpp:serr_connection_timeout()); + _:undef -> + stop(State) + end); +handle_info({'DOWN', MRef, _Type, _Object, _Info}, + #{socket_monitor := MRef} = State) -> + noreply(process_stream_end({socket, closed}, State)); +handle_info(Info, #{mod := Mod} = State) -> + noreply(try Mod:handle_info(Info, State) + catch _:undef -> State + end). + +terminate(Reason, #{mod := Mod} = State) -> + case get(already_terminated) of + true -> + State; + _ -> + put(already_terminated, true), + try Mod:terminate(Reason, State) + catch _:undef -> ok + end, + send_trailer(State) + end. + +code_change(OldVsn, #{mod := Mod} = State, Extra) -> + Mod:code_change(OldVsn, State, Extra). + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. +noreply(#{stream_timeout := infinity} = State) -> + {noreply, State, infinity}; +noreply(#{stream_timeout := {MSecs, StartTime}} = State) -> + CurrentTime = p1_time_compat:monotonic_time(milli_seconds), + Timeout = max(0, MSecs - CurrentTime + StartTime), + {noreply, State, Timeout}. + +-spec new_id() -> binary(). +new_id() -> + randoms:get_string(). + +-spec is_disconnected(state()) -> boolean(). +is_disconnected(#{stream_state := StreamState}) -> + StreamState == disconnected. + +-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state(). +process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> + case xmpp:is_stanza(El) of + true -> + Txt = xmpp:io_format_error(Reason), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + send_error(State, El, xmpp:err_bad_request(Txt, Lang)); + false -> + case {xmpp:get_name(El), xmpp:get_ns(El)} of + {Tag, ?NS_SASL} when Tag == <<"auth">>; + Tag == <<"response">>; + Tag == <<"abort">> -> + Txt = xmpp:io_format_error(Reason), + Err = #sasl_failure{reason = 'malformed-request', + text = xmpp:mk_text(Txt, MyLang)}, + send_pkt(State, Err); + {<<"starttls">>, ?NS_TLS} -> + send_pkt(State, #starttls_failure{}); + {<<"compress">>, ?NS_COMPRESS} -> + Err = #compress_failure{reason = 'setup-failed'}, + send_pkt(State, Err); + _ -> + %% Maybe add something more? + State + end + end. + +-spec process_stream_end(stop_reason(), state()) -> state(). +process_stream_end(_, #{stream_state := disconnected} = State) -> + State; +process_stream_end(Reason, #{mod := Mod} = State) -> + State1 = send_trailer(State), + try Mod:handle_stream_end(Reason, State1) + catch _:undef -> stop(State1) + end. + +-spec process_stream(stream_start(), state()) -> state(). +process_stream(#stream_start{xmlns = XML_NS, + stream_xmlns = STREAM_NS}, + #{xmlns := NS} = State) + when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> + send_pkt(State, xmpp:serr_invalid_namespace()); +process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> + send_pkt(State, xmpp:serr_unsupported_version()); +process_stream(#stream_start{lang = Lang}, + #{xmlns := ?NS_CLIENT, lang := DefaultLang} = State) + when size(Lang) > 35 -> + %% As stated in BCP47, 4.4.1: + %% Protocols or specifications that specify limited buffer sizes for + %% language tags MUST allow for language tags of at least 35 characters. + %% Do not store long language tag to avoid possible DoS/flood attacks + Txt = <<"Too long value of 'xml:lang' attribute">>, + send_pkt(State, xmpp:serr_policy_violation(Txt, DefaultLang)); +process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) -> + Txt = <<"Missing 'to' attribute">>, + send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)); +process_stream(#stream_start{to = #jid{luser = U, lresource = R}}, + #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> -> + Txt = <<"Improper 'to' attribute">>, + send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)); +process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart, + #{xmlns := ?NS_COMPONENT, mod := Mod} = State) -> + State1 = State#{remote_server => RemoteServer, + stream_state => wait_for_handshake}, + try Mod:handle_stream_start(StreamStart, State1) + catch _:undef -> State1 + end; +process_stream(#stream_start{to = #jid{server = Server, lserver = LServer}, + from = From} = StreamStart, + #{stream_authenticated := Authenticated, + stream_restarted := StreamWasRestarted, + mod := Mod, xmlns := NS, resource := Resource, + stream_encrypted := Encrypted} = State) -> + State1 = if not StreamWasRestarted -> + State#{server => Server, lserver => LServer}; + true -> + State + end, + State2 = case From of + #jid{lserver = RemoteServer} when NS == ?NS_SERVER -> + State1#{remote_server => RemoteServer}; + _ -> + State1 + end, + State3 = try Mod:handle_stream_start(StreamStart, State2) + catch _:undef -> State2 + end, + case is_disconnected(State3) of + true -> State3; + false -> + State4 = send_features(State3), + case is_disconnected(State4) of + true -> State4; + false -> + TLSRequired = is_starttls_required(State4), + if not Authenticated and (TLSRequired and not Encrypted) -> + State4#{stream_state => wait_for_starttls}; + not Authenticated -> + State4#{stream_state => wait_for_sasl_request}; + (NS == ?NS_CLIENT) and (Resource == <<"">>) -> + State4#{stream_state => wait_for_bind}; + true -> + process_stream_established(State4) + end + end + end. + +-spec process_element(xmpp_element(), state()) -> state(). +process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> + case Pkt of + #starttls{} when StateName == wait_for_starttls; + StateName == wait_for_sasl_request -> + process_starttls(State); + #starttls{} -> + process_starttls_failure(unexpected_starttls_request, State); + #sasl_auth{} when StateName == wait_for_starttls -> + send_pkt(State, #sasl_failure{reason = 'encryption-required'}); + #sasl_auth{} when StateName == wait_for_sasl_request -> + process_sasl_request(Pkt, State); + #sasl_auth{} when StateName == wait_for_sasl_response -> + process_sasl_request(Pkt, maps:remove(sasl_state, State)); + #sasl_auth{} -> + Txt = <<"SASL negotiation is not allowed in this state">>, + send_pkt(State, #sasl_failure{reason = 'not-authorized', + text = xmpp:mk_text(Txt, Lang)}); + #sasl_response{} when StateName == wait_for_starttls -> + send_pkt(State, #sasl_failure{reason = 'encryption-required'}); + #sasl_response{} when StateName == wait_for_sasl_response -> + process_sasl_response(Pkt, State); + #sasl_response{} -> + Txt = <<"SASL negotiation is not allowed in this state">>, + send_pkt(State, #sasl_failure{reason = 'not-authorized', + text = xmpp:mk_text(Txt, Lang)}); + #sasl_abort{} when StateName == wait_for_sasl_response -> + process_sasl_abort(State); + #sasl_abort{} -> + send_pkt(State, #sasl_failure{reason = 'aborted'}); + #sasl_success{} -> + State; + #compress{} when StateName == wait_for_sasl_response -> + send_pkt(State, #compress_failure{reason = 'setup-failed'}); + #compress{} -> + process_compress(Pkt, State); + #handshake{} when StateName == wait_for_handshake -> + process_handshake(Pkt, State); + #handshake{} -> + State; + #stream_error{} -> + process_stream_end({stream, {in, Pkt}}, State); + _ when StateName == wait_for_sasl_request; + StateName == wait_for_handshake; + StateName == wait_for_sasl_response -> + process_unauthenticated_packet(Pkt, State); + _ when StateName == wait_for_starttls -> + Txt = <<"Use of STARTTLS required">>, + Err = xmpp:err_policy_violation(Txt, Lang), + send_error(State, Pkt, Err); + _ when StateName == wait_for_bind -> + process_bind(Pkt, State); + _ when StateName == established -> + process_authenticated_packet(Pkt, State) + end. + +-spec process_unauthenticated_packet(xmpp_element(), state()) -> state(). +process_unauthenticated_packet(Pkt, #{mod := Mod} = State) -> + NewPkt = set_lang(Pkt, State), + try Mod:handle_unauthenticated_packet(NewPkt, State) + catch _:undef -> + Err = xmpp:serr_not_authorized(), + send(State, Err) + end. + +-spec process_authenticated_packet(xmpp_element(), state()) -> state(). +process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) -> + Pkt1 = set_lang(Pkt, State), + case set_from_to(Pkt1, State) of + {ok, #iq{type = set, sub_els = [_]} = Pkt2} when NS == ?NS_CLIENT -> + case xmpp:get_subtag(Pkt2, #xmpp_session{}) of + #xmpp_session{} -> + send_pkt(State, xmpp:make_iq_result(Pkt2)); + _ -> + try Mod:handle_authenticated_packet(Pkt2, State) + catch _:undef -> + Err = xmpp:err_service_unavailable(), + send_error(State, Pkt, Err) + end + end; + {ok, Pkt2} -> + try Mod:handle_authenticated_packet(Pkt2, State) + catch _:undef -> + Err = xmpp:err_service_unavailable(), + send_error(State, Pkt, Err) + end; + {error, Err} -> + send_pkt(State, Err) + end. + +-spec process_bind(xmpp_element(), state()) -> state(). +process_bind(#iq{type = set, sub_els = [_]} = Pkt, + #{xmlns := ?NS_CLIENT, mod := Mod, lang := Lang} = State) -> + case xmpp:get_subtag(Pkt, #bind{}) of + #bind{resource = R} -> + case jid:resourceprep(R) of + error -> + Txt = <<"Malformed resource">>, + Err = xmpp:err_bad_request(Txt, Lang), + send_error(State, Pkt, Err); + _ -> + case Mod:bind(R, State) of + {ok, #{user := U, + server := S, + resource := NewR} = State1} when NewR /= <<"">> -> + Reply = #bind{jid = jid:make(U, S, NewR)}, + State2 = send_pkt(State1, xmpp:make_iq_result(Pkt, Reply)), + process_stream_established(State2); + {error, #stanza_error{}, State1} = Err -> + send_error(State1, Pkt, Err) + end + end; + _ -> + try Mod:handle_unbinded_packet(Pkt, State) + catch _:undef -> + Err = xmpp:err_not_authorized(), + send_error(State, Pkt, Err) + end + end; +process_bind(Pkt, #{mod := Mod} = State) -> + try Mod:handle_unbinded_packet(Pkt, State) + catch _:undef -> + Err = xmpp:err_not_authorized(), + send_error(State, Pkt, Err) + end. + +-spec process_handshake(handshake(), state()) -> state(). +process_handshake(#handshake{data = Digest}, + #{mod := Mod, stream_id := StreamID, + remote_server := RemoteServer} = State) -> + GetPW = try Mod:get_password_fun(State) + catch _:undef -> fun(_) -> {false, undefined} end + end, + AuthRes = case GetPW(<<"">>) of + {false, _} -> + false; + {Password, _} -> + p1_sha:sha(<>) == Digest + end, + case AuthRes of + true -> + State1 = try Mod:handle_auth_success( + RemoteServer, <<"handshake">>, undefined, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + State2 = send_pkt(State1, #handshake{}), + process_stream_established(State2) + end; + false -> + State1 = try Mod:handle_auth_failure( + RemoteServer, <<"handshake">>, <<"not authorized">>, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> send_pkt(State1, xmpp:serr_not_authorized()) + end + end. + +-spec process_stream_established(state()) -> state(). +process_stream_established(#{stream_state := StateName} = State) + when StateName == disconnected; StateName == established -> + State; +process_stream_established(#{mod := Mod} = State) -> + State1 = State#{stream_authenticated := true, + stream_state => established, + stream_timeout => infinity}, + try Mod:handle_stream_established(State1) + catch _:undef -> State1 + end. + +-spec process_compress(compress(), state()) -> state(). +process_compress(#compress{}, #{stream_compressed := true} = State) -> + send_pkt(State, #compress_failure{reason = 'setup-failed'}); +process_compress(#compress{methods = HisMethods}, + #{socket := Socket, sockmod := SockMod, mod := Mod} = State) -> + MyMethods = try Mod:compress_methods(State) + catch _:undef -> [] + end, + CommonMethods = lists_intersection(MyMethods, HisMethods), + case lists:member(<<"zlib">>, CommonMethods) of + true -> + State1 = send_pkt(State, #compressed{}), + case is_disconnected(State1) of + true -> State1; + false -> + case SockMod:compress(Socket) of + {ok, ZlibSocket} -> + State1#{socket => ZlibSocket, + stream_id => new_id(), + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + stream_compressed => true}; + {error, _} -> + Err = #compress_failure{reason = 'setup-failed'}, + send_pkt(State1, Err) + end + end; + false -> + send_pkt(State, #compress_failure{reason = 'unsupported-method'}) + end. + +-spec process_starttls(state()) -> state(). +process_starttls(#{stream_encrypted := true} = State) -> + process_starttls_failure(already_encrypted, State); +process_starttls(#{socket := Socket, + sockmod := SockMod, mod := Mod} = State) -> + case is_starttls_available(State) of + true -> + TLSOpts = try Mod:tls_options(State) + catch _:undef -> [] + end, + case SockMod:starttls(Socket, TLSOpts) of + {ok, TLSSocket} -> + State1 = send_pkt(State, #starttls_proceed{}), + case is_disconnected(State1) of + true -> State1; + false -> + State1#{socket => TLSSocket, + stream_id => new_id(), + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + stream_encrypted => true} + end; + {error, Reason} -> + process_starttls_failure(Reason, State) + end; + false -> + process_starttls_failure(starttls_unsupported, State) + end. + +-spec process_starttls_failure(term(), state()) -> state(). +process_starttls_failure(Why, State) -> + State1 = send_pkt(State, #starttls_failure{}), + case is_disconnected(State1) of + true -> State1; + false -> process_stream_end({tls, Why}, State1) + end. + +-spec process_sasl_request(sasl_auth(), state()) -> state(). +process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, + #{mod := Mod, lserver := LServer} = State) -> + State1 = State#{sasl_mech => Mech}, + Mechs = get_sasl_mechanisms(State1), + case lists:member(Mech, Mechs) of + true when Mech == <<"EXTERNAL">> -> + Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of + {ok, Peer} -> + {ok, [{auth_module, pkix}, {username, Peer}]}; + {error, Reason, Peer} -> + {error, Reason, Peer} + end, + process_sasl_result(Res, State1); + true -> + GetPW = try Mod:get_password_fun(State1) + catch _:undef -> fun(_) -> false end + end, + CheckPW = try Mod:check_password_fun(State1) + catch _:undef -> fun(_, _, _) -> false end + end, + CheckPWDigest = try Mod:check_password_digest_fun(State1) + catch _:undef -> fun(_, _, _, _, _) -> false end + end, + SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [], + GetPW, CheckPW, CheckPWDigest), + Res = cyrsasl:server_start(SASLState, Mech, ClientIn), + process_sasl_result(Res, State1#{sasl_state => SASLState}); + false -> + process_sasl_result({error, unsupported_mechanism, <<"">>}, State1) + end. + +-spec process_sasl_response(sasl_response(), state()) -> state(). +process_sasl_response(#sasl_response{text = ClientIn}, + #{sasl_state := SASLState} = State) -> + SASLResult = cyrsasl:server_step(SASLState, ClientIn), + process_sasl_result(SASLResult, State). + +-spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state(). +process_sasl_result({ok, Props}, State) -> + process_sasl_success(Props, <<"">>, State); +process_sasl_result({ok, Props, ServerOut}, State) -> + process_sasl_success(Props, ServerOut, State); +process_sasl_result({continue, ServerOut, NewSASLState}, State) -> + process_sasl_continue(ServerOut, NewSASLState, State); +process_sasl_result({error, Reason, User}, State) -> + process_sasl_failure(Reason, User, State). + +-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state(). +process_sasl_success(Props, ServerOut, + #{socket := Socket, sockmod := SockMod, + mod := Mod, sasl_mech := Mech} = State) -> + User = identity(Props), + AuthModule = proplists:get_value(auth_module, Props), + SockMod:reset_stream(Socket), + State1 = send_pkt(State, #sasl_success{text = ServerOut}), + case is_disconnected(State1) of + true -> State1; + false -> + State2 = try Mod:handle_auth_success(User, Mech, AuthModule, State1) + catch _:undef -> State1 + end, + case is_disconnected(State2) of + true -> State2; + false -> + State3 = maps:remove(sasl_state, + maps:remove(sasl_mech, State2)), + State3#{stream_id => new_id(), + stream_authenticated => true, + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + user => User} + end + end. + +-spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state(). +process_sasl_continue(ServerOut, NewSASLState, State) -> + State1 = State#{sasl_state => NewSASLState, + stream_state => wait_for_sasl_response}, + send_pkt(State1, #sasl_challenge{text = ServerOut}). + +-spec process_sasl_failure(atom(), binary(), state()) -> state(). +process_sasl_failure(Err, User, + #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) -> + {Reason, Text} = format_sasl_error(Mech, Err), + State1 = send_pkt(State, #sasl_failure{reason = Reason, + text = xmpp:mk_text(Text, Lang)}), + case is_disconnected(State1) of + true -> State1; + false -> + State2 = try Mod:handle_auth_failure(User, Mech, Text, State1) + catch _:undef -> State1 + end, + State3 = maps:remove(sasl_state, maps:remove(sasl_mech, State2)), + State3#{stream_state => wait_for_sasl_request} + end. + +-spec process_sasl_abort(state()) -> state(). +process_sasl_abort(State) -> + process_sasl_failure(aborted, <<"">>, State). + +-spec send_features(state()) -> state(). +send_features(#{stream_version := {1,0}, + stream_encrypted := Encrypted} = State) -> + TLSRequired = is_starttls_required(State), + Features = if TLSRequired and not Encrypted -> + get_tls_feature(State); + true -> + get_sasl_feature(State) ++ get_compress_feature(State) + ++ get_tls_feature(State) ++ get_bind_feature(State) + ++ get_session_feature(State) ++ get_other_features(State) + end, + send_pkt(State, #stream_features{sub_els = Features}); +send_features(State) -> + %% clients and servers from stone age + State. + +-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()]. +get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod, + xmlns := NS, lserver := LServer} = State) -> + Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer); + true -> [] + end, + TLSVerify = try Mod:tls_verify(State) + catch _:undef -> false + end, + Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) -> + [<<"EXTERNAL">>|Mechs]; + true -> + Mechs + end, + try Mod:sasl_mechanisms(Mechs1, State) + catch _:undef -> Mechs1 + end. + +-spec get_sasl_feature(state()) -> [sasl_mechanisms()]. +get_sasl_feature(#{stream_authenticated := false, + stream_encrypted := Encrypted} = State) -> + TLSRequired = is_starttls_required(State), + if Encrypted or not TLSRequired -> + Mechs = get_sasl_mechanisms(State), + [#sasl_mechanisms{list = Mechs}]; + true -> + [] + end; +get_sasl_feature(_) -> + []. + +-spec get_compress_feature(state()) -> [compression()]. +get_compress_feature(#{stream_compressed := false, mod := Mod} = State) -> + try Mod:compress_methods(State) of + [] -> []; + Ms -> [#compression{methods = Ms}] + catch _:undef -> + [] + end; +get_compress_feature(_) -> + []. + +-spec get_tls_feature(state()) -> [starttls()]. +get_tls_feature(#{stream_authenticated := false, + stream_encrypted := false} = State) -> + case is_starttls_available(State) of + true -> + TLSRequired = is_starttls_required(State), + [#starttls{required = TLSRequired}]; + false -> + [] + end; +get_tls_feature(_) -> + []. + +-spec get_bind_feature(state()) -> [bind()]. +get_bind_feature(#{xmlns := ?NS_CLIENT, + stream_authenticated := true, + resource := <<"">>}) -> + [#bind{}]; +get_bind_feature(_) -> + []. + +-spec get_session_feature(state()) -> [xmpp_session()]. +get_session_feature(#{xmlns := ?NS_CLIENT, + stream_authenticated := true, + resource := <<"">>}) -> + [#xmpp_session{optional = true}]; +get_session_feature(_) -> + []. + +-spec get_other_features(state()) -> [xmpp_element()]. +get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) -> + try + if Auth -> Mod:authenticated_stream_features(State); + true -> Mod:unauthenticated_stream_features(State) + end + catch _:undef -> + [] + end. + +-spec is_starttls_available(state()) -> boolean(). +is_starttls_available(#{mod := Mod} = State) -> + try Mod:tls_enabled(State) + catch _:undef -> true + end. + +-spec is_starttls_required(state()) -> boolean(). +is_starttls_required(#{mod := Mod} = State) -> + try Mod:tls_required(State) + catch _:undef -> false + end. + +-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} | + {error, stream_error()}. +set_from_to(Pkt, _State) when not ?is_stanza(Pkt) -> + {ok, Pkt}; +set_from_to(Pkt, #{user := U, server := S, resource := R, + lang := Lang, xmlns := ?NS_CLIENT}) -> + JID = jid:make(U, S, R), + From = case xmpp:get_from(Pkt) of + undefined -> JID; + F -> F + end, + if JID#jid.luser == From#jid.luser andalso + JID#jid.lserver == From#jid.lserver andalso + (JID#jid.lresource == From#jid.lresource + orelse From#jid.lresource == <<"">>) -> + To = case xmpp:get_to(Pkt) of + undefined -> jid:make(U, S); + T -> T + end, + {ok, xmpp:set_from_to(Pkt, JID, To)}; + true -> + Txt = <<"Improper 'from' attribute">>, + {error, xmpp:serr_invalid_from(Txt, Lang)} + end; +set_from_to(Pkt, #{lang := Lang}) -> + From = xmpp:get_from(Pkt), + To = xmpp:get_to(Pkt), + if From == undefined -> + Txt = <<"Missing 'from' attribute">>, + {error, xmpp:serr_improper_addressing(Txt, Lang)}; + To == undefined -> + Txt = <<"Missing 'to' attribute">>, + {error, xmpp:serr_improper_addressing(Txt, Lang)}; + true -> + {ok, Pkt} + end. + +-spec send_header(state()) -> state(). +send_header(#{stream_version := Version} = State) -> + send_header(State, #stream_start{version = Version}). + +-spec send_header(state(), stream_start()) -> state(). +send_header(#{stream_id := StreamID, + stream_version := MyVersion, + stream_header_sent := false, + lang := MyLang, + xmlns := NS} = State, + #stream_start{to = HisTo, from = HisFrom, + lang = HisLang, version = HisVersion}) -> + Lang = select_lang(MyLang, HisLang), + NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK; + true -> <<"">> + end, + Version = case HisVersion of + undefined -> undefined; + {0,_} -> HisVersion; + _ -> MyVersion + end, + StreamStart = #stream_start{version = Version, + lang = Lang, + xmlns = NS, + stream_xmlns = ?NS_STREAM, + db_xmlns = NS_DB, + id = StreamID, + to = HisFrom, + from = HisTo}, + State1 = State#{lang => Lang, + stream_version => Version, + stream_header_sent => true}, + case socket_send(State1, StreamStart) of + ok -> State1; + {error, Why} -> process_stream_end({socket, Why}, State1) + end; +send_header(State, _) -> + State. + +-spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). +send_pkt(#{mod := Mod} = State, Pkt) -> + Result = socket_send(State, Pkt), + State1 = try Mod:handle_send(Pkt, Result, State) + catch _:undef -> State + end, + case Result of + _ when is_record(Pkt, stream_error) -> + process_stream_end({stream, {out, Pkt}}, State1); + ok -> + State1; + {error, Why} -> + process_stream_end({socket, Why}, State1) + end. + +-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state(). +send_error(State, Pkt, Err) -> + case xmpp:is_stanza(Pkt) of + true -> + case xmpp:get_type(Pkt) of + result -> State; + error -> State; + <<"result">> -> State; + <<"error">> -> State; + _ -> + ErrPkt = xmpp:make_error(Pkt, Err), + send_pkt(State, ErrPkt) + end; + false -> + State + end. + +-spec send_trailer(state()) -> state(). +send_trailer(State) -> + socket_send(State, trailer), + close_socket(State). + +-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}. +socket_send(#{socket := Sock, sockmod := SockMod, + stream_state := StateName, + xmlns := NS, + stream_header_sent := true}, Pkt) when StateName /= disconnected -> + case Pkt of + trailer -> + SockMod:send_trailer(Sock); + #stream_start{} -> + SockMod:send_header(Sock, xmpp:encode(Pkt)); + _ -> + SockMod:send_element(Sock, xmpp:encode(Pkt, NS)) + end; +socket_send(_, _) -> + {error, closed}. + +-spec close_socket(state()) -> state(). +close_socket(#{stream_state := disconnected} = State) -> + State; +close_socket(#{sockmod := SockMod, socket := Socket} = State) -> + SockMod:close(Socket), + State#{stream_timeout => infinity, + stream_state => disconnected}. + +-spec select_lang(binary(), binary()) -> binary(). +select_lang(Lang, <<"">>) -> Lang; +select_lang(_, Lang) -> Lang. + +-spec set_lang(xmpp_element(), state()) -> xmpp_element(). +set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) -> + HisLang = xmpp:get_lang(Pkt), + Lang = select_lang(MyLang, HisLang), + xmpp:set_lang(Pkt, Lang); +set_lang(Pkt, _) -> + Pkt. + +-spec format_inet_error(atom()) -> string(). +format_inet_error(Reason) -> + case inet:format_error(Reason) of + "unknown POSIX error" -> atom_to_list(Reason); + Txt -> Txt + end. + +-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string(). +format_stream_error(Reason, Txt) -> + Slogan = case Reason of + undefined -> "no reason"; + #'see-other-host'{} -> "see-other-host"; + _ -> atom_to_list(Reason) + end, + case Txt of + undefined -> Slogan; + #text{data = <<"">>} -> Slogan; + #text{data = Data} -> + binary_to_list(Data) ++ " (" ++ Slogan ++ ")" + end. + +-spec format_sasl_error(cyrsasl:mechanism(), atom()) -> {atom(), binary()}. +format_sasl_error(<<"EXTERNAL">>, Err) -> + xmpp_stream_pkix:format_error(Err); +format_sasl_error(Mech, Err) -> + cyrsasl:format_error(Mech, Err). + +-spec format_tls_error(atom() | binary()) -> list(). +format_tls_error(Reason) when is_atom(Reason) -> + format_inet_error(Reason); +format_tls_error(Reason) -> + Reason. + +-spec format(io:format(), list()) -> binary(). +format(Fmt, Args) -> + iolist_to_binary(io_lib:format(Fmt, Args)). + +-spec lists_intersection(list(), list()) -> list(). +lists_intersection(L1, L2) -> + lists:filter( + fun(E) -> + lists:member(E, L2) + end, L1). + +-spec identity([cyrsasl:sasl_property()]) -> binary(). +identity(Props) -> + case proplists:get_value(authzid, Props, <<>>) of + <<>> -> proplists:get_value(username, Props, <<>>); + AuthzId -> AuthzId + end. diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl new file mode 100644 index 000000000..cd1524c12 --- /dev/null +++ b/src/xmpp_stream_out.erl @@ -0,0 +1,989 @@ +%%%------------------------------------------------------------------- +%%% Created : 14 Dec 2016 by Evgeny Khramtsov +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%------------------------------------------------------------------- +-module(xmpp_stream_out). +-define(GEN_SERVER, gen_server). +-behaviour(?GEN_SERVER). + +-protocol({rfc, 6120}). +-protocol({xep, 114, '1.6'}). + +%% API +-export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1, + stop/1, send/2, close/1, close/2, establish/1, format_error/1, + set_timeout/2, get_transport/1, change_shaper/2]). +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +%%-define(DBGFSM, true). +-ifdef(DBGFSM). +-define(FSMOPTS, [{debug, [trace]}]). +-else. +-define(FSMOPTS, []). +-endif. + +-define(TCP_SEND_TIMEOUT, 15000). + +-include("xmpp.hrl"). +-include_lib("kernel/include/inet.hrl"). + +-type state() :: map(). +-type noreply() :: {noreply, state(), timeout()}. +-type host_port() :: {inet:hostname(), inet:port_number()}. +-type ip_port() :: {inet:ip_address(), inet:port_number()}. +-type network_error() :: {error, inet:posix() | inet_res:res_error()}. +-type stop_reason() :: {idna, bad_string} | + {dns, inet:posix() | inet_res:res_error()} | + {stream, reset | {in | out, stream_error()}} | + {tls, inet:posix() | atom() | binary()} | + {pkix, binary()} | + {auth, atom() | binary() | string()} | + {socket, inet:posix() | closed | timeout} | + internal_failure. + +-callback init(list()) -> {ok, state()} | {error, term()} | ignore. +-callback handle_cast(term(), state()) -> state(). +-callback handle_call(term(), term(), state()) -> state(). +-callback handle_info(term(), state()) -> state(). +-callback terminate(term(), state()) -> any(). +-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. +-callback handle_stream_start(stream_start(), state()) -> state(). +-callback handle_stream_established(state()) -> state(). +-callback handle_stream_downgraded(stream_start(), state()) -> state(). +-callback handle_stream_end(stop_reason(), state()) -> state(). +-callback handle_cdata(binary(), state()) -> state(). +-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state(). +-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state(). +-callback handle_timeout(state()) -> state(). +-callback handle_authenticated_features(stream_features(), state()) -> state(). +-callback handle_unauthenticated_features(stream_features(), state()) -> state(). +-callback handle_auth_success(cyrsasl:mechanism(), state()) -> state(). +-callback handle_auth_failure(cyrsasl:mechanism(), binary(), state()) -> state(). +-callback handle_packet(xmpp_element(), state()) -> state(). +-callback tls_options(state()) -> [proplists:property()]. +-callback tls_required(state()) -> boolean(). +-callback tls_verify(state()) -> boolean(). +-callback tls_enabled(state()) -> boolean(). +-callback dns_timeout(state()) -> timeout(). +-callback dns_retries(state()) -> non_neg_integer(). +-callback default_port(state()) -> inet:port_number(). +-callback address_families(state()) -> [inet:address_family()]. +-callback connect_timeout(state()) -> timeout(). + +-optional_callbacks([init/1, + handle_cast/2, + handle_call/3, + handle_info/2, + terminate/2, + code_change/3, + handle_stream_start/2, + handle_stream_established/1, + handle_stream_downgraded/2, + handle_stream_end/2, + handle_cdata/2, + handle_send/3, + handle_recv/3, + handle_timeout/1, + handle_authenticated_features/2, + handle_unauthenticated_features/2, + handle_auth_success/2, + handle_auth_failure/3, + handle_packet/2, + tls_options/1, + tls_required/1, + tls_verify/1, + tls_enabled/1, + dns_timeout/1, + dns_retries/1, + default_port/1, + address_families/1, + connect_timeout/1]). + +%%%=================================================================== +%%% API +%%%=================================================================== +start(Mod, Args, Opts) -> + ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + +start_link(Mod, Args, Opts) -> + ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + +call(Ref, Msg, Timeout) -> + ?GEN_SERVER:call(Ref, Msg, Timeout). + +cast(Ref, Msg) -> + ?GEN_SERVER:cast(Ref, Msg). + +reply(Ref, Reply) -> + ?GEN_SERVER:reply(Ref, Reply). + +-spec connect(pid()) -> ok. +connect(Ref) -> + cast(Ref, connect). + +-spec stop(pid()) -> ok; + (state()) -> no_return(). +stop(Pid) when is_pid(Pid) -> + cast(Pid, stop); +stop(#{owner := Owner} = State) when Owner == self() -> + terminate(normal, State), + exit(normal); +stop(_) -> + erlang:error(badarg). + +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Pid, Pkt) when is_pid(Pid) -> + cast(Pid, {send, Pkt}); +send(#{owner := Owner} = State, Pkt) when Owner == self() -> + send_pkt(State, Pkt); +send(_, _) -> + erlang:error(badarg). + +-spec close(pid()) -> ok; + (state()) -> state(). +close(Ref) -> + close(Ref, true). + +-spec close(pid(), boolean()) -> ok; + (state(), boolean()) -> state(). +close(Pid, SendTrailer) when is_pid(Pid) -> + cast(Pid, {close, SendTrailer}); +close(#{owner := Owner} = State, SendTrailer) when Owner == self() -> + if SendTrailer -> send_trailer(State); + true -> close_socket(State) + end; +close(_, _) -> + erlang:error(badarg). + +-spec establish(state()) -> state(). +establish(State) -> + process_stream_established(State). + +-spec set_timeout(state(), timeout()) -> state(). +set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> + case Timeout of + infinity -> State#{stream_timeout => infinity}; + _ -> + Time = p1_time_compat:monotonic_time(milli_seconds), + State#{stream_timeout => {Timeout, Time}} + end; +set_timeout(_, _) -> + erlang:error(badarg). + +get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner}) + when Owner == self() -> + SockMod:get_transport(Socket); +get_transport(_) -> + erlang:error(badarg). + +-spec change_shaper(state(), shaper:shaper()) -> ok. +change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper) + when Owner == self() -> + SockMod:change_shaper(Socket, Shaper); +change_shaper(_, _) -> + erlang:error(badarg). + +-spec format_error(stop_reason()) -> binary(). +format_error({idna, _}) -> + <<"Remote domain is not an IDN hostname">>; +format_error({dns, Reason}) -> + format("DNS lookup failed: ~s", [format_inet_error(Reason)]); +format_error({socket, Reason}) -> + format("Connection failed: ~s", [format_inet_error(Reason)]); +format_error({pkix, Reason}) -> + {_, ErrTxt} = xmpp_stream_pkix:format_error(Reason), + format("Peer certificate rejected: ~s", [ErrTxt]); +format_error({stream, reset}) -> + <<"Stream reset by peer">>; +format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) -> + format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]); +format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) -> + format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]); +format_error({tls, Reason}) -> + format("TLS failed: ~s", [format_tls_error(Reason)]); +format_error({auth, Reason}) -> + format("Authentication failed: ~s", [Reason]); +format_error(internal_failure) -> + <<"Internal server error">>; +format_error(Err) -> + format("Unrecognized error: ~w", [Err]). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +-spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore. +init([Mod, SockMod, From, To, Opts]) -> + Time = p1_time_compat:monotonic_time(milli_seconds), + State = #{owner => self(), + mod => Mod, + sockmod => SockMod, + server => From, + user => <<"">>, + resource => <<"">>, + lang => <<"">>, + remote_server => To, + xmlns => ?NS_SERVER, + stream_direction => out, + stream_timeout => {timer:seconds(30), Time}, + stream_id => new_id(), + stream_encrypted => false, + stream_verified => false, + stream_authenticated => false, + stream_restarted => false, + stream_state => connecting}, + case try Mod:init([State, Opts]) + catch _:undef -> {ok, State} + end of + {ok, State1} -> + {_, State2, Timeout} = noreply(State1), + {ok, State2, Timeout}; + {error, Reason} -> + {stop, Reason}; + ignore -> + ignore + end. + +-spec handle_call(term(), term(), state()) -> noreply(). +handle_call(Call, From, #{mod := Mod} = State) -> + noreply(try Mod:handle_call(Call, From, State) + catch _:undef -> State + end). + +-spec handle_cast(term(), state()) -> noreply(). +handle_cast(connect, #{remote_server := RemoteServer, + sockmod := SockMod, + stream_state := connecting} = State) -> + noreply( + case idna_to_ascii(RemoteServer) of + false -> + process_stream_end({idna, bad_string}, State); + ASCIIName -> + case resolve(binary_to_list(ASCIIName), State) of + {ok, AddrPorts} -> + case connect(AddrPorts, State) of + {ok, Socket, AddrPort} -> + SocketMonitor = SockMod:monitor(Socket), + State1 = State#{ip => AddrPort, + socket => Socket, + socket_monitor => SocketMonitor}, + State2 = State1#{stream_state => wait_for_stream}, + send_header(State2); + {error, Why} -> + process_stream_end({socket, Why}, State) + end; + {error, Why} -> + process_stream_end({dns, Why}, State) + end + end); +handle_cast(connect, State) -> + %% Ignoring connection attempts in other states + noreply(State); +handle_cast({send, Pkt}, State) -> + noreply(send_pkt(State, Pkt)); +handle_cast(stop, State) -> + {stop, normal, State}; +handle_cast(Cast, #{mod := Mod} = State) -> + noreply(try Mod:handle_cast(Cast, State) + catch _:undef -> State + end). + +-spec handle_info(term(), state()) -> noreply(). +handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, + #{stream_state := wait_for_stream, + xmlns := XMLNS, lang := MyLang} = State) -> + El = #xmlel{name = Name, attrs = Attrs}, + noreply( + try xmpp:decode(El, XMLNS, []) of + #stream_start{} = Pkt -> + process_stream(Pkt, State); + _ -> + send_pkt(State, xmpp:serr_invalid_xml()) + catch _:{xmpp_codec, Why} -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + Err = xmpp:serr_invalid_xml(Txt, Lang), + send_pkt(State, Err) + end); +handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> + State1 = send_header(State), + noreply( + case is_disconnected(State1) of + true -> State1; + false -> + Err = case Reason of + <<"XML stanza is too big">> -> + xmpp:serr_policy_violation(Reason, Lang); + {_, Txt} -> + xmpp:serr_not_well_formed(Txt, Lang) + end, + send_pkt(State1, Err) + end); +handle_info({'$gen_event', {xmlstreamelement, El}}, + #{xmlns := NS, mod := Mod} = State) -> + noreply( + try xmpp:decode(El, NS, [ignore_els]) of + Pkt -> + State1 = try Mod:handle_recv(El, Pkt, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> process_element(Pkt, State1) + end + catch _:{xmpp_codec, Why} -> + State1 = try Mod:handle_recv(El, {error, Why}, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> process_invalid_xml(State1, El, Why) + end + end); +handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, + #{mod := Mod} = State) -> + noreply(try Mod:handle_cdata(Data, State) + catch _:undef -> State + end); +handle_info({'$gen_event', {xmlstreamend, _}}, State) -> + noreply(process_stream_end({stream, reset}, State)); +handle_info({'$gen_event', closed}, State) -> + noreply(process_stream_end({socket, closed}, State)); +handle_info(timeout, #{mod := Mod} = State) -> + Disconnected = is_disconnected(State), + noreply(try Mod:handle_timeout(State) + catch _:undef when not Disconnected -> + send_pkt(State, xmpp:serr_connection_timeout()); + _:undef -> + stop(State) + end); +handle_info({'DOWN', MRef, _Type, _Object, _Info}, + #{socket_monitor := MRef} = State) -> + noreply(process_stream_end({socket, closed}, State)); +handle_info(Info, #{mod := Mod} = State) -> + noreply(try Mod:handle_info(Info, State) + catch _:undef -> State + end). + +-spec terminate(term(), state()) -> any(). +terminate(Reason, #{mod := Mod} = State) -> + case get(already_terminated) of + true -> + State; + _ -> + put(already_terminated, true), + try Mod:terminate(Reason, State) + catch _:undef -> ok + end, + send_trailer(State) + end. + +code_change(OldVsn, #{mod := Mod} = State, Extra) -> + Mod:code_change(OldVsn, State, Extra). + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec noreply(state()) -> noreply(). +noreply(#{stream_timeout := infinity} = State) -> + {noreply, State, infinity}; +noreply(#{stream_timeout := {MSecs, OldTime}} = State) -> + NewTime = p1_time_compat:monotonic_time(milli_seconds), + Timeout = max(0, MSecs - NewTime + OldTime), + {noreply, State, Timeout}. + +-spec new_id() -> binary(). +new_id() -> + randoms:get_string(). + +-spec is_disconnected(state()) -> boolean(). +is_disconnected(#{stream_state := StreamState}) -> + StreamState == disconnected. + +-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state(). +process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> + case xmpp:is_stanza(El) of + true -> + Txt = xmpp:io_format_error(Reason), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + send_error(State, El, xmpp:err_bad_request(Txt, Lang)); + false -> + State + end. + +-spec process_stream_end(stop_reason(), state()) -> state(). +process_stream_end(_, #{stream_state := disconnected} = State) -> + State; +process_stream_end(Reason, #{mod := Mod} = State) -> + State1 = send_trailer(State), + try Mod:handle_stream_end(Reason, State1) + catch _:undef -> stop(State1) + end. + +-spec process_stream(stream_start(), state()) -> state(). +process_stream(#stream_start{xmlns = XML_NS, + stream_xmlns = STREAM_NS}, + #{xmlns := NS} = State) + when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> + send_pkt(State, xmpp:serr_invalid_namespace()); +process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> + send_pkt(State, xmpp:serr_unsupported_version()); +process_stream(#stream_start{lang = Lang, id = ID, + version = Version} = StreamStart, + #{mod := Mod} = State) -> + State1 = State#{stream_remote_id => ID, lang => Lang}, + State2 = try Mod:handle_stream_start(StreamStart, State1) + catch _:undef -> State1 + end, + case is_disconnected(State2) of + true -> State2; + false -> + case Version of + {1, _} -> + State2#{stream_state => wait_for_features}; + _ -> + process_stream_downgrade(StreamStart, State2) + end + end. + +-spec process_element(xmpp_element(), state()) -> state(). +process_element(Pkt, #{stream_state := StateName} = State) -> + case Pkt of + #stream_features{} when StateName == wait_for_features -> + process_features(Pkt, State); + #starttls_proceed{} when StateName == wait_for_starttls_response -> + process_starttls(State); + #sasl_success{} when StateName == wait_for_sasl_response -> + process_sasl_success(State); + #sasl_failure{} when StateName == wait_for_sasl_response -> + process_sasl_failure(Pkt, State); + #stream_error{} -> + process_stream_end({stream, {in, Pkt}}, State); + _ when is_record(Pkt, stream_features); + is_record(Pkt, starttls_proceed); + is_record(Pkt, starttls); + is_record(Pkt, sasl_auth); + is_record(Pkt, sasl_success); + is_record(Pkt, sasl_failure); + is_record(Pkt, sasl_response); + is_record(Pkt, sasl_abort); + is_record(Pkt, compress); + is_record(Pkt, handshake) -> + %% Do not pass this crap upstream + State; + _ -> + process_packet(Pkt, State) + end. + +-spec process_features(stream_features(), state()) -> state(). +process_features(StreamFeatures, + #{stream_authenticated := true, mod := Mod} = State) -> + State1 = try Mod:handle_authenticated_features(StreamFeatures, State) + catch _:undef -> State + end, + process_stream_established(State1); +process_features(#stream_features{sub_els = Els} = StreamFeatures, + #{stream_encrypted := Encrypted, + mod := Mod, lang := Lang} = State) -> + State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + TLSRequired = is_starttls_required(State1), + TLSAvailable = is_starttls_available(State1), + %% TODO: improve xmpp.erl + Msg = #message{sub_els = Els}, + case xmpp:get_subtag(Msg, #starttls{}) of + false when TLSRequired and not Encrypted -> + Txt = <<"Use of STARTTLS required">>, + send_pkt(State1, xmpp:serr_policy_violation(Txt, Lang)); + #starttls{required = true} when not TLSAvailable and not Encrypted -> + Txt = <<"Use of STARTTLS forbidden">>, + send_pkt(State1, xmpp:serr_unsupported_feature(Txt, Lang)); + #starttls{} when TLSAvailable and not Encrypted -> + State2 = State1#{stream_state => wait_for_starttls_response}, + send_pkt(State2, #starttls{}); + _ -> + State2 = process_cert_verification(State1), + case is_disconnected(State2) of + true -> State2; + false -> + case xmpp:get_subtag(Msg, #sasl_mechanisms{}) of + #sasl_mechanisms{list = Mechs} -> + process_sasl_mechanisms(Mechs, State2); + false -> + process_sasl_failure( + #sasl_failure{reason = 'invalid-mechanism'}, + State2) + end + end + end + end. + +-spec process_stream_established(state()) -> state(). +process_stream_established(#{stream_state := StateName} = State) + when StateName == disconnected; StateName == established -> + State; +process_stream_established(#{mod := Mod} = State) -> + State1 = State#{stream_authenticated := true, + stream_state => established, + stream_timeout => infinity}, + try Mod:handle_stream_established(State1) + catch _:undef -> State1 + end. + +-spec process_sasl_mechanisms([binary()], state()) -> state(). +process_sasl_mechanisms(Mechs, #{user := User, server := Server} = State) -> + %% TODO: support other mechanisms + Mech = <<"EXTERNAL">>, + case lists:member(<<"EXTERNAL">>, Mechs) of + true -> + State1 = State#{stream_state => wait_for_sasl_response}, + Authzid = jid:to_string(jid:make(User, Server)), + send_pkt(State1, #sasl_auth{mechanism = Mech, text = Authzid}); + false -> + process_sasl_failure( + #sasl_failure{reason = 'invalid-mechanism'}, State) + end. + +-spec process_starttls(state()) -> state(). +process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) -> + TLSOpts = try Mod:tls_options(State) + catch _:undef -> [] + end, + case SockMod:starttls(Socket, [connect|TLSOpts]) of + {ok, TLSSocket} -> + State1 = State#{socket => TLSSocket, + stream_id => new_id(), + stream_restarted => true, + stream_state => wait_for_stream, + stream_encrypted => true}, + send_header(State1); + {error, Why} -> + process_stream_end({tls, Why}, State) + end. + +-spec process_stream_downgrade(stream_start(), state()) -> state(). +process_stream_downgrade(StreamStart, + #{mod := Mod, lang := Lang, + stream_encrypted := Encrypted} = State) -> + TLSRequired = is_starttls_required(State), + if not Encrypted and TLSRequired -> + Txt = <<"Use of STARTTLS required">>, + send_pkt(State, xmpp:serr_policy_violation(Txt, Lang)); + true -> + State1 = State#{stream_state => downgraded}, + try Mod:handle_stream_downgraded(StreamStart, State1) + catch _:undef -> + send_pkt(State1, xmpp:serr_unsupported_version()) + end + end. + +-spec process_cert_verification(state()) -> state(). +process_cert_verification(#{stream_encrypted := true, + stream_verified := false, + mod := Mod} = State) -> + case try Mod:tls_verify(State) + catch _:undef -> true + end of + true -> + case xmpp_stream_pkix:authenticate(State) of + {ok, _} -> + State#{stream_verified => true}; + {error, Why, _Peer} -> + process_stream_end({pkix, Why}, State) + end; + false -> + State#{stream_verified => true} + end; +process_cert_verification(State) -> + State. + +-spec process_sasl_success(state()) -> state(). +process_sasl_success(#{mod := Mod, + sockmod := SockMod, + socket := Socket} = State) -> + State1 = try Mod:handle_auth_success(<<"EXTERNAL">>, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + SockMod:reset_stream(Socket), + State2 = State1#{stream_id => new_id(), + stream_restarted => true, + stream_state => wait_for_stream, + stream_authenticated => true}, + send_header(State2) + end. + +-spec process_sasl_failure(sasl_failure(), state()) -> state(). +process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) -> + try Mod:handle_auth_failure(<<"EXTERNAL">>, {auth, Reason}, State) + catch _:undef -> process_stream_end({auth, Reason}, State) + end. + +-spec process_packet(xmpp_element(), state()) -> state(). +process_packet(Pkt, #{mod := Mod} = State) -> + try Mod:handle_packet(Pkt, State) + catch _:undef -> State + end. + +-spec is_starttls_required(state()) -> boolean(). +is_starttls_required(#{mod := Mod} = State) -> + try Mod:tls_required(State) + catch _:undef -> false + end. + +-spec is_starttls_available(state()) -> boolean(). +is_starttls_available(#{mod := Mod} = State) -> + try Mod:tls_enabled(State) + catch _:undef -> true + end. + +-spec send_header(state()) -> state(). +send_header(#{remote_server := RemoteServer, + stream_encrypted := Encrypted, + lang := Lang, + xmlns := NS, + user := User, + resource := Resource, + server := Server} = State) -> + NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK; + true -> <<"">> + end, + From = if Encrypted -> + jid:make(User, Server, Resource); + NS == ?NS_SERVER -> + jid:make(Server); + true -> + undefined + end, + StreamStart = #stream_start{xmlns = NS, + lang = Lang, + stream_xmlns = ?NS_STREAM, + db_xmlns = NS_DB, + from = From, + to = jid:make(RemoteServer), + version = {1,0}}, + case socket_send(State, StreamStart) of + ok -> State; + {error, Why} -> process_stream_end({socket, Why}, State) + end. + +-spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). +send_pkt(#{mod := Mod} = State, Pkt) -> + Result = socket_send(State, Pkt), + State1 = try Mod:handle_send(Pkt, Result, State) + catch _:undef -> State + end, + case Result of + _ when is_record(Pkt, stream_error) -> + process_stream_end({stream, {out, Pkt}}, State1); + ok -> + State1; + {error, Why} -> + process_stream_end({socket, Why}, State1) + end. + +-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state(). +send_error(State, Pkt, Err) -> + case xmpp:is_stanza(Pkt) of + true -> + case xmpp:get_type(Pkt) of + result -> State; + error -> State; + <<"result">> -> State; + <<"error">> -> State; + _ -> + ErrPkt = xmpp:make_error(Pkt, Err), + send_pkt(State, ErrPkt) + end; + false -> + State + end. + +-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}. +socket_send(#{sockmod := SockMod, socket := Socket, xmlns := NS, + stream_state := StateName}, Pkt) when StateName /= disconnected -> + case Pkt of + trailer -> + SockMod:send_trailer(Socket); + #stream_start{} -> + SockMod:send_header(Socket, xmpp:encode(Pkt)); + _ -> + SockMod:send_element(Socket, xmpp:encode(Pkt, NS)) + end; +socket_send(_, _) -> + {error, closed}. + +-spec send_trailer(state()) -> state(). +send_trailer(State) -> + socket_send(State, trailer), + close_socket(State). + +-spec close_socket(state()) -> state(). +close_socket(#{stream_state := disconnected} = State) -> + State; +close_socket(State) -> + case State of + #{sockmod := SockMod, socket := Socket} -> + SockMod:close(Socket); + _ -> + ok + end, + State#{stream_timeout => infinity, + stream_state => disconnected}. + +-spec select_lang(binary(), binary()) -> binary(). +select_lang(Lang, <<"">>) -> Lang; +select_lang(_, Lang) -> Lang. + +-spec format_inet_error(atom()) -> string(). +format_inet_error(Reason) -> + case inet:format_error(Reason) of + "unknown POSIX error" -> atom_to_list(Reason); + Txt -> Txt + end. + +-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string(). +format_stream_error(Reason, Txt) -> + Slogan = case Reason of + undefined -> "no reason"; + #'see-other-host'{} -> "see-other-host"; + _ -> atom_to_list(Reason) + end, + case Txt of + undefined -> Slogan; + #text{data = <<"">>} -> Slogan; + #text{data = Data} -> + binary_to_list(Data) ++ " (" ++ Slogan ++ ")" + end. + +-spec format_tls_error(atom() | binary()) -> list(). +format_tls_error(Reason) when is_atom(Reason) -> + format_inet_error(Reason); +format_tls_error(Reason) -> + binary_to_list(Reason). + +-spec format(io:format(), list()) -> binary(). +format(Fmt, Args) -> + iolist_to_binary(io_lib:format(Fmt, Args)). + +%%%=================================================================== +%%% Connection stuff +%%%=================================================================== +idna_to_ascii(<<$[, _/binary>> = Host) -> + %% This is an IPv6 address in 'IP-literal' format (as per RFC7622) + %% We remove brackets here + case binary:last(Host) of + $] -> + IPv6 = binary:part(Host, {1, size(Host)-2}), + case inet:parse_ipv6strict_address(binary_to_list(IPv6)) of + {ok, _} -> IPv6; + {error, _} -> false + end; + _ -> + false + end; +idna_to_ascii(Host) -> + case inet:parse_address(binary_to_list(Host)) of + {ok, _} -> Host; + {error, _} -> ejabberd_idna:domain_utf8_to_ascii(Host) + end. + +-spec resolve(string(), state()) -> {ok, [host_port()]} | network_error(). +resolve(Host, State) -> + case srv_lookup(Host, State) of + {error, _Reason} -> + DefaultPort = get_default_port(State), + a_lookup([{Host, DefaultPort}], State); + {ok, HostPorts} -> + a_lookup(HostPorts, State) + end. + +-spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error(). +srv_lookup(_Host, #{xmlns := ?NS_COMPONENT}) -> + %% Do not attempt to lookup SRV for component connections + {error, nxdomain}; +srv_lookup(Host, State) -> + %% Only perform SRV lookups for FQDN names + case string:chr(Host, $.) of + 0 -> + {error, nxdomain}; + _ -> + case inet:parse_address(Host) of + {ok, _} -> + {error, nxdomain}; + {error, _} -> + Timeout = get_dns_timeout(State), + Retries = get_dns_retries(State), + srv_lookup(Host, Timeout, Retries) + end + end. + +-spec srv_lookup(string(), timeout(), integer()) -> + {ok, [host_port()]} | network_error(). +srv_lookup(_Host, _Timeout, Retries) when Retries < 1 -> + {error, timeout}; +srv_lookup(Host, Timeout, Retries) -> + SRVName = "_xmpp-server._tcp." ++ Host, + case inet_res:getbyname(SRVName, srv, Timeout) of + {ok, HostEntry} -> + host_entry_to_host_ports(HostEntry); + {error, timeout} -> + srv_lookup(Host, Timeout, Retries - 1); + {error, _} = Err -> + Err + end. + +-spec a_lookup([{inet:hostname(), inet:port_number()}], state()) -> + {ok, [ip_port()]} | network_error(). +a_lookup(HostPorts, State) -> + HostPortFamilies = [{Host, Port, Family} + || {Host, Port} <- HostPorts, + Family <- get_address_families(State)], + a_lookup(HostPortFamilies, State, {error, nxdomain}). + +-spec a_lookup([{inet:hostname(), inet:port_number(), inet:address_family()}], + state(), network_error()) -> {ok, [ip_port()]} | network_error(). +a_lookup([{Host, Port, Family}|HostPortFamilies], State, _) -> + Timeout = get_dns_timeout(State), + Retries = get_dns_retries(State), + case a_lookup(Host, Port, Family, Timeout, Retries) of + {error, _} = Err -> + a_lookup(HostPortFamilies, State, Err); + {ok, AddrPorts} -> + {ok, AddrPorts} + end; +a_lookup([], _State, Err) -> + Err. + +-spec a_lookup(inet:hostname(), inet:port_number(), inet:address_family(), + timeout(), integer()) -> {ok, [ip_port()]} | network_error(). +a_lookup(_Host, _Port, _Family, _Timeout, Retries) when Retries < 1 -> + {error, timeout}; +a_lookup(Host, Port, Family, Timeout, Retries) -> + case inet:gethostbyname(Host, Family, Timeout) of + {error, timeout} -> + a_lookup(Host, Port, Family, Timeout, Retries - 1); + {error, _} = Err -> + Err; + {ok, HostEntry} -> + host_entry_to_addr_ports(HostEntry, Port) + end. + +-spec host_entry_to_host_ports(inet:hostent()) -> {ok, [host_port()]} | + {error, nxdomain}. +host_entry_to_host_ports(#hostent{h_addr_list = AddrList}) -> + PrioHostPorts = lists:flatmap( + fun({Priority, Weight, Port, Host}) -> + N = case Weight of + 0 -> 0; + _ -> (Weight + 1) * randoms:uniform() + end, + [{Priority * 65536 - N, Host, Port}]; + (_) -> + [] + end, AddrList), + HostPorts = [{Host, Port} + || {_Priority, Host, Port} <- lists:usort(PrioHostPorts)], + case HostPorts of + [] -> {error, nxdomain}; + _ -> {ok, HostPorts} + end. + +-spec host_entry_to_addr_ports(inet:hostent(), inet:port_number()) -> + {ok, [ip_port()]} | {error, nxdomain}. +host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port) -> + AddrPorts = lists:flatmap( + fun(Addr) -> + try get_addr_type(Addr) of + _ -> [{Addr, Port}] + catch _:_ -> + [] + end + end, AddrList), + case AddrPorts of + [] -> {error, nxdomain}; + _ -> {ok, AddrPorts} + end. + +-spec connect([ip_port()], state()) -> {ok, term(), ip_port()} | network_error(). +connect(AddrPorts, #{sockmod := SockMod} = State) -> + Timeout = get_connect_timeout(State), + connect(AddrPorts, SockMod, Timeout, {error, nxdomain}). + +-spec connect([ip_port()], module(), timeout(), network_error()) -> + {ok, term(), ip_port()} | network_error(). +connect([{Addr, Port}|AddrPorts], SockMod, Timeout, _) -> + Type = get_addr_type(Addr), + case SockMod:connect(Addr, Port, + [binary, {packet, 0}, + {send_timeout, ?TCP_SEND_TIMEOUT}, + {send_timeout_close, true}, + {active, false}, Type], + Timeout) of + {ok, Socket} -> + {ok, Socket, {Addr, Port}}; + Err -> + connect(AddrPorts, SockMod, Timeout, Err) + end; +connect([], _SockMod, _Timeout, Err) -> + Err. + +-spec get_addr_type(inet:ip_address()) -> inet:address_family(). +get_addr_type({_, _, _, _}) -> inet; +get_addr_type({_, _, _, _, _, _, _, _}) -> inet6. + +-spec get_dns_timeout(state()) -> timeout(). +get_dns_timeout(#{mod := Mod} = State) -> + try Mod:dns_timeout(State) + catch _:undef -> timer:seconds(10) + end. + +-spec get_dns_retries(state()) -> non_neg_integer(). +get_dns_retries(#{mod := Mod} = State) -> + try Mod:dns_retries(State) + catch _:undef -> 2 + end. + +-spec get_default_port(state()) -> inet:port_number(). +get_default_port(#{mod := Mod, xmlns := NS} = State) -> + try Mod:default_port(State) + catch _:undef when NS == ?NS_SERVER -> 5269; + _:undef when NS == ?NS_CLIENT -> 5222 + end. + +-spec get_address_families(state()) -> [inet:address_family()]. +get_address_families(#{mod := Mod} = State) -> + try Mod:address_families(State) + catch _:undef -> [inet, inet6] + end. + +-spec get_connect_timeout(state()) -> timeout(). +get_connect_timeout(#{mod := Mod} = State) -> + try Mod:connect_timeout(State) + catch _:undef -> timer:seconds(10) + end. diff --git a/src/xmpp_stream_pkix.erl b/src/xmpp_stream_pkix.erl new file mode 100644 index 000000000..5d64c5eb6 --- /dev/null +++ b/src/xmpp_stream_pkix.erl @@ -0,0 +1,176 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2016, Evgeny Khramtsov +%%% @doc +%%% +%%% @end +%%% Created : 13 Dec 2016 by Evgeny Khramtsov +%%%------------------------------------------------------------------- +-module(xmpp_stream_pkix). + +%% API +-export([authenticate/1, authenticate/2, format_error/1]). + +-include("xmpp.hrl"). +-include_lib("public_key/include/public_key.hrl"). +-include("XmppAddr.hrl"). + +%%%=================================================================== +%%% API +%%%=================================================================== +-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state()) + -> {ok, binary()} | {error, atom(), binary()}. +authenticate(State) -> + authenticate(State, <<"">>). + +-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state(), binary()) + -> {ok, binary()} | {error, atom(), binary()}. +authenticate(#{xmlns := ?NS_SERVER, sockmod := SockMod, + socket := Socket} = State, Authzid) -> + Peer = try maps:get(remote_server, State) + catch _:{badkey, _} -> Authzid + end, + case SockMod:get_peer_certificate(Socket) of + {ok, Cert} -> + case SockMod:get_verify_result(Socket) of + 0 -> + case ejabberd_idna:domain_utf8_to_ascii(Peer) of + false -> + {error, idna_failed, Peer}; + AsciiPeer -> + case lists:any( + fun(D) -> match_domain(AsciiPeer, D) end, + get_cert_domains(Cert)) of + true -> + {ok, Peer}; + false -> + {error, hostname_mismatch, Peer} + end + end; + VerifyRes -> + %% TODO: return atomic errors + %% This should be improved in fast_tls + Reason = fast_tls:get_cert_verify_string(VerifyRes, Cert), + {error, erlang:binary_to_atom(Reason, utf8), Peer} + end; + {error, _Reason} -> + {error, get_cert_failed, Peer}; + error -> + {error, get_cert_failed, Peer} + end; +authenticate(_State, _Authzid) -> + %% TODO: client PKIX authentication + {error, client_not_supported, <<"">>}. + +format_error(idna_failed) -> + {'bad-protocol', <<"Remote domain is not an IDN hostname">>}; +format_error(hostname_mismatch) -> + {'not-authorized', <<"Certificate host name mismatch">>}; +format_error(get_cert_failed) -> + {'bad-protocol', <<"Failed to get peer certificate">>}; +format_error(client_not_supported) -> + {'invalid-mechanism', <<"Client certificate verification is not supported">>}; +format_error(Other) -> + {'not-authorized', erlang:atom_to_binary(Other, utf8)}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +get_cert_domains(Cert) -> + TBSCert = Cert#'Certificate'.tbsCertificate, + Subject = case TBSCert#'TBSCertificate'.subject of + {rdnSequence, Subj} -> lists:flatten(Subj); + _ -> [] + end, + Extensions = case TBSCert#'TBSCertificate'.extensions of + Exts when is_list(Exts) -> Exts; + _ -> [] + end, + lists:flatmap( + fun(#'AttributeTypeAndValue'{type = ?'id-at-commonName',value = Val}) -> + case 'OTP-PUB-KEY':decode('X520CommonName', Val) of + {ok, {_, D1}} -> + D = if is_binary(D1) -> D1; + is_list(D1) -> list_to_binary(D1); + true -> error + end, + if D /= error -> + case jid:from_string(D) of + #jid{luser = <<"">>, lserver = LD, + lresource = <<"">>} -> + [LD]; + _ -> [] + end; + true -> [] + end; + _ -> [] + end; + (_) -> [] + end, Subject) ++ + lists:flatmap( + fun(#'Extension'{extnID = ?'id-ce-subjectAltName', + extnValue = Val}) -> + BVal = if is_list(Val) -> list_to_binary(Val); + true -> Val + end, + case 'OTP-PUB-KEY':decode('SubjectAltName', BVal) of + {ok, SANs} -> + lists:flatmap( + fun({otherName, #'AnotherName'{'type-id' = ?'id-on-xmppAddr', + value = XmppAddr}}) -> + case 'XmppAddr':decode('XmppAddr', XmppAddr) of + {ok, D} when is_binary(D) -> + case jid:from_string(D) of + #jid{luser = <<"">>, + lserver = LD, + lresource = <<"">>} -> + case ejabberd_idna:domain_utf8_to_ascii(LD) of + false -> + []; + PCLD -> + [PCLD] + end; + _ -> [] + end; + _ -> [] + end; + ({dNSName, D}) when is_list(D) -> + case jid:from_string(list_to_binary(D)) of + #jid{luser = <<"">>, + lserver = LD, + lresource = <<"">>} -> + [LD]; + _ -> [] + end; + (_) -> [] + end, SANs); + _ -> [] + end; + (_) -> [] + end, Extensions). + +match_domain(Domain, Domain) -> true; +match_domain(Domain, Pattern) -> + DLabels = str:tokens(Domain, <<".">>), + PLabels = str:tokens(Pattern, <<".">>), + match_labels(DLabels, PLabels). + +match_labels([], []) -> true; +match_labels([], [_ | _]) -> false; +match_labels([_ | _], []) -> false; +match_labels([DL | DLabels], [PL | PLabels]) -> + case lists:all(fun (C) -> + $a =< C andalso C =< $z orelse + $0 =< C andalso C =< $9 orelse + C == $- orelse C == $* + end, + binary_to_list(PL)) + of + true -> + Regexp = ejabberd_regexp:sh_to_awk(PL), + case ejabberd_regexp:run(DL, Regexp) of + match -> match_labels(DLabels, PLabels); + nomatch -> false + end; + false -> false + end. diff --git a/tools/hook_deps.sh b/tools/hook_deps.sh new file mode 100755 index 000000000..39c33ca2e --- /dev/null +++ b/tools/hook_deps.sh @@ -0,0 +1,394 @@ +#!/usr/bin/escript +%% -*- erlang -*- +%%! -pa ebin + +-record(state, {run_hooks = dict:new(), + run_fold_hooks = dict:new(), + hooked_funs = dict:new(), + mfas = dict:new(), + specs = dict:new(), + module :: module(), + file :: filename:filename()}). + +main([Dir]) -> + State = + filelib:fold_files( + Dir, ".+\.[eh]rl\$", false, + fun(FileIn, Res) -> + case get_forms(FileIn) of + {ok, Forms} -> + Tree = erl_syntax:form_list(Forms), + Mod = list_to_atom(filename:rootname(filename:basename(FileIn))), + Acc0 = analyze_form(Tree, Res#state{module = Mod, file = FileIn}), + erl_syntax_lib:fold( + fun(Form, Acc) -> + case erl_syntax:type(Form) of + application -> + case erl_syntax_lib:analyze_application(Form) of + {ejabberd_hooks, {run, N}} + when N == 2; N == 3 -> + analyze_run_hook(Form, Acc); + {ejabberd_hooks, {run_fold, N}} + when N == 3; N == 4 -> + analyze_run_fold_hook(Form, Acc); + {ejabberd_hooks, {add, N}} + when N == 4; N == 5 -> + analyze_run_fun(Form, Acc); + {gen_iq_handler, {add_iq_handler, 6}} -> + analyze_iq_handler(Form, Acc); + _ -> + Acc + end; + attribute -> + case catch erl_syntax_lib:analyze_attribute(Form) of + {spec, _} -> + analyze_type_spec(Form, Acc); + _ -> + Acc + end; + _ -> + Acc + end + end, Acc0, Tree); + _Err -> + Res + end + end, #state{}), + report_orphaned_funs(State), + RunDeps = build_deps(State#state.run_hooks, State#state.hooked_funs), + RunFoldDeps = build_deps(State#state.run_fold_hooks, State#state.hooked_funs), + emit_module(RunDeps, RunFoldDeps, State#state.specs, Dir, hooks_type_test). + +analyze_form(_Form, State) -> + %% case catch erl_syntax_lib:analyze_forms(Form) of + %% Props when is_list(Props) -> + %% M = State#state.module, + %% MFAs = lists:foldl( + %% fun({F, A}, Acc) -> + %% dict:append({M, F}, A, Acc) + %% end, State#state.mfas, + %% proplists:get_value(functions, Props, [])), + %% State#state{mfas = MFAs}; + %% _ -> + %% State + %% end. + State. + +analyze_run_hook(Form, State) -> + [Hook|Tail] = erl_syntax:application_arguments(Form), + case atom_value(Hook, State) of + undefined -> + State; + HookName -> + Args = case Tail of + [_Host, Args0] -> Args0; + [Args0] -> + Args0 + end, + Arity = erl_syntax:list_length(Args), + Hooks = dict:store({HookName, Arity}, + {State#state.file, erl_syntax:get_pos(Hook)}, + State#state.run_hooks), + State#state{run_hooks = Hooks} + end. + +analyze_run_fold_hook(Form, State) -> + [Hook|Tail] = erl_syntax:application_arguments(Form), + case atom_value(Hook, State) of + undefined -> + State; + HookName -> + Args = case Tail of + [_Host, _Val, Args0] -> Args0; + [_Val, Args0] -> Args0 + end, + Arity = erl_syntax:list_length(Args) + 1, + Hooks = dict:store({HookName, Arity}, + {State#state.file, erl_syntax:get_pos(Form)}, + State#state.run_fold_hooks), + State#state{run_fold_hooks = Hooks} + end. + +analyze_run_fun(Form, State) -> + [Hook|Tail] = erl_syntax:application_arguments(Form), + case atom_value(Hook, State) of + undefined -> + State; + HookName -> + {Module, Fun, Seq} = case Tail of + [_Host, M, F, S] -> + {M, F, S}; + [M, F, S] -> + {M, F, S} + end, + ModName = module_name(Module, State), + FunName = atom_value(Fun, State), + if ModName /= undefined, FunName /= undefined -> + Funs = dict:append( + HookName, + {ModName, FunName, integer_value(Seq, State), + {State#state.file, erl_syntax:get_pos(Form)}}, + State#state.hooked_funs), + State#state{hooked_funs = Funs}; + true -> + State + end + end. + +analyze_iq_handler(Form, State) -> + [_Component, _Host, _NS, Module, Function, _IQDisc] = + erl_syntax:application_arguments(Form), + Mod = module_name(Module, State), + Fun = atom_value(Function, State), + if Mod /= undefined, Fun /= undefined -> + code:ensure_loaded(Mod), + case erlang:function_exported(Mod, Fun, 1) of + false -> + log("~s:~p: Error: function ~s:~s/1 is registered " + "as iq handler, but is not exported~n", + [State#state.file, erl_syntax:get_pos(Form), + Mod, Fun]); + true -> + ok + end; + true -> + ok + end, + State. + +analyze_type_spec(Form, State) -> + case catch erl_syntax:revert(Form) of + {attribute, _, spec, {{F, A}, _}} -> + Specs = dict:store({State#state.module, F, A}, + {Form, State#state.file}, + State#state.specs), + State#state{specs = Specs}; + _ -> + State + end. + +build_deps(Hooks, Hooked) -> + dict:fold( + fun({Hook, Arity}, {_File, _LineNo} = Meta, Deps) -> + case dict:find(Hook, Hooked) of + {ok, Funs} -> + ExportedFuns = + lists:flatmap( + fun({M, F, Seq, {FunFile, FunLineNo} = FunMeta}) -> + code:ensure_loaded(M), + case erlang:function_exported(M, F, Arity) of + false -> + log("~s:~p: Error: function ~s:~s/~p " + "is hooked on ~s/~p, but is not " + "exported~n", + [FunFile, FunLineNo, M, F, + Arity, Hook, Arity]), + []; + true -> + [{{M, F, Arity}, Seq, FunMeta}] + end + end, Funs), + dict:append_list({Hook, Arity, Meta}, ExportedFuns, Deps); + error -> + %% log("~s:~p: Warning: hook ~p/~p is unused~n", + %% [_File, _LineNo, Hook, Arity]), + dict:append_list({Hook, Arity, Meta}, [], Deps) + end + end, dict:new(), Hooks). + +report_orphaned_funs(State) -> + dict:map( + fun(Hook, Funs) -> + lists:foreach( + fun({M, F, _, {File, Line}}) -> + case get_fun_arities(M, F, State) of + [] -> + log("~s:~p: Error: function ~s:~s is " + "hooked on hook ~s, but is not exported~n", + [File, Line, M, F, Hook]); + Arities -> + case lists:any( + fun(Arity) -> + dict:is_key({Hook, Arity}, + State#state.run_hooks) orelse + dict:is_key({Hook, Arity}, + State#state.run_fold_hooks); + (_) -> + false + end, Arities) of + false -> + Arity = hd(Arities), + log("~s:~p: Error: function ~s:~s/~p is hooked" + " on non-existent hook ~s/~p~n", + [File, Line, M, F, Arity, Hook, Arity]); + true -> + ok + end + end + end, Funs) + end, State#state.hooked_funs). + +get_fun_arities(Mod, Fun, _State) -> + proplists:get_all_values(Fun, Mod:module_info(exports)). + +module_name(Form, State) -> + try + Name = erl_syntax:macro_name(Form), + 'MODULE' = erl_syntax:variable_name(Name), + State#state.module + catch _:_ -> + atom_value(Form, State) + end. + +atom_value(Form, State) -> + case erl_syntax:type(Form) of + atom -> + erl_syntax:atom_value(Form); + _ -> + log("~s:~p: Warning: not an atom: ~s~n", + [State#state.file, + erl_syntax:get_pos(Form), + erl_prettypr:format(Form)]), + undefined + end. + +integer_value(Form, State) -> + case erl_syntax:type(Form) of + integer -> + erl_syntax:integer_value(Form); + _ -> + log("~s:~p: Warning: not an integer: ~s~n", + [State#state.file, + erl_syntax:get_pos(Form), + erl_prettypr:format(Form)]), + 0 + end. + +emit_module(RunDeps, RunFoldDeps, Specs, Dir, Module) -> + File = filename:join([Dir, Module]) ++ ".erl", + try + {ok, Fd} = file:open(File, [write]), + write(Fd, "-module(~s).~n~n", [Module]), + emit_export(Fd, RunDeps, "run hooks"), + emit_export(Fd, RunFoldDeps, "run_fold hooks"), + emit_run_hooks(Fd, RunDeps, Specs), + emit_run_fold_hooks(Fd, RunFoldDeps, Specs), + write(Fd, "bypass_stop({stop, Acc}) -> Acc;~n" + "bypass_stop(Acc) -> Acc.~n", []), + file:close(Fd), + log("Module written to file ~s~n", [File]) + catch _:{badmatch, {error, Reason}} -> + log("writing to ~s failed: ~s", [File, file:format_error(Reason)]) + end. + +emit_run_hooks(Fd, Deps, Specs) -> + DepsList = lists:sort(dict:to_list(Deps)), + lists:foreach( + fun({{Hook, Arity, {File, LineNo}}, []}) -> + Args = lists:duplicate(Arity, "_"), + write(Fd, "%% called at ~s:~p~n", [File, LineNo]), + write(Fd, "~s(~s) -> ok.~n~n", [Hook, string:join(Args, ", ")]); + ({{Hook, Arity, {File, LineNo}}, Funs}) -> + emit_specs(Fd, Funs, Specs), + write(Fd, "%% called at ~s:~p~n", [File, LineNo]), + Args = string:join( + [[N] || N <- lists:sublist(lists:seq($A, $Z), Arity)], + ", "), + write(Fd, "~s(~s) ->~n ", [Hook, Args]), + Calls = [io_lib:format("~s:~s(~s)", [Mod, Fun, Args]) + || {{Mod, Fun, _}, _Seq, _} <- lists:keysort(2, Funs)], + write(Fd, "~s.~n~n", [string:join(Calls, ",\n ")]) + end, DepsList). + +emit_run_fold_hooks(Fd, Deps, Specs) -> + DepsList = lists:sort(dict:to_list(Deps)), + lists:foreach( + fun({{Hook, Arity, {File, LineNo}}, []}) -> + write(Fd, "%% called at ~s:~p~n", [File, LineNo]), + Args = ["Acc"|lists:duplicate(Arity - 1, "_")], + write(Fd, "~s(~s) -> Acc.~n~n", [Hook, string:join(Args, ", ")]); + ({{Hook, Arity, {File, LineNo}}, Funs}) -> + emit_specs(Fd, Funs, Specs), + write(Fd, "%% called at ~s:~p~n", [File, LineNo]), + Args = [[N] || N <- lists:sublist(lists:seq($A, $Z), Arity - 1)], + write(Fd, "~s(~s) ->", [Hook, string:join(["Acc"|Args], ", ")]), + FunsCascade = make_funs_cascade( + lists:reverse(lists:keysort(2, Funs)), + 1, Args), + write(Fd, "~s.~n~n", [FunsCascade]) + end, DepsList). + +make_funs_cascade([{{Mod, Fun, _}, _Seq, _}|Funs], N, Args) -> + io_lib:format("~n~sbypass_stop(~s:~s(~s))", + [lists:duplicate(N, " "), + Mod, Fun, string:join([make_funs_cascade(Funs, N+1, Args)|Args], ", ")]); +make_funs_cascade([], _N, _Args) -> + "Acc". + +emit_export(Fd, Deps, Comment) -> + DepsList = lists:sort(dict:to_list(Deps)), + Exports = lists:map( + fun({{Hook, Arity, _}, _}) -> + io_lib:format("~s/~p", [Hook, Arity]) + end, DepsList), + write(Fd, "%% ~s~n-export([~s]).~n~n", + [Comment, string:join(Exports, ",\n ")]). + +emit_specs(Fd, Funs, Specs) -> + lists:foreach( + fun({{M, _, _} = MFA, _, _}) -> + case dict:find(MFA, Specs) of + {ok, {Form, _File}} -> + Lines = string:tokens(erl_syntax:get_ann(Form), "\n"), + lists:foreach( + fun("%" ++ _) -> + ok; + ("-spec" ++ Spec) -> + write(Fd, "%% -spec ~p:~s~n", + [M, string:strip(Spec, left)]); + (Line) -> + write(Fd, "%% ~s~n", [Line]) + end, Lines); + error -> + ok + end + end, lists:keysort(2, Funs)). + +get_forms(Path) -> + case file:open(Path, [read]) of + {ok, Fd} -> + parse(Path, Fd, 1, []); + Err -> + Err + end. + +parse(Path, Fd, Line, Acc) -> + {ok, Pos} = file:position(Fd, cur), + case epp_dodger:parse_form(Fd, Line) of + {ok, Form, NewLine} -> + {ok, NewPos} = file:position(Fd, cur), + {ok, RawForm} = file:pread(Fd, Pos, NewPos - Pos), + file:position(Fd, {bof, NewPos}), + AnnForm = erl_syntax:set_ann(Form, RawForm), + parse(Path, Fd, NewLine, [AnnForm|Acc]); + {eof, _} -> + {ok, NewPos} = file:position(Fd, cur), + if NewPos > Pos -> + {ok, RawForm} = file:pread(Fd, Pos, NewPos - Pos), + Form = erl_syntax:text(""), + AnnForm = erl_syntax:set_ann(Form, RawForm), + {ok, lists:reverse([AnnForm|Acc])}; + true -> + {ok, lists:reverse(Acc)} + end; + {error, {_, _, ErrDesc}, LineNo} = Err -> + log("~s:~p: Error: ~s~n", + [Path, LineNo, erl_parse:format_error(ErrDesc)]), + Err + end. + +log(Format, Args) -> + io:format(standard_io, Format, Args). + +write(Fd, Format, Args) -> + file:write(Fd, io_lib:format(Format, Args)).