25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-11-24 16:23:40 +01:00

Update mod_roster and ejabberd_auth_odbc SQL queries to the new API

This commit is contained in:
Alexey Shchepin 2016-02-15 21:02:22 +03:00
parent 7f3bffe821
commit 3d8219d8f9
3 changed files with 144 additions and 176 deletions

View File

@ -72,22 +72,18 @@ check_password(User, Server, Password) ->
(LUser == <<>>) or (LServer == <<>>) -> (LUser == <<>>) or (LServer == <<>>) ->
false; false;
true -> true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of case is_scrammed() of
true -> true ->
try odbc_queries:get_password_scram(LServer, Username) of try odbc_queries:get_password_scram(LServer, LUser) of
{selected, [<<"password">>, <<"serverkey">>, {selected,
<<"salt">>, <<"iterationcount">>], [{StoredKey, ServerKey, Salt, IterationCount}]} ->
[[StoredKey, ServerKey, Salt, IterationCount]]} ->
Scram = Scram =
#scram{storedkey = StoredKey, #scram{storedkey = StoredKey,
serverkey = ServerKey, serverkey = ServerKey,
salt = Salt, salt = Salt,
iterationcount = binary_to_integer( iterationcount = IterationCount},
IterationCount)},
is_password_scram_valid(Password, Scram); is_password_scram_valid(Password, Scram);
{selected, [<<"password">>, <<"serverkey">>, {selected, []} ->
<<"salt">>, <<"iterationcount">>], []} ->
false; %% Account does not exist false; %% Account does not exist
{error, _Error} -> {error, _Error} ->
false %% Typical error is that table doesn't exist false %% Typical error is that table doesn't exist
@ -96,12 +92,12 @@ check_password(User, Server, Password) ->
false %% Typical error is database not accessible false %% Typical error is database not accessible
end; end;
false -> false ->
try odbc_queries:get_password(LServer, Username) of try odbc_queries:get_password(LServer, LUser) of
{selected, [<<"password">>], [[Password]]} -> {selected, [{Password}]} ->
Password /= <<"">>; Password /= <<"">>;
{selected, [<<"password">>], [[_Password2]]} -> {selected, [{_Password2}]} ->
false; %% Password is not correct false; %% Password is not correct
{selected, [<<"password">>], []} -> {selected, []} ->
false; %% Account does not exist false; %% Account does not exist
{error, _Error} -> {error, _Error} ->
false %% Typical error is that table doesn't exist false %% Typical error is that table doesn't exist
@ -124,10 +120,9 @@ check_password(User, Server, Password, Digest,
true -> true ->
case is_scrammed() of case is_scrammed() of
false -> false ->
Username = ejabberd_odbc:escape(LUser), try odbc_queries:get_password(LServer, LUser) of
try odbc_queries:get_password(LServer, Username) of
%% Account exists, check if password is valid %% Account exists, check if password is valid
{selected, [<<"password">>], [[Passwd]]} -> {selected, [{Passwd}]} ->
DigRes = if Digest /= <<"">> -> DigRes = if Digest /= <<"">> ->
Digest == DigestGen(Passwd); Digest == DigestGen(Passwd);
true -> false true -> false
@ -135,7 +130,7 @@ check_password(User, Server, Password, Digest,
if DigRes -> true; if DigRes -> true;
true -> (Passwd == Password) and (Password /= <<"">>) true -> (Passwd == Password) and (Password /= <<"">>)
end; end;
{selected, [<<"password">>], []} -> {selected, []} ->
false; %% Account does not exist false; %% Account does not exist
{error, _Error} -> {error, _Error} ->
false %% Typical error is that table doesn't exist false %% Typical error is that table doesn't exist
@ -267,24 +262,22 @@ get_password(User, Server) ->
(LUser == <<>>) or (LServer == <<>>) -> (LUser == <<>>) or (LServer == <<>>) ->
false; false;
true -> true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of case is_scrammed() of
true -> true ->
case catch odbc_queries:get_password_scram( case catch odbc_queries:get_password_scram(
LServer, Username) of LServer, LUser) of
{selected, [<<"password">>, <<"serverkey">>, {selected,
<<"salt">>, <<"iterationcount">>], [{StoredKey, ServerKey, Salt, IterationCount}]} ->
[[StoredKey, ServerKey, Salt, IterationCount]]} ->
{jlib:decode_base64(StoredKey), {jlib:decode_base64(StoredKey),
jlib:decode_base64(ServerKey), jlib:decode_base64(ServerKey),
jlib:decode_base64(Salt), jlib:decode_base64(Salt),
binary_to_integer(IterationCount)}; IterationCount};
_ -> false _ -> false
end; end;
false -> false ->
case catch odbc_queries:get_password(LServer, Username) case catch odbc_queries:get_password(LServer, LUser)
of of
{selected, [<<"password">>], [[Password]]} -> Password; {selected, [{Password}]} -> Password;
_ -> false _ -> false
end end
end end
@ -300,9 +293,8 @@ get_password_s(User, Server) ->
true -> true ->
case is_scrammed() of case is_scrammed() of
false -> false ->
Username = ejabberd_odbc:escape(LUser), case catch odbc_queries:get_password(LServer, LUser) of
case catch odbc_queries:get_password(LServer, Username) of {selected, [{Password}]} -> Password;
{selected, [<<"password">>], [[Password]]} -> Password;
_ -> <<"">> _ -> <<"">>
end; end;
true -> <<"">> true -> <<"">>
@ -311,15 +303,17 @@ get_password_s(User, Server) ->
%% @spec (User, Server) -> true | false | {error, Error} %% @spec (User, Server) -> true | false | {error, Error}
is_user_exists(User, Server) -> is_user_exists(User, Server) ->
case jid:nodeprep(User) of LServer = jid:nameprep(Server),
error -> false; LUser = jid:nodeprep(User),
LUser -> if (LUser == error) or (LServer == error) ->
Username = ejabberd_odbc:escape(LUser), false;
LServer = jid:nameprep(Server), (LUser == <<>>) or (LServer == <<>>) ->
try odbc_queries:get_password(LServer, Username) of false;
{selected, [<<"password">>], [[_Password]]} -> true ->
try odbc_queries:get_password(LServer, LUser) of
{selected, [{_Password}]} ->
true; %% Account exists true; %% Account exists
{selected, [<<"password">>], []} -> {selected, []} ->
false; %% Account does not exist false; %% Account does not exist
{error, Error} -> {error, Error} {error, Error} -> {error, Error}
catch catch

View File

@ -203,11 +203,9 @@ read_roster_version(LUser, LServer, mnesia) ->
[] -> error [] -> error
end; end;
read_roster_version(LUser, LServer, odbc) -> read_roster_version(LUser, LServer, odbc) ->
Username = ejabberd_odbc:escape(LUser), case odbc_queries:get_roster_version(LServer, LUser) of
case odbc_queries:get_roster_version(LServer, Username) {selected, [{Version}]} -> Version;
of {selected, []} -> error
{selected, [<<"version">>], [[Version]]} -> Version;
{selected, [<<"version">>], []} -> error
end; end;
read_roster_version(LServer, LUser, riak) -> read_roster_version(LServer, LUser, riak) ->
case ejabberd_riak:get(roster_version, roster_version_schema(), case ejabberd_riak:get(roster_version, roster_version_schema(),
@ -369,46 +367,37 @@ get_roster(LUser, LServer, riak) ->
_Err -> [] _Err -> []
end; end;
get_roster(LUser, LServer, odbc) -> get_roster(LUser, LServer, odbc) ->
Username = ejabberd_odbc:escape(LUser), case catch odbc_queries:get_roster(LServer, LUser) of
case catch odbc_queries:get_roster(LServer, Username) of {selected, Items} when is_list(Items) ->
{selected, JIDGroups = case catch odbc_queries:get_roster_jid_groups(
[<<"username">>, <<"jid">>, <<"nick">>, LServer, LUser) of
<<"subscription">>, <<"ask">>, <<"askmessage">>, {selected, JGrps}
<<"server">>, <<"subscribe">>, <<"type">>], when is_list(JGrps) ->
Items} JGrps;
when is_list(Items) -> _ -> []
JIDGroups = case catch end,
odbc_queries:get_roster_jid_groups(LServer, GroupsDict = lists:foldl(fun({J, G}, Acc) ->
Username) dict:append(J, G, Acc)
of end,
{selected, [<<"jid">>, <<"grp">>], JGrps} dict:new(), JIDGroups),
when is_list(JGrps) -> RItems =
JGrps; lists:flatmap(
_ -> [] fun(I) ->
end, case raw_to_record(LServer, I) of
GroupsDict = lists:foldl(fun ([J, G], Acc) -> %% Bad JID in database:
dict:append(J, G, Acc) error -> [];
end, R ->
dict:new(), JIDGroups), SJID = jid:to_string(R#roster.jid),
RItems = lists:flatmap(fun (I) -> Groups = case dict:find(SJID, GroupsDict) of
case raw_to_record(LServer, I) of {ok, Gs} -> Gs;
%% Bad JID in database: error -> []
error -> []; end,
R -> [R#roster{groups = Groups}]
SJID = end
jid:to_string(R#roster.jid), end,
Groups = case dict:find(SJID, Items),
GroupsDict) RItems;
of _ -> []
{ok, Gs} -> Gs;
error -> []
end,
[R#roster{groups = Groups}]
end
end,
Items),
RItems;
_ -> []
end. end.
set_roster(#roster{us = {LUser, LServer}, jid = LJID} = Item) -> set_roster(#roster{us = {LUser, LServer}, jid = LJID} = Item) ->
@ -460,14 +449,8 @@ get_roster_by_jid_t(LUser, LServer, LJID, mnesia) ->
xs = []} xs = []}
end; end;
get_roster_by_jid_t(LUser, LServer, LJID, odbc) -> get_roster_by_jid_t(LUser, LServer, LJID, odbc) ->
Username = ejabberd_odbc:escape(LUser), {selected, Res} =
SJID = ejabberd_odbc:escape(jid:to_string(LJID)), odbc_queries:get_roster_by_jid(LServer, LUser, jid:to_string(LJID)),
{selected,
[<<"username">>, <<"jid">>, <<"nick">>,
<<"subscription">>, <<"ask">>, <<"askmessage">>,
<<"server">>, <<"subscribe">>, <<"type">>],
Res} =
odbc_queries:get_roster_by_jid(LServer, Username, SJID),
case Res of case Res of
[] -> [] ->
#roster{usj = {LUser, LServer, LJID}, #roster{usj = {LUser, LServer, LJID},
@ -750,30 +733,18 @@ get_roster_by_jid_with_groups_t(LUser, LServer, LJID,
end; end;
get_roster_by_jid_with_groups_t(LUser, LServer, LJID, get_roster_by_jid_with_groups_t(LUser, LServer, LJID,
odbc) -> odbc) ->
Username = ejabberd_odbc:escape(LUser), SJID = jid:to_string(LJID),
SJID = ejabberd_odbc:escape(jid:to_string(LJID)), case odbc_queries:get_roster_by_jid(LServer, LUser, SJID) of
case odbc_queries:get_roster_by_jid(LServer, Username, {selected, [I]} ->
SJID) R = raw_to_record(LServer, I),
of Groups =
{selected, case odbc_queries:get_roster_groups(LServer, LUser, SJID) of
[<<"username">>, <<"jid">>, <<"nick">>, {selected, JGrps} when is_list(JGrps) ->
<<"subscription">>, <<"ask">>, <<"askmessage">>, [JGrp || {JGrp} <- JGrps];
<<"server">>, <<"subscribe">>, <<"type">>], _ -> []
[I]} -> end,
R = raw_to_record(LServer, I), R#roster{groups = Groups};
Groups = case odbc_queries:get_roster_groups(LServer, {selected, []} ->
Username, SJID)
of
{selected, [<<"grp">>], JGrps} when is_list(JGrps) ->
[JGrp || [JGrp] <- JGrps];
_ -> []
end,
R#roster{groups = Groups};
{selected,
[<<"username">>, <<"jid">>, <<"nick">>,
<<"subscription">>, <<"ask">>, <<"askmessage">>,
<<"server">>, <<"subscribe">>, <<"type">>],
[]} ->
#roster{usj = {LUser, LServer, LJID}, #roster{usj = {LUser, LServer, LJID},
us = {LUser, LServer}, jid = LJID} us = {LUser, LServer}, jid = LJID}
end; end;
@ -995,8 +966,7 @@ remove_user(LUser, LServer, mnesia) ->
end, end,
mnesia:transaction(F); mnesia:transaction(F);
remove_user(LUser, LServer, odbc) -> remove_user(LUser, LServer, odbc) ->
Username = ejabberd_odbc:escape(LUser), odbc_queries:del_user_roster_t(LServer, LUser),
odbc_queries:del_user_roster_t(LServer, Username),
ok; ok;
remove_user(LUser, LServer, riak) -> remove_user(LUser, LServer, riak) ->
{atomic, ejabberd_riak:delete_by_index(roster, <<"us">>, {LUser, LServer})}. {atomic, ejabberd_riak:delete_by_index(roster, <<"us">>, {LUser, LServer})}.
@ -1243,12 +1213,9 @@ read_subscription_and_groups(LUser, LServer, LJID,
end; end;
read_subscription_and_groups(LUser, LServer, LJID, read_subscription_and_groups(LUser, LServer, LJID,
odbc) -> odbc) ->
Username = ejabberd_odbc:escape(LUser), SJID = jid:to_string(LJID),
SJID = ejabberd_odbc:escape(jid:to_string(LJID)), case catch odbc_queries:get_subscription(LServer, LUser, SJID) of
case catch odbc_queries:get_subscription(LServer, {selected, [{SSubscription}]} ->
Username, SJID)
of
{selected, [<<"subscription">>], [[SSubscription]]} ->
Subscription = case SSubscription of Subscription = case SSubscription of
<<"B">> -> both; <<"B">> -> both;
<<"T">> -> to; <<"T">> -> to;
@ -1256,11 +1223,11 @@ read_subscription_and_groups(LUser, LServer, LJID,
_ -> none _ -> none
end, end,
Groups = case catch Groups = case catch
odbc_queries:get_rostergroup_by_jid(LServer, Username, odbc_queries:get_rostergroup_by_jid(LServer, LUser,
SJID) SJID)
of of
{selected, [<<"grp">>], JGrps} when is_list(JGrps) -> {selected, JGrps} when is_list(JGrps) ->
[JGrp || [JGrp] <- JGrps]; [JGrp || {JGrp} <- JGrps];
_ -> [] _ -> []
end, end,
{Subscription, Groups}; {Subscription, Groups};
@ -1297,6 +1264,12 @@ get_jid_info(_, User, Server, JID) ->
raw_to_record(LServer, raw_to_record(LServer,
[User, SJID, Nick, SSubscription, SAsk, SAskMessage, [User, SJID, Nick, SSubscription, SAsk, SAskMessage,
_SServer, _SSubscribe, _SType]) -> _SServer, _SSubscribe, _SType]) ->
raw_to_record(LServer,
{User, SJID, Nick, SSubscription, SAsk, SAskMessage,
_SServer, _SSubscribe, _SType});
raw_to_record(LServer,
{User, SJID, Nick, SSubscription, SAsk, SAskMessage,
_SServer, _SSubscribe, _SType}) ->
case jid:from_string(SJID) of case jid:from_string(SJID) of
error -> error; error -> error;
JID -> JID ->

View File

@ -139,16 +139,17 @@ del_last(LServer, Username) ->
[<<"delete from last where username='">>, Username, [<<"delete from last where username='">>, Username,
<<"'">>]). <<"'">>]).
get_password(LServer, Username) -> get_password(LServer, LUser) ->
ejabberd_odbc:sql_query(LServer,
[<<"select password from users where username='">>,
Username, <<"';">>]).
get_password_scram(LServer, Username) ->
ejabberd_odbc:sql_query( ejabberd_odbc:sql_query(
LServer, LServer,
[<<"select password, serverkey, salt, iterationcount from users where " ?SQL("select @(password)s from users where username=%(LUser)s")).
"username='">>, Username, <<"';">>]).
get_password_scram(LServer, LUser) ->
ejabberd_odbc:sql_query(
LServer,
?SQL("select @(password)s, @(serverkey)s, @(salt)s, @(iterationcount)d"
" from users"
" where username=%(LUser)s")).
set_password_t(LServer, Username, Pass) -> set_password_t(LServer, Username, Pass) ->
ejabberd_odbc:sql_transaction(LServer, ejabberd_odbc:sql_transaction(LServer,
@ -311,46 +312,46 @@ del_spool_msg(LServer, LUser) ->
LServer, LServer,
?SQL("delete from spool where username=%(LUser)s")). ?SQL("delete from spool where username=%(LUser)s")).
get_roster(LServer, Username) -> get_roster(LServer, LUser) ->
ejabberd_odbc:sql_query(LServer, ejabberd_odbc:sql_query(
[<<"select username, jid, nick, subscription, " LServer,
"ask, askmessage, server, subscribe, " ?SQL("select @(username)s, @(jid)s, @(nick)s, @(subscription)s, "
"type from rosterusers where username='">>, "@(ask)s, @(askmessage)s, @(server)s, @(subscribe)s, "
Username, <<"'">>]). "@(type)s from rosterusers where username=%(LUser)s")).
get_roster_jid_groups(LServer, Username) -> get_roster_jid_groups(LServer, LUser) ->
ejabberd_odbc:sql_query(LServer, ejabberd_odbc:sql_query(
[<<"select jid, grp from rostergroups where " LServer,
"username='">>, ?SQL("select @(jid)s, @(grp)s from rostergroups where "
Username, <<"'">>]). "username=%(LUser)s")).
get_roster_groups(_LServer, Username, SJID) -> get_roster_groups(_LServer, LUser, SJID) ->
ejabberd_odbc:sql_query_t([<<"select grp from rostergroups where username='">>, ejabberd_odbc:sql_query_t(
Username, <<"' and jid='">>, SJID, <<"';">>]). ?SQL("select @(grp)s from rostergroups"
" where username=%(LUser)s and jid=%(SJID)s")).
del_user_roster_t(LServer, Username) -> del_user_roster_t(LServer, LUser) ->
ejabberd_odbc:sql_transaction(LServer, ejabberd_odbc:sql_transaction(
fun () -> LServer,
ejabberd_odbc:sql_query_t([<<"delete from rosterusers where " fun () ->
"username='">>, ejabberd_odbc:sql_query_t(
Username, ?SQL("delete from rosterusers where username=%(LUser)s")),
<<"';">>]), ejabberd_odbc:sql_query_t(
ejabberd_odbc:sql_query_t([<<"delete from rostergroups where " ?SQL("delete from rostergroups where username=%(LUser)s"))
"username='">>, end).
Username,
<<"';">>])
end).
get_roster_by_jid(_LServer, Username, SJID) -> get_roster_by_jid(_LServer, LUser, SJID) ->
ejabberd_odbc:sql_query_t([<<"select username, jid, nick, subscription, " ejabberd_odbc:sql_query_t(
"ask, askmessage, server, subscribe, " ?SQL("select @(username)s, @(jid)s, @(nick)s, @(subscription)s,"
"type from rosterusers where username='">>, " @(ask)s, @(askmessage)s, @(server)s, @(subscribe)s,"
Username, <<"' and jid='">>, SJID, <<"';">>]). " @(type)s from rosterusers"
" where username=%(LUser)s and jid=%(SJID)s")).
get_rostergroup_by_jid(LServer, Username, SJID) -> get_rostergroup_by_jid(LServer, LUser, SJID) ->
ejabberd_odbc:sql_query(LServer, ejabberd_odbc:sql_query(
[<<"select grp from rostergroups where username='">>, LServer,
Username, <<"' and jid='">>, SJID, <<"'">>]). ?SQL("select @(grp)s from rostergroups"
" where username=%(LUser)s and jid=%(SJID)s")).
del_roster(_LServer, Username, SJID) -> del_roster(_LServer, Username, SJID) ->
ejabberd_odbc:sql_query_t([<<"delete from rosterusers where " ejabberd_odbc:sql_query_t([<<"delete from rosterusers where "
@ -421,11 +422,11 @@ roster_subscribe(_LServer, Username, SJID, ItemVals) ->
[<<"username='">>, Username, <<"' and jid='">>, SJID, [<<"username='">>, Username, <<"' and jid='">>, SJID,
<<"'">>]). <<"'">>]).
get_subscription(LServer, Username, SJID) -> get_subscription(LServer, LUser, SJID) ->
ejabberd_odbc:sql_query(LServer, ejabberd_odbc:sql_query(
[<<"select subscription from rosterusers " LServer,
"where username='">>, ?SQL("select @(subscription)s from rosterusers "
Username, <<"' and jid='">>, SJID, <<"'">>]). "where username=%(LUser)s and jid=%(SJID)s")).
set_private_data(_LServer, Username, LXMLNS, SData) -> set_private_data(_LServer, Username, LXMLNS, SData) ->
update_t(<<"private_storage">>, update_t(<<"private_storage">>,
@ -639,10 +640,10 @@ count_records_where(LServer, Table, WhereClause) ->
WhereClause, <<";">>]). WhereClause, <<";">>]).
get_roster_version(LServer, LUser) -> get_roster_version(LServer, LUser) ->
ejabberd_odbc:sql_query(LServer, ejabberd_odbc:sql_query(
[<<"select version from roster_version where " LServer,
"username = '">>, ?SQL("select @(version)s from roster_version"
LUser, <<"'">>]). " where username = %(LUser)s")).
set_roster_version(LUser, Version) -> set_roster_version(LUser, Version) ->
update_t(<<"roster_version">>, update_t(<<"roster_version">>,