diff --git a/src/mod_mqtt.erl b/src/mod_mqtt.erl index a5214a922..24d033892 100644 --- a/src/mod_mqtt.erl +++ b/src/mod_mqtt.erl @@ -37,6 +37,8 @@ -export([open_session/1, close_session/1, lookup_session/1, publish/3, subscribe/4, unsubscribe/2, select_retained/4, check_publish_access/2, check_subscribe_access/2]). +%% ejabberd_hooks +-export([remove_user/2]). -include("logger.hrl"). -include("mqtt.hrl"). @@ -53,6 +55,7 @@ -callback open_session(jid:ljid()) -> ok | {error, db_failure}. -callback close_session(jid:ljid()) -> ok | {error, db_failure}. -callback lookup_session(jid:ljid()) -> {ok, pid()} | {error, notfound | db_failure}. +-callback get_sessions(binary(), binary()) -> [jid:ljid()]. -callback subscribe(jid:ljid(), binary(), sub_opts(), non_neg_integer()) -> ok | {error, db_failure}. -callback unsubscribe(jid:ljid(), binary()) -> ok | {error, notfound | db_failure}. -callback find_subscriber(binary(), binary() | continuation()) -> @@ -71,7 +74,7 @@ -optional_callbacks([use_cache/1, cache_nodes/1]). --record(state, {}). +-record(state, {host :: binary()}). %%%=================================================================== %%% API @@ -163,6 +166,13 @@ select_retained({_, S, _} = USR, TopicFilter, QoS, SubID) -> Limit = mod_mqtt_opt:match_retained_limit(S), select_retained(Mod, USR, TopicFilter, QoS, SubID, Limit). +remove_user(User, Server) -> + LUser = jid:nodeprep(User), + LServer = jid:nameprep(Server), + Mod = gen_mod:ram_db_mod(LServer, ?MODULE), + Sessions = Mod:get_sessions(LUser, LServer), + [close_session(Session) || Session <- Sessions]. + %%%=================================================================== %%% gen_server callbacks %%%=================================================================== @@ -170,11 +180,12 @@ init([Host|_]) -> Opts = gen_mod:get_module_opts(Host, ?MODULE), Mod = gen_mod:db_mod(Opts, ?MODULE), RMod = gen_mod:ram_db_mod(Opts, ?MODULE), + ejabberd_hooks:add(remove_user, Host, ?MODULE, remove_user, 50), try ok = Mod:init(Host, Opts), ok = RMod:init(), ok = init_cache(Mod, Host, Opts), - {ok, #state{}} + {ok, #state{host = Host}} catch _:{badmatch, {error, Why}} -> {stop, Why} end. @@ -191,7 +202,8 @@ handle_info(Info, State) -> ?WARNING_MSG("Unexpected info: ~p", [Info]), {noreply, State}. -terminate(_Reason, _State) -> +terminate(_Reason, #state{host = Host}) -> + ejabberd_hooks:delete(remove_user, Host, ?MODULE, remove_user, 50), ok. code_change(_OldVsn, State, _Extra) -> diff --git a/src/mod_mqtt_mnesia.erl b/src/mod_mqtt_mnesia.erl index a5b66bf5b..92c43d2ee 100644 --- a/src/mod_mqtt_mnesia.erl +++ b/src/mod_mqtt_mnesia.erl @@ -23,7 +23,7 @@ -export([list_topics/1, use_cache/1]). -export([init/0]). -export([subscribe/4, unsubscribe/2, find_subscriber/2]). --export([open_session/1, close_session/1, lookup_session/1]). +-export([open_session/1, close_session/1, lookup_session/1, get_sessions/2]). -include("logger.hrl"). -include("mqtt.hrl"). @@ -46,9 +46,9 @@ pid :: pid(), timestamp :: erlang:timestamp()}). --record(mqtt_session, {usr :: jid:ljid(), - pid :: pid(), - timestamp :: erlang:timestamp()}). +-record(mqtt_session, {usr :: jid:ljid() | {'_', '_', '$1'}, + pid :: pid() | '_', + timestamp :: erlang:timestamp() | '_'}). %%%=================================================================== %%% API @@ -196,6 +196,14 @@ lookup_session(USR) -> {error, notfound} end. +get_sessions(U, S) -> + Resources = mnesia:dirty_select(mqtt_session, + [{#mqtt_session{usr = {U, S, '$1'}, + _ = '_'}, + [], + ['$1']}]), + [{U, S, Resource} || Resource <- Resources]. + subscribe({U, S, R} = USR, TopicFilter, SubOpts, ID) -> T1 = misc:unique_timestamp(), P1 = self(), diff --git a/src/mod_mqtt_sql.erl b/src/mod_mqtt_sql.erl index 3bc4d927f..fefd000cd 100644 --- a/src/mod_mqtt_sql.erl +++ b/src/mod_mqtt_sql.erl @@ -24,7 +24,7 @@ %% Unsupported backend API -export([init/0]). -export([subscribe/4, unsubscribe/2, find_subscriber/2]). --export([open_session/1, close_session/1, lookup_session/1]). +-export([open_session/1, close_session/1, lookup_session/1, get_sessions/2]). -include("logger.hrl"). -include("ejabberd_sql_pt.hrl"). @@ -125,6 +125,9 @@ close_session(_) -> lookup_session(_) -> erlang:nif_error(unsupported_db). +get_sessions(_, _) -> + erlang:nif_error(unsupported_db). + subscribe(_, _, _, _) -> erlang:nif_error(unsupported_db).