diff --git a/src/ejabberd_auth_odbc.erl b/src/ejabberd_auth_odbc.erl index b8b4594b6..608127817 100644 --- a/src/ejabberd_auth_odbc.erl +++ b/src/ejabberd_auth_odbc.erl @@ -72,22 +72,18 @@ check_password(User, Server, Password) -> (LUser == <<>>) or (LServer == <<>>) -> false; true -> - Username = ejabberd_odbc:escape(LUser), case is_scrammed() of true -> - try odbc_queries:get_password_scram(LServer, Username) of - {selected, [<<"password">>, <<"serverkey">>, - <<"salt">>, <<"iterationcount">>], - [[StoredKey, ServerKey, Salt, IterationCount]]} -> + try odbc_queries:get_password_scram(LServer, LUser) of + {selected, + [{StoredKey, ServerKey, Salt, IterationCount}]} -> Scram = #scram{storedkey = StoredKey, serverkey = ServerKey, salt = Salt, - iterationcount = binary_to_integer( - IterationCount)}, + iterationcount = IterationCount}, is_password_scram_valid(Password, Scram); - {selected, [<<"password">>, <<"serverkey">>, - <<"salt">>, <<"iterationcount">>], []} -> + {selected, []} -> false; %% Account does not exist {error, _Error} -> false %% Typical error is that table doesn't exist @@ -96,12 +92,12 @@ check_password(User, Server, Password) -> false %% Typical error is database not accessible end; false -> - try odbc_queries:get_password(LServer, Username) of - {selected, [<<"password">>], [[Password]]} -> + try odbc_queries:get_password(LServer, LUser) of + {selected, [{Password}]} -> Password /= <<"">>; - {selected, [<<"password">>], [[_Password2]]} -> + {selected, [{_Password2}]} -> false; %% Password is not correct - {selected, [<<"password">>], []} -> + {selected, []} -> false; %% Account does not exist {error, _Error} -> false %% Typical error is that table doesn't exist @@ -124,10 +120,9 @@ check_password(User, Server, Password, Digest, true -> case is_scrammed() of false -> - Username = ejabberd_odbc:escape(LUser), - try odbc_queries:get_password(LServer, Username) of + try odbc_queries:get_password(LServer, LUser) of %% Account exists, check if password is valid - {selected, [<<"password">>], [[Passwd]]} -> + {selected, [{Passwd}]} -> DigRes = if Digest /= <<"">> -> Digest == DigestGen(Passwd); true -> false @@ -135,7 +130,7 @@ check_password(User, Server, Password, Digest, if DigRes -> true; true -> (Passwd == Password) and (Password /= <<"">>) end; - {selected, [<<"password">>], []} -> + {selected, []} -> false; %% Account does not exist {error, _Error} -> false %% Typical error is that table doesn't exist @@ -267,24 +262,22 @@ get_password(User, Server) -> (LUser == <<>>) or (LServer == <<>>) -> false; true -> - Username = ejabberd_odbc:escape(LUser), case is_scrammed() of true -> case catch odbc_queries:get_password_scram( - LServer, Username) of - {selected, [<<"password">>, <<"serverkey">>, - <<"salt">>, <<"iterationcount">>], - [[StoredKey, ServerKey, Salt, IterationCount]]} -> + LServer, LUser) of + {selected, + [{StoredKey, ServerKey, Salt, IterationCount}]} -> {jlib:decode_base64(StoredKey), jlib:decode_base64(ServerKey), jlib:decode_base64(Salt), - binary_to_integer(IterationCount)}; + IterationCount}; _ -> false end; false -> - case catch odbc_queries:get_password(LServer, Username) + case catch odbc_queries:get_password(LServer, LUser) of - {selected, [<<"password">>], [[Password]]} -> Password; + {selected, [{Password}]} -> Password; _ -> false end end @@ -300,9 +293,8 @@ get_password_s(User, Server) -> true -> case is_scrammed() of false -> - Username = ejabberd_odbc:escape(LUser), - case catch odbc_queries:get_password(LServer, Username) of - {selected, [<<"password">>], [[Password]]} -> Password; + case catch odbc_queries:get_password(LServer, LUser) of + {selected, [{Password}]} -> Password; _ -> <<"">> end; true -> <<"">> @@ -311,15 +303,17 @@ get_password_s(User, Server) -> %% @spec (User, Server) -> true | false | {error, Error} is_user_exists(User, Server) -> - case jid:nodeprep(User) of - error -> false; - LUser -> - Username = ejabberd_odbc:escape(LUser), - LServer = jid:nameprep(Server), - try odbc_queries:get_password(LServer, Username) of - {selected, [<<"password">>], [[_Password]]} -> + LServer = jid:nameprep(Server), + LUser = jid:nodeprep(User), + if (LUser == error) or (LServer == error) -> + false; + (LUser == <<>>) or (LServer == <<>>) -> + false; + true -> + try odbc_queries:get_password(LServer, LUser) of + {selected, [{_Password}]} -> true; %% Account exists - {selected, [<<"password">>], []} -> + {selected, []} -> false; %% Account does not exist {error, Error} -> {error, Error} catch diff --git a/src/mod_roster.erl b/src/mod_roster.erl index f68300763..1e5e3b70c 100644 --- a/src/mod_roster.erl +++ b/src/mod_roster.erl @@ -203,11 +203,9 @@ read_roster_version(LUser, LServer, mnesia) -> [] -> error end; read_roster_version(LUser, LServer, odbc) -> - Username = ejabberd_odbc:escape(LUser), - case odbc_queries:get_roster_version(LServer, Username) - of - {selected, [<<"version">>], [[Version]]} -> Version; - {selected, [<<"version">>], []} -> error + case odbc_queries:get_roster_version(LServer, LUser) of + {selected, [{Version}]} -> Version; + {selected, []} -> error end; read_roster_version(LServer, LUser, riak) -> case ejabberd_riak:get(roster_version, roster_version_schema(), @@ -369,46 +367,37 @@ get_roster(LUser, LServer, riak) -> _Err -> [] end; get_roster(LUser, LServer, odbc) -> - Username = ejabberd_odbc:escape(LUser), - case catch odbc_queries:get_roster(LServer, Username) of - {selected, - [<<"username">>, <<"jid">>, <<"nick">>, - <<"subscription">>, <<"ask">>, <<"askmessage">>, - <<"server">>, <<"subscribe">>, <<"type">>], - Items} - when is_list(Items) -> - JIDGroups = case catch - odbc_queries:get_roster_jid_groups(LServer, - Username) - of - {selected, [<<"jid">>, <<"grp">>], JGrps} - when is_list(JGrps) -> - JGrps; - _ -> [] - end, - GroupsDict = lists:foldl(fun ([J, G], Acc) -> - dict:append(J, G, Acc) - end, - dict:new(), JIDGroups), - RItems = lists:flatmap(fun (I) -> - case raw_to_record(LServer, I) of - %% Bad JID in database: - error -> []; - R -> - SJID = - jid:to_string(R#roster.jid), - Groups = case dict:find(SJID, - GroupsDict) - of - {ok, Gs} -> Gs; - error -> [] - end, - [R#roster{groups = Groups}] - end - end, - Items), - RItems; - _ -> [] + case catch odbc_queries:get_roster(LServer, LUser) of + {selected, Items} when is_list(Items) -> + JIDGroups = case catch odbc_queries:get_roster_jid_groups( + LServer, LUser) of + {selected, JGrps} + when is_list(JGrps) -> + JGrps; + _ -> [] + end, + GroupsDict = lists:foldl(fun({J, G}, Acc) -> + dict:append(J, G, Acc) + end, + dict:new(), JIDGroups), + RItems = + lists:flatmap( + fun(I) -> + case raw_to_record(LServer, I) of + %% Bad JID in database: + error -> []; + R -> + SJID = jid:to_string(R#roster.jid), + Groups = case dict:find(SJID, GroupsDict) of + {ok, Gs} -> Gs; + error -> [] + end, + [R#roster{groups = Groups}] + end + end, + Items), + RItems; + _ -> [] end. set_roster(#roster{us = {LUser, LServer}, jid = LJID} = Item) -> @@ -460,14 +449,8 @@ get_roster_by_jid_t(LUser, LServer, LJID, mnesia) -> xs = []} end; get_roster_by_jid_t(LUser, LServer, LJID, odbc) -> - Username = ejabberd_odbc:escape(LUser), - SJID = ejabberd_odbc:escape(jid:to_string(LJID)), - {selected, - [<<"username">>, <<"jid">>, <<"nick">>, - <<"subscription">>, <<"ask">>, <<"askmessage">>, - <<"server">>, <<"subscribe">>, <<"type">>], - Res} = - odbc_queries:get_roster_by_jid(LServer, Username, SJID), + {selected, Res} = + odbc_queries:get_roster_by_jid(LServer, LUser, jid:to_string(LJID)), case Res of [] -> #roster{usj = {LUser, LServer, LJID}, @@ -750,30 +733,18 @@ get_roster_by_jid_with_groups_t(LUser, LServer, LJID, end; get_roster_by_jid_with_groups_t(LUser, LServer, LJID, odbc) -> - Username = ejabberd_odbc:escape(LUser), - SJID = ejabberd_odbc:escape(jid:to_string(LJID)), - case odbc_queries:get_roster_by_jid(LServer, Username, - SJID) - of - {selected, - [<<"username">>, <<"jid">>, <<"nick">>, - <<"subscription">>, <<"ask">>, <<"askmessage">>, - <<"server">>, <<"subscribe">>, <<"type">>], - [I]} -> - R = raw_to_record(LServer, I), - Groups = case odbc_queries:get_roster_groups(LServer, - Username, SJID) - of - {selected, [<<"grp">>], JGrps} when is_list(JGrps) -> - [JGrp || [JGrp] <- JGrps]; - _ -> [] - end, - R#roster{groups = Groups}; - {selected, - [<<"username">>, <<"jid">>, <<"nick">>, - <<"subscription">>, <<"ask">>, <<"askmessage">>, - <<"server">>, <<"subscribe">>, <<"type">>], - []} -> + SJID = jid:to_string(LJID), + case odbc_queries:get_roster_by_jid(LServer, LUser, SJID) of + {selected, [I]} -> + R = raw_to_record(LServer, I), + Groups = + case odbc_queries:get_roster_groups(LServer, LUser, SJID) of + {selected, JGrps} when is_list(JGrps) -> + [JGrp || {JGrp} <- JGrps]; + _ -> [] + end, + R#roster{groups = Groups}; + {selected, []} -> #roster{usj = {LUser, LServer, LJID}, us = {LUser, LServer}, jid = LJID} end; @@ -995,8 +966,7 @@ remove_user(LUser, LServer, mnesia) -> end, mnesia:transaction(F); remove_user(LUser, LServer, odbc) -> - Username = ejabberd_odbc:escape(LUser), - odbc_queries:del_user_roster_t(LServer, Username), + odbc_queries:del_user_roster_t(LServer, LUser), ok; remove_user(LUser, LServer, riak) -> {atomic, ejabberd_riak:delete_by_index(roster, <<"us">>, {LUser, LServer})}. @@ -1243,12 +1213,9 @@ read_subscription_and_groups(LUser, LServer, LJID, end; read_subscription_and_groups(LUser, LServer, LJID, odbc) -> - Username = ejabberd_odbc:escape(LUser), - SJID = ejabberd_odbc:escape(jid:to_string(LJID)), - case catch odbc_queries:get_subscription(LServer, - Username, SJID) - of - {selected, [<<"subscription">>], [[SSubscription]]} -> + SJID = jid:to_string(LJID), + case catch odbc_queries:get_subscription(LServer, LUser, SJID) of + {selected, [{SSubscription}]} -> Subscription = case SSubscription of <<"B">> -> both; <<"T">> -> to; @@ -1256,11 +1223,11 @@ read_subscription_and_groups(LUser, LServer, LJID, _ -> none end, Groups = case catch - odbc_queries:get_rostergroup_by_jid(LServer, Username, + odbc_queries:get_rostergroup_by_jid(LServer, LUser, SJID) of - {selected, [<<"grp">>], JGrps} when is_list(JGrps) -> - [JGrp || [JGrp] <- JGrps]; + {selected, JGrps} when is_list(JGrps) -> + [JGrp || {JGrp} <- JGrps]; _ -> [] end, {Subscription, Groups}; @@ -1297,6 +1264,12 @@ get_jid_info(_, User, Server, JID) -> raw_to_record(LServer, [User, SJID, Nick, SSubscription, SAsk, SAskMessage, _SServer, _SSubscribe, _SType]) -> + raw_to_record(LServer, + {User, SJID, Nick, SSubscription, SAsk, SAskMessage, + _SServer, _SSubscribe, _SType}); +raw_to_record(LServer, + {User, SJID, Nick, SSubscription, SAsk, SAskMessage, + _SServer, _SSubscribe, _SType}) -> case jid:from_string(SJID) of error -> error; JID -> diff --git a/src/odbc_queries.erl b/src/odbc_queries.erl index ee8fa1690..b6c9a36c0 100644 --- a/src/odbc_queries.erl +++ b/src/odbc_queries.erl @@ -139,16 +139,17 @@ del_last(LServer, Username) -> [<<"delete from last where username='">>, Username, <<"'">>]). -get_password(LServer, Username) -> - ejabberd_odbc:sql_query(LServer, - [<<"select password from users where username='">>, - Username, <<"';">>]). - -get_password_scram(LServer, Username) -> +get_password(LServer, LUser) -> ejabberd_odbc:sql_query( LServer, - [<<"select password, serverkey, salt, iterationcount from users where " - "username='">>, Username, <<"';">>]). + ?SQL("select @(password)s from users where username=%(LUser)s")). + +get_password_scram(LServer, LUser) -> + ejabberd_odbc:sql_query( + LServer, + ?SQL("select @(password)s, @(serverkey)s, @(salt)s, @(iterationcount)d" + " from users" + " where username=%(LUser)s")). set_password_t(LServer, Username, Pass) -> ejabberd_odbc:sql_transaction(LServer, @@ -311,46 +312,46 @@ del_spool_msg(LServer, LUser) -> LServer, ?SQL("delete from spool where username=%(LUser)s")). -get_roster(LServer, Username) -> - ejabberd_odbc:sql_query(LServer, - [<<"select username, jid, nick, subscription, " - "ask, askmessage, server, subscribe, " - "type from rosterusers where username='">>, - Username, <<"'">>]). +get_roster(LServer, LUser) -> + ejabberd_odbc:sql_query( + LServer, + ?SQL("select @(username)s, @(jid)s, @(nick)s, @(subscription)s, " + "@(ask)s, @(askmessage)s, @(server)s, @(subscribe)s, " + "@(type)s from rosterusers where username=%(LUser)s")). -get_roster_jid_groups(LServer, Username) -> - ejabberd_odbc:sql_query(LServer, - [<<"select jid, grp from rostergroups where " - "username='">>, - Username, <<"'">>]). +get_roster_jid_groups(LServer, LUser) -> + ejabberd_odbc:sql_query( + LServer, + ?SQL("select @(jid)s, @(grp)s from rostergroups where " + "username=%(LUser)s")). -get_roster_groups(_LServer, Username, SJID) -> - ejabberd_odbc:sql_query_t([<<"select grp from rostergroups where username='">>, - Username, <<"' and jid='">>, SJID, <<"';">>]). +get_roster_groups(_LServer, LUser, SJID) -> + ejabberd_odbc:sql_query_t( + ?SQL("select @(grp)s from rostergroups" + " where username=%(LUser)s and jid=%(SJID)s")). -del_user_roster_t(LServer, Username) -> - ejabberd_odbc:sql_transaction(LServer, - fun () -> - ejabberd_odbc:sql_query_t([<<"delete from rosterusers where " - "username='">>, - Username, - <<"';">>]), - ejabberd_odbc:sql_query_t([<<"delete from rostergroups where " - "username='">>, - Username, - <<"';">>]) - end). +del_user_roster_t(LServer, LUser) -> + ejabberd_odbc:sql_transaction( + LServer, + fun () -> + ejabberd_odbc:sql_query_t( + ?SQL("delete from rosterusers where username=%(LUser)s")), + ejabberd_odbc:sql_query_t( + ?SQL("delete from rostergroups where username=%(LUser)s")) + end). -get_roster_by_jid(_LServer, Username, SJID) -> - ejabberd_odbc:sql_query_t([<<"select username, jid, nick, subscription, " - "ask, askmessage, server, subscribe, " - "type from rosterusers where username='">>, - Username, <<"' and jid='">>, SJID, <<"';">>]). +get_roster_by_jid(_LServer, LUser, SJID) -> + ejabberd_odbc:sql_query_t( + ?SQL("select @(username)s, @(jid)s, @(nick)s, @(subscription)s," + " @(ask)s, @(askmessage)s, @(server)s, @(subscribe)s," + " @(type)s from rosterusers" + " where username=%(LUser)s and jid=%(SJID)s")). -get_rostergroup_by_jid(LServer, Username, SJID) -> - ejabberd_odbc:sql_query(LServer, - [<<"select grp from rostergroups where username='">>, - Username, <<"' and jid='">>, SJID, <<"'">>]). +get_rostergroup_by_jid(LServer, LUser, SJID) -> + ejabberd_odbc:sql_query( + LServer, + ?SQL("select @(grp)s from rostergroups" + " where username=%(LUser)s and jid=%(SJID)s")). del_roster(_LServer, Username, SJID) -> ejabberd_odbc:sql_query_t([<<"delete from rosterusers where " @@ -421,11 +422,11 @@ roster_subscribe(_LServer, Username, SJID, ItemVals) -> [<<"username='">>, Username, <<"' and jid='">>, SJID, <<"'">>]). -get_subscription(LServer, Username, SJID) -> - ejabberd_odbc:sql_query(LServer, - [<<"select subscription from rosterusers " - "where username='">>, - Username, <<"' and jid='">>, SJID, <<"'">>]). +get_subscription(LServer, LUser, SJID) -> + ejabberd_odbc:sql_query( + LServer, + ?SQL("select @(subscription)s from rosterusers " + "where username=%(LUser)s and jid=%(SJID)s")). set_private_data(_LServer, Username, LXMLNS, SData) -> update_t(<<"private_storage">>, @@ -639,10 +640,10 @@ count_records_where(LServer, Table, WhereClause) -> WhereClause, <<";">>]). get_roster_version(LServer, LUser) -> - ejabberd_odbc:sql_query(LServer, - [<<"select version from roster_version where " - "username = '">>, - LUser, <<"'">>]). + ejabberd_odbc:sql_query( + LServer, + ?SQL("select @(version)s from roster_version" + " where username = %(LUser)s")). set_roster_version(LUser, Version) -> update_t(<<"roster_version">>,