25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-12-24 17:29:28 +01:00

Update mod_roster and ejabberd_auth_odbc SQL queries to the new API

This commit is contained in:
Alexey Shchepin 2016-02-15 21:02:22 +03:00
parent 7f3bffe821
commit 3d8219d8f9
3 changed files with 144 additions and 176 deletions

View File

@ -72,22 +72,18 @@ check_password(User, Server, Password) ->
(LUser == <<>>) or (LServer == <<>>) ->
false;
true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of
true ->
try odbc_queries:get_password_scram(LServer, Username) of
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>],
[[StoredKey, ServerKey, Salt, IterationCount]]} ->
try odbc_queries:get_password_scram(LServer, LUser) of
{selected,
[{StoredKey, ServerKey, Salt, IterationCount}]} ->
Scram =
#scram{storedkey = StoredKey,
serverkey = ServerKey,
salt = Salt,
iterationcount = binary_to_integer(
IterationCount)},
iterationcount = IterationCount},
is_password_scram_valid(Password, Scram);
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>], []} ->
{selected, []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
@ -96,12 +92,12 @@ check_password(User, Server, Password) ->
false %% Typical error is database not accessible
end;
false ->
try odbc_queries:get_password(LServer, Username) of
{selected, [<<"password">>], [[Password]]} ->
try odbc_queries:get_password(LServer, LUser) of
{selected, [{Password}]} ->
Password /= <<"">>;
{selected, [<<"password">>], [[_Password2]]} ->
{selected, [{_Password2}]} ->
false; %% Password is not correct
{selected, [<<"password">>], []} ->
{selected, []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
@ -124,10 +120,9 @@ check_password(User, Server, Password, Digest,
true ->
case is_scrammed() of
false ->
Username = ejabberd_odbc:escape(LUser),
try odbc_queries:get_password(LServer, Username) of
try odbc_queries:get_password(LServer, LUser) of
%% Account exists, check if password is valid
{selected, [<<"password">>], [[Passwd]]} ->
{selected, [{Passwd}]} ->
DigRes = if Digest /= <<"">> ->
Digest == DigestGen(Passwd);
true -> false
@ -135,7 +130,7 @@ check_password(User, Server, Password, Digest,
if DigRes -> true;
true -> (Passwd == Password) and (Password /= <<"">>)
end;
{selected, [<<"password">>], []} ->
{selected, []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
@ -267,24 +262,22 @@ get_password(User, Server) ->
(LUser == <<>>) or (LServer == <<>>) ->
false;
true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of
true ->
case catch odbc_queries:get_password_scram(
LServer, Username) of
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>],
[[StoredKey, ServerKey, Salt, IterationCount]]} ->
LServer, LUser) of
{selected,
[{StoredKey, ServerKey, Salt, IterationCount}]} ->
{jlib:decode_base64(StoredKey),
jlib:decode_base64(ServerKey),
jlib:decode_base64(Salt),
binary_to_integer(IterationCount)};
IterationCount};
_ -> false
end;
false ->
case catch odbc_queries:get_password(LServer, Username)
case catch odbc_queries:get_password(LServer, LUser)
of
{selected, [<<"password">>], [[Password]]} -> Password;
{selected, [{Password}]} -> Password;
_ -> false
end
end
@ -300,9 +293,8 @@ get_password_s(User, Server) ->
true ->
case is_scrammed() of
false ->
Username = ejabberd_odbc:escape(LUser),
case catch odbc_queries:get_password(LServer, Username) of
{selected, [<<"password">>], [[Password]]} -> Password;
case catch odbc_queries:get_password(LServer, LUser) of
{selected, [{Password}]} -> Password;
_ -> <<"">>
end;
true -> <<"">>
@ -311,15 +303,17 @@ get_password_s(User, Server) ->
%% @spec (User, Server) -> true | false | {error, Error}
is_user_exists(User, Server) ->
case jid:nodeprep(User) of
error -> false;
LUser ->
Username = ejabberd_odbc:escape(LUser),
LServer = jid:nameprep(Server),
try odbc_queries:get_password(LServer, Username) of
{selected, [<<"password">>], [[_Password]]} ->
LServer = jid:nameprep(Server),
LUser = jid:nodeprep(User),
if (LUser == error) or (LServer == error) ->
false;
(LUser == <<>>) or (LServer == <<>>) ->
false;
true ->
try odbc_queries:get_password(LServer, LUser) of
{selected, [{_Password}]} ->
true; %% Account exists
{selected, [<<"password">>], []} ->
{selected, []} ->
false; %% Account does not exist
{error, Error} -> {error, Error}
catch

View File

@ -203,11 +203,9 @@ read_roster_version(LUser, LServer, mnesia) ->
[] -> error
end;
read_roster_version(LUser, LServer, odbc) ->
Username = ejabberd_odbc:escape(LUser),
case odbc_queries:get_roster_version(LServer, Username)
of
{selected, [<<"version">>], [[Version]]} -> Version;
{selected, [<<"version">>], []} -> error
case odbc_queries:get_roster_version(LServer, LUser) of
{selected, [{Version}]} -> Version;
{selected, []} -> error
end;
read_roster_version(LServer, LUser, riak) ->
case ejabberd_riak:get(roster_version, roster_version_schema(),
@ -369,46 +367,37 @@ get_roster(LUser, LServer, riak) ->
_Err -> []
end;
get_roster(LUser, LServer, odbc) ->
Username = ejabberd_odbc:escape(LUser),
case catch odbc_queries:get_roster(LServer, Username) of
{selected,
[<<"username">>, <<"jid">>, <<"nick">>,
<<"subscription">>, <<"ask">>, <<"askmessage">>,
<<"server">>, <<"subscribe">>, <<"type">>],
Items}
when is_list(Items) ->
JIDGroups = case catch
odbc_queries:get_roster_jid_groups(LServer,
Username)
of
{selected, [<<"jid">>, <<"grp">>], JGrps}
when is_list(JGrps) ->
JGrps;
_ -> []
end,
GroupsDict = lists:foldl(fun ([J, G], Acc) ->
dict:append(J, G, Acc)
end,
dict:new(), JIDGroups),
RItems = lists:flatmap(fun (I) ->
case raw_to_record(LServer, I) of
%% Bad JID in database:
error -> [];
R ->
SJID =
jid:to_string(R#roster.jid),
Groups = case dict:find(SJID,
GroupsDict)
of
{ok, Gs} -> Gs;
error -> []
end,
[R#roster{groups = Groups}]
end
end,
Items),
RItems;
_ -> []
case catch odbc_queries:get_roster(LServer, LUser) of
{selected, Items} when is_list(Items) ->
JIDGroups = case catch odbc_queries:get_roster_jid_groups(
LServer, LUser) of
{selected, JGrps}
when is_list(JGrps) ->
JGrps;
_ -> []
end,
GroupsDict = lists:foldl(fun({J, G}, Acc) ->
dict:append(J, G, Acc)
end,
dict:new(), JIDGroups),
RItems =
lists:flatmap(
fun(I) ->
case raw_to_record(LServer, I) of
%% Bad JID in database:
error -> [];
R ->
SJID = jid:to_string(R#roster.jid),
Groups = case dict:find(SJID, GroupsDict) of
{ok, Gs} -> Gs;
error -> []
end,
[R#roster{groups = Groups}]
end
end,
Items),
RItems;
_ -> []
end.
set_roster(#roster{us = {LUser, LServer}, jid = LJID} = Item) ->
@ -460,14 +449,8 @@ get_roster_by_jid_t(LUser, LServer, LJID, mnesia) ->
xs = []}
end;
get_roster_by_jid_t(LUser, LServer, LJID, odbc) ->
Username = ejabberd_odbc:escape(LUser),
SJID = ejabberd_odbc:escape(jid:to_string(LJID)),
{selected,
[<<"username">>, <<"jid">>, <<"nick">>,
<<"subscription">>, <<"ask">>, <<"askmessage">>,
<<"server">>, <<"subscribe">>, <<"type">>],
Res} =
odbc_queries:get_roster_by_jid(LServer, Username, SJID),
{selected, Res} =
odbc_queries:get_roster_by_jid(LServer, LUser, jid:to_string(LJID)),
case Res of
[] ->
#roster{usj = {LUser, LServer, LJID},
@ -750,30 +733,18 @@ get_roster_by_jid_with_groups_t(LUser, LServer, LJID,
end;
get_roster_by_jid_with_groups_t(LUser, LServer, LJID,
odbc) ->
Username = ejabberd_odbc:escape(LUser),
SJID = ejabberd_odbc:escape(jid:to_string(LJID)),
case odbc_queries:get_roster_by_jid(LServer, Username,
SJID)
of
{selected,
[<<"username">>, <<"jid">>, <<"nick">>,
<<"subscription">>, <<"ask">>, <<"askmessage">>,
<<"server">>, <<"subscribe">>, <<"type">>],
[I]} ->
R = raw_to_record(LServer, I),
Groups = case odbc_queries:get_roster_groups(LServer,
Username, SJID)
of
{selected, [<<"grp">>], JGrps} when is_list(JGrps) ->
[JGrp || [JGrp] <- JGrps];
_ -> []
end,
R#roster{groups = Groups};
{selected,
[<<"username">>, <<"jid">>, <<"nick">>,
<<"subscription">>, <<"ask">>, <<"askmessage">>,
<<"server">>, <<"subscribe">>, <<"type">>],
[]} ->
SJID = jid:to_string(LJID),
case odbc_queries:get_roster_by_jid(LServer, LUser, SJID) of
{selected, [I]} ->
R = raw_to_record(LServer, I),
Groups =
case odbc_queries:get_roster_groups(LServer, LUser, SJID) of
{selected, JGrps} when is_list(JGrps) ->
[JGrp || {JGrp} <- JGrps];
_ -> []
end,
R#roster{groups = Groups};
{selected, []} ->
#roster{usj = {LUser, LServer, LJID},
us = {LUser, LServer}, jid = LJID}
end;
@ -995,8 +966,7 @@ remove_user(LUser, LServer, mnesia) ->
end,
mnesia:transaction(F);
remove_user(LUser, LServer, odbc) ->
Username = ejabberd_odbc:escape(LUser),
odbc_queries:del_user_roster_t(LServer, Username),
odbc_queries:del_user_roster_t(LServer, LUser),
ok;
remove_user(LUser, LServer, riak) ->
{atomic, ejabberd_riak:delete_by_index(roster, <<"us">>, {LUser, LServer})}.
@ -1243,12 +1213,9 @@ read_subscription_and_groups(LUser, LServer, LJID,
end;
read_subscription_and_groups(LUser, LServer, LJID,
odbc) ->
Username = ejabberd_odbc:escape(LUser),
SJID = ejabberd_odbc:escape(jid:to_string(LJID)),
case catch odbc_queries:get_subscription(LServer,
Username, SJID)
of
{selected, [<<"subscription">>], [[SSubscription]]} ->
SJID = jid:to_string(LJID),
case catch odbc_queries:get_subscription(LServer, LUser, SJID) of
{selected, [{SSubscription}]} ->
Subscription = case SSubscription of
<<"B">> -> both;
<<"T">> -> to;
@ -1256,11 +1223,11 @@ read_subscription_and_groups(LUser, LServer, LJID,
_ -> none
end,
Groups = case catch
odbc_queries:get_rostergroup_by_jid(LServer, Username,
odbc_queries:get_rostergroup_by_jid(LServer, LUser,
SJID)
of
{selected, [<<"grp">>], JGrps} when is_list(JGrps) ->
[JGrp || [JGrp] <- JGrps];
{selected, JGrps} when is_list(JGrps) ->
[JGrp || {JGrp} <- JGrps];
_ -> []
end,
{Subscription, Groups};
@ -1297,6 +1264,12 @@ get_jid_info(_, User, Server, JID) ->
raw_to_record(LServer,
[User, SJID, Nick, SSubscription, SAsk, SAskMessage,
_SServer, _SSubscribe, _SType]) ->
raw_to_record(LServer,
{User, SJID, Nick, SSubscription, SAsk, SAskMessage,
_SServer, _SSubscribe, _SType});
raw_to_record(LServer,
{User, SJID, Nick, SSubscription, SAsk, SAskMessage,
_SServer, _SSubscribe, _SType}) ->
case jid:from_string(SJID) of
error -> error;
JID ->

View File

@ -139,16 +139,17 @@ del_last(LServer, Username) ->
[<<"delete from last where username='">>, Username,
<<"'">>]).
get_password(LServer, Username) ->
ejabberd_odbc:sql_query(LServer,
[<<"select password from users where username='">>,
Username, <<"';">>]).
get_password_scram(LServer, Username) ->
get_password(LServer, LUser) ->
ejabberd_odbc:sql_query(
LServer,
[<<"select password, serverkey, salt, iterationcount from users where "
"username='">>, Username, <<"';">>]).
?SQL("select @(password)s from users where username=%(LUser)s")).
get_password_scram(LServer, LUser) ->
ejabberd_odbc:sql_query(
LServer,
?SQL("select @(password)s, @(serverkey)s, @(salt)s, @(iterationcount)d"
" from users"
" where username=%(LUser)s")).
set_password_t(LServer, Username, Pass) ->
ejabberd_odbc:sql_transaction(LServer,
@ -311,46 +312,46 @@ del_spool_msg(LServer, LUser) ->
LServer,
?SQL("delete from spool where username=%(LUser)s")).
get_roster(LServer, Username) ->
ejabberd_odbc:sql_query(LServer,
[<<"select username, jid, nick, subscription, "
"ask, askmessage, server, subscribe, "
"type from rosterusers where username='">>,
Username, <<"'">>]).
get_roster(LServer, LUser) ->
ejabberd_odbc:sql_query(
LServer,
?SQL("select @(username)s, @(jid)s, @(nick)s, @(subscription)s, "
"@(ask)s, @(askmessage)s, @(server)s, @(subscribe)s, "
"@(type)s from rosterusers where username=%(LUser)s")).
get_roster_jid_groups(LServer, Username) ->
ejabberd_odbc:sql_query(LServer,
[<<"select jid, grp from rostergroups where "
"username='">>,
Username, <<"'">>]).
get_roster_jid_groups(LServer, LUser) ->
ejabberd_odbc:sql_query(
LServer,
?SQL("select @(jid)s, @(grp)s from rostergroups where "
"username=%(LUser)s")).
get_roster_groups(_LServer, Username, SJID) ->
ejabberd_odbc:sql_query_t([<<"select grp from rostergroups where username='">>,
Username, <<"' and jid='">>, SJID, <<"';">>]).
get_roster_groups(_LServer, LUser, SJID) ->
ejabberd_odbc:sql_query_t(
?SQL("select @(grp)s from rostergroups"
" where username=%(LUser)s and jid=%(SJID)s")).
del_user_roster_t(LServer, Username) ->
ejabberd_odbc:sql_transaction(LServer,
fun () ->
ejabberd_odbc:sql_query_t([<<"delete from rosterusers where "
"username='">>,
Username,
<<"';">>]),
ejabberd_odbc:sql_query_t([<<"delete from rostergroups where "
"username='">>,
Username,
<<"';">>])
end).
del_user_roster_t(LServer, LUser) ->
ejabberd_odbc:sql_transaction(
LServer,
fun () ->
ejabberd_odbc:sql_query_t(
?SQL("delete from rosterusers where username=%(LUser)s")),
ejabberd_odbc:sql_query_t(
?SQL("delete from rostergroups where username=%(LUser)s"))
end).
get_roster_by_jid(_LServer, Username, SJID) ->
ejabberd_odbc:sql_query_t([<<"select username, jid, nick, subscription, "
"ask, askmessage, server, subscribe, "
"type from rosterusers where username='">>,
Username, <<"' and jid='">>, SJID, <<"';">>]).
get_roster_by_jid(_LServer, LUser, SJID) ->
ejabberd_odbc:sql_query_t(
?SQL("select @(username)s, @(jid)s, @(nick)s, @(subscription)s,"
" @(ask)s, @(askmessage)s, @(server)s, @(subscribe)s,"
" @(type)s from rosterusers"
" where username=%(LUser)s and jid=%(SJID)s")).
get_rostergroup_by_jid(LServer, Username, SJID) ->
ejabberd_odbc:sql_query(LServer,
[<<"select grp from rostergroups where username='">>,
Username, <<"' and jid='">>, SJID, <<"'">>]).
get_rostergroup_by_jid(LServer, LUser, SJID) ->
ejabberd_odbc:sql_query(
LServer,
?SQL("select @(grp)s from rostergroups"
" where username=%(LUser)s and jid=%(SJID)s")).
del_roster(_LServer, Username, SJID) ->
ejabberd_odbc:sql_query_t([<<"delete from rosterusers where "
@ -421,11 +422,11 @@ roster_subscribe(_LServer, Username, SJID, ItemVals) ->
[<<"username='">>, Username, <<"' and jid='">>, SJID,
<<"'">>]).
get_subscription(LServer, Username, SJID) ->
ejabberd_odbc:sql_query(LServer,
[<<"select subscription from rosterusers "
"where username='">>,
Username, <<"' and jid='">>, SJID, <<"'">>]).
get_subscription(LServer, LUser, SJID) ->
ejabberd_odbc:sql_query(
LServer,
?SQL("select @(subscription)s from rosterusers "
"where username=%(LUser)s and jid=%(SJID)s")).
set_private_data(_LServer, Username, LXMLNS, SData) ->
update_t(<<"private_storage">>,
@ -639,10 +640,10 @@ count_records_where(LServer, Table, WhereClause) ->
WhereClause, <<";">>]).
get_roster_version(LServer, LUser) ->
ejabberd_odbc:sql_query(LServer,
[<<"select version from roster_version where "
"username = '">>,
LUser, <<"'">>]).
ejabberd_odbc:sql_query(
LServer,
?SQL("select @(version)s from roster_version"
" where username = %(LUser)s")).
set_roster_version(LUser, Version) ->
update_t(<<"roster_version">>,