25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-12-20 17:27:00 +01:00

Make it possible to define 'sm_db_type' per virtual host

This commit is contained in:
Evgeniy Khramtsov 2016-02-19 16:15:11 +03:00
parent eece6e69cb
commit 4b0860e7de
5 changed files with 67 additions and 30 deletions

View File

@ -296,3 +296,17 @@ CREATE TABLE archive_prefs (
never text NOT NULL,
created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE sm (
usec bigint NOT NULL,
pid text NOT NULL,
node text NOT NULL,
username text NOT NULL,
resource text NOT NULL,
priority text NOT NULL,
info text NOT NULL
);
CREATE UNIQUE INDEX i_sm_sid ON sm(usec, pid);
CREATE INDEX i_sm_node ON sm(node);
CREATE INDEX i_sm_username ON sm(username);

View File

@ -50,6 +50,7 @@
dirty_get_my_sessions_list/0,
get_vh_session_list/1,
get_vh_session_number/1,
get_vh_by_backend/1,
register_iq_handler/4,
register_iq_handler/5,
unregister_iq_handler/2,
@ -133,10 +134,10 @@ open_session(SID, User, Server, Resource, Info) ->
-spec close_session(sid(), binary(), binary(), binary()) -> ok.
close_session(SID, User, Server, Resource) ->
Mod = get_sm_backend(),
LUser = jid:nodeprep(User),
LServer = jid:nameprep(Server),
LResource = jid:resourceprep(Resource),
Mod = get_sm_backend(LServer),
Info = case Mod:delete_session(LUser, LServer, LResource, SID) of
{ok, #session{info = I}} -> I;
{error, notfound} -> []
@ -172,14 +173,14 @@ disconnect_removed_user(User, Server) ->
get_user_resources(User, Server) ->
LUser = jid:nodeprep(User),
LServer = jid:nameprep(Server),
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
Ss = Mod:get_sessions(LUser, LServer),
[element(3, S#session.usr) || S <- clean_session_list(Ss)].
-spec get_user_present_resources(binary(), binary()) -> [tuple()].
get_user_present_resources(LUser, LServer) ->
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
Ss = Mod:get_sessions(LUser, LServer),
[{S#session.priority, element(3, S#session.usr)}
|| S <- clean_session_list(Ss), is_integer(S#session.priority)].
@ -190,7 +191,7 @@ get_user_ip(User, Server, Resource) ->
LUser = jid:nodeprep(User),
LServer = jid:nameprep(Server),
LResource = jid:resourceprep(Resource),
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
case Mod:get_sessions(LUser, LServer, LResource) of
[] ->
undefined;
@ -205,7 +206,7 @@ get_user_info(User, Server, Resource) ->
LUser = jid:nodeprep(User),
LServer = jid:nameprep(Server),
LResource = jid:resourceprep(Resource),
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
case Mod:get_sessions(LUser, LServer, LResource) of
[] ->
offline;
@ -255,7 +256,7 @@ get_session_pid(User, Server, Resource) ->
LUser = jid:nodeprep(User),
LServer = jid:nameprep(Server),
LResource = jid:resourceprep(Resource),
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
case Mod:get_sessions(LUser, LServer, LResource) of
[#session{sid = {_, Pid}}] -> Pid;
_ -> none
@ -264,33 +265,40 @@ get_session_pid(User, Server, Resource) ->
-spec dirty_get_sessions_list() -> [ljid()].
dirty_get_sessions_list() ->
Mod = get_sm_backend(),
[S#session.usr || S <- Mod:get_sessions()].
lists:flatmap(
fun(Mod) ->
[S#session.usr || S <- Mod:get_sessions()]
end, get_sm_backends()).
-spec dirty_get_my_sessions_list() -> [#session{}].
dirty_get_my_sessions_list() ->
Mod = get_sm_backend(),
[S || S <- Mod:get_sessions(), node(element(2, S#session.sid)) == node()].
lists:flatmap(
fun(Mod) ->
[S || S <- Mod:get_sessions(),
node(element(2, S#session.sid)) == node()]
end, get_sm_backends()).
-spec get_vh_session_list(binary()) -> [ljid()].
get_vh_session_list(Server) ->
LServer = jid:nameprep(Server),
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
[S#session.usr || S <- Mod:get_sessions(LServer)].
-spec get_all_pids() -> [pid()].
get_all_pids() ->
Mod = get_sm_backend(),
[element(2, S#session.sid) || S <- Mod:get_sessions()].
lists:flatmap(
fun(Mod) ->
[element(2, S#session.sid) || S <- Mod:get_sessions()]
end, get_sm_backends()).
-spec get_vh_session_number(binary()) -> non_neg_integer().
get_vh_session_number(Server) ->
LServer = jid:nameprep(Server),
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
length(Mod:get_sessions(LServer)).
register_iq_handler(Host, XMLNS, Module, Fun) ->
@ -312,8 +320,7 @@ unregister_iq_handler(Host, XMLNS) ->
%%====================================================================
init([]) ->
Mod = get_sm_backend(),
Mod:init(),
lists:foreach(fun(Mod) -> Mod:init() end, get_sm_backends()),
ets:new(sm_iqtable, [named_table]),
lists:foreach(
fun(Host) ->
@ -380,7 +387,7 @@ set_session(SID, User, Server, Resource, Priority, Info) ->
LResource = jid:resourceprep(Resource),
US = {LUser, LServer},
USR = {LUser, LServer, LResource},
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
Mod:set_session(#session{sid = SID, usr = USR, us = US,
priority = Priority, info = Info}).
@ -397,7 +404,7 @@ do_route(From, To, {broadcast, _} = Packet) ->
get_user_resources(To#jid.user, To#jid.server));
_ ->
{U, S, R} = jid:tolower(To),
Mod = get_sm_backend(),
Mod = get_sm_backend(S),
case Mod:get_sessions(U, S, R) of
[] ->
?DEBUG("packet dropped~n", []);
@ -498,7 +505,7 @@ do_route(From, To, #xmlel{} = Packet) ->
_ -> ok
end;
_ ->
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
case Mod:get_sessions(LUser, LServer, LResource) of
[] ->
case Name of
@ -565,7 +572,7 @@ route_message(From, To, Packet, Type) ->
lists:foreach(fun ({P, R}) when P == Priority;
(P >= 0) and (Type == headline) ->
LResource = jid:resourceprep(R),
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
case Mod:get_sessions(LUser, LServer,
LResource) of
[] ->
@ -647,11 +654,11 @@ get_resource_sessions(User, Server, Resource) ->
LUser = jid:nodeprep(User),
LServer = jid:nameprep(Server),
LResource = jid:resourceprep(Resource),
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
[S#session.sid || S <- Mod:get_sessions(LUser, LServer, LResource)].
check_max_sessions(LUser, LServer) ->
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
SIDs = [S#session.sid || S <- Mod:get_sessions(LUser, LServer)],
MaxSessions = get_max_user_sessions(LUser, LServer),
if length(SIDs) =< MaxSessions -> ok;
@ -703,17 +710,17 @@ process_iq(From, To, Packet) ->
-spec force_update_presence({binary(), binary()}) -> any().
force_update_presence({LUser, LServer}) ->
Mod = get_sm_backend(),
Mod = get_sm_backend(LServer),
Ss = Mod:get_sessions(LUser, LServer),
lists:foreach(fun (#session{sid = {_, Pid}}) ->
Pid ! {force_update_presence, LUser, LServer}
end,
Ss).
-spec get_sm_backend() -> module().
-spec get_sm_backend(binary()) -> module().
get_sm_backend() ->
DBType = ejabberd_config:get_option(sm_db_type,
get_sm_backend(Host) ->
DBType = ejabberd_config:get_option({sm_db_type, Host},
fun(mnesia) -> mnesia;
(internal) -> mnesia;
(odbc) -> odbc;
@ -721,6 +728,19 @@ get_sm_backend() ->
end, mnesia),
list_to_atom("ejabberd_sm_" ++ atom_to_list(DBType)).
-spec get_sm_backends() -> [module()].
get_sm_backends() ->
lists:usort([get_sm_backend(Host) || Host <- ?MYHOSTS]).
-spec get_vh_by_backend(module()) -> [binary()].
get_vh_by_backend(Mod) ->
lists:filter(
fun(Host) ->
get_sm_backend(Host) == Mod
end, ?MYHOSTS).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% ejabberd commands

View File

@ -43,7 +43,7 @@ init() ->
end;
(_, Err) ->
Err
end, ok, ?MYHOSTS).
end, ok, ejabberd_sm:get_vh_by_backend(?MODULE)).
set_session(#session{sid = {Now, Pid}, usr = {U, LServer, R},
priority = Priority, info = Info}) ->
@ -90,7 +90,7 @@ get_sessions() ->
lists:flatmap(
fun(LServer) ->
get_sessions(LServer)
end, ?MYHOSTS).
end, ejabberd_sm:get_vh_by_backend(?MODULE)).
get_sessions(LServer) ->
case ejabberd_odbc:sql_query(

View File

@ -107,7 +107,7 @@ get_sessions() ->
lists:flatmap(
fun(LServer) ->
get_sessions(LServer)
end, ?MYHOSTS).
end, ejabberd_sm:get_vh_by_backend(?MODULE)).
-spec get_sessions(binary()) -> [#session{}].
get_sessions(LServer) ->
@ -204,7 +204,7 @@ clean_table() ->
?ERROR_MSG("failed to clean redis table for "
"server ~s: ~p", [LServer, Err])
end
end, ?MYHOSTS).
end, ejabberd_sm:get_vh_by_backend(?MODULE)).
opt_type(redis_connect_timeout) ->
fun (I) when is_integer(I), I > 0 -> I end;

View File

@ -8,6 +8,7 @@ host_config:
odbc_password: "@@pgsql_pass@@"
odbc_database: "@@pgsql_db@@"
auth_method: odbc
sm_db_type: odbc
modules:
mod_announce:
db_type: odbc
@ -60,6 +61,7 @@ Welcome to this XMPP server."
"sqlite.localhost":
odbc_type: sqlite
auth_method: odbc
sm_db_type: odbc
modules:
mod_announce:
db_type: odbc
@ -118,6 +120,7 @@ Welcome to this XMPP server."
odbc_password: "@@mysql_pass@@"
odbc_database: "@@mysql_db@@"
auth_method: odbc
sm_db_type: odbc
modules:
mod_announce:
db_type: odbc