Add SCRAM support to ejabberd_auth_odbc

This commit is contained in:
Alexey Shchepin 2015-02-17 23:26:31 +03:00
parent 0eb6b942ff
commit e575c87ea2
4 changed files with 307 additions and 107 deletions

View File

@ -22,6 +22,10 @@ CREATE TABLE users (
created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB CHARACTER SET utf8; ) ENGINE=InnoDB CHARACTER SET utf8;
-- To support SCRAM auth:
-- ALTER TABLE users ADD COLUMN serverkey text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN salt text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN iterationcount integer NOT NULL DEFAULT 0;
CREATE TABLE last ( CREATE TABLE last (
username varchar(250) PRIMARY KEY, username varchar(250) PRIMARY KEY,

View File

@ -22,6 +22,10 @@ CREATE TABLE users (
created_at TIMESTAMP NOT NULL DEFAULT now() created_at TIMESTAMP NOT NULL DEFAULT now()
); );
-- To support SCRAM auth:
-- ALTER TABLE users ADD COLUMN serverkey text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN salt text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN iterationcount integer NOT NULL DEFAULT 0;
CREATE TABLE last ( CREATE TABLE last (
username text PRIMARY KEY, username text PRIMARY KEY,

View File

@ -43,96 +43,180 @@
-include("ejabberd.hrl"). -include("ejabberd.hrl").
-include("logger.hrl"). -include("logger.hrl").
-define(SALT_LENGTH, 16).
%%%---------------------------------------------------------------------- %%%----------------------------------------------------------------------
%%% API %%% API
%%%---------------------------------------------------------------------- %%%----------------------------------------------------------------------
start(_Host) -> ok. start(_Host) -> ok.
plain_password_required() -> false. plain_password_required() ->
case is_scrammed() of
false -> false;
true -> true
end.
store_type() -> plain. store_type() ->
case is_scrammed() of
false -> plain; %% allows: PLAIN DIGEST-MD5 SCRAM
true -> scram %% allows: PLAIN SCRAM
end.
%% @spec (User, Server, Password) -> true | false | {error, Error} %% @spec (User, Server, Password) -> true | false | {error, Error}
check_password(User, Server, Password) -> check_password(User, Server, Password) ->
case jlib:nodeprep(User) of LServer = jlib:nameprep(Server),
error -> false; LUser = jlib:nodeprep(User),
LUser -> if (LUser == error) or (LServer == error) ->
Username = ejabberd_odbc:escape(LUser), false;
LServer = jlib:nameprep(Server), (LUser == <<>>) or (LServer == <<>>) ->
try odbc_queries:get_password(LServer, Username) of false;
{selected, [<<"password">>], [[Password]]} -> true ->
Password /= <<"">>; Username = ejabberd_odbc:escape(LUser),
{selected, [<<"password">>], [[_Password2]]} -> case is_scrammed() of
false; %% Password is not correct true ->
{selected, [<<"password">>], []} -> try odbc_queries:get_password_scram(LServer, Username) of
false; %% Account does not exist {selected, [<<"password">>, <<"serverkey">>,
{error, _Error} -> <<"salt">>, <<"iterationcount">>],
false %% Typical error is that table doesn't exist [[StoredKey, ServerKey, Salt, IterationCount]]} ->
catch Scram =
_:_ -> #scram{storedkey = StoredKey,
false %% Typical error is database not accessible serverkey = ServerKey,
end salt = Salt,
iterationcount = binary_to_integer(
IterationCount)},
is_password_scram_valid(Password, Scram);
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>], []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
catch
_:_ ->
false %% Typical error is database not accessible
end;
false ->
try odbc_queries:get_password(LServer, Username) of
{selected, [<<"password">>], [[Password]]} ->
Password /= <<"">>;
{selected, [<<"password">>], [[_Password2]]} ->
false; %% Password is not correct
{selected, [<<"password">>], []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
catch
_:_ ->
false %% Typical error is database not accessible
end
end
end. end.
%% @spec (User, Server, Password, Digest, DigestGen) -> true | false | {error, Error} %% @spec (User, Server, Password, Digest, DigestGen) -> true | false | {error, Error}
check_password(User, Server, Password, Digest, check_password(User, Server, Password, Digest,
DigestGen) -> DigestGen) ->
case jlib:nodeprep(User) of LServer = jlib:nameprep(Server),
error -> false; LUser = jlib:nodeprep(User),
LUser -> if (LUser == error) or (LServer == error) ->
Username = ejabberd_odbc:escape(LUser), false;
LServer = jlib:nameprep(Server), (LUser == <<>>) or (LServer == <<>>) ->
try odbc_queries:get_password(LServer, Username) of false;
%% Account exists, check if password is valid true ->
{selected, [<<"password">>], [[Passwd]]} -> case is_scrammed() of
DigRes = if Digest /= <<"">> -> false ->
Digest == DigestGen(Passwd); Username = ejabberd_odbc:escape(LUser),
true -> false try odbc_queries:get_password(LServer, Username) of
end, %% Account exists, check if password is valid
if DigRes -> true; {selected, [<<"password">>], [[Passwd]]} ->
true -> (Passwd == Password) and (Password /= <<"">>) DigRes = if Digest /= <<"">> ->
end; Digest == DigestGen(Passwd);
{selected, [<<"password">>], []} -> true -> false
false; %% Account does not exist end,
{error, _Error} -> if DigRes -> true;
false %% Typical error is that table doesn't exist true -> (Passwd == Password) and (Password /= <<"">>)
catch end;
_:_ -> {selected, [<<"password">>], []} ->
false %% Typical error is database not accessible false; %% Account does not exist
end {error, _Error} ->
false %% Typical error is that table doesn't exist
catch
_:_ ->
false %% Typical error is database not accessible
end;
true ->
false
end
end. end.
%% @spec (User::string(), Server::string(), Password::string()) -> %% @spec (User::string(), Server::string(), Password::string()) ->
%% ok | {error, invalid_jid} %% ok | {error, invalid_jid}
set_password(User, Server, Password) -> set_password(User, Server, Password) ->
case jlib:nodeprep(User) of LServer = jlib:nameprep(Server),
error -> {error, invalid_jid}; LUser = jlib:nodeprep(User),
LUser -> if (LUser == error) or (LServer == error) ->
Username = ejabberd_odbc:escape(LUser), {error, invalid_jid};
Pass = ejabberd_odbc:escape(Password), (LUser == <<>>) or (LServer == <<>>) ->
LServer = jlib:nameprep(Server), {error, invalid_jid};
case catch odbc_queries:set_password_t(LServer, true ->
Username, Pass) Username = ejabberd_odbc:escape(LUser),
of case is_scrammed() of
{atomic, ok} -> ok; true ->
Other -> {error, Other} Scram = password_to_scram(Password),
end case catch odbc_queries:set_password_scram_t(
LServer,
Username,
ejabberd_odbc:escape(Scram#scram.storedkey),
ejabberd_odbc:escape(Scram#scram.serverkey),
ejabberd_odbc:escape(Scram#scram.salt),
integer_to_binary(Scram#scram.iterationcount)
)
of
{atomic, ok} -> ok;
Other -> {error, Other}
end;
false ->
Pass = ejabberd_odbc:escape(Password),
case catch odbc_queries:set_password_t(LServer,
Username, Pass)
of
{atomic, ok} -> ok;
Other -> {error, Other}
end
end
end. end.
%% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, invalid_jid} %% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, invalid_jid}
try_register(User, Server, Password) -> try_register(User, Server, Password) ->
case jlib:nodeprep(User) of LServer = jlib:nameprep(Server),
error -> {error, invalid_jid}; LUser = jlib:nodeprep(User),
LUser -> if (LUser == error) or (LServer == error) ->
{error, invalid_jid};
(LUser == <<>>) or (LServer == <<>>) ->
{error, invalid_jid};
true ->
Username = ejabberd_odbc:escape(LUser), Username = ejabberd_odbc:escape(LUser),
Pass = ejabberd_odbc:escape(Password), case is_scrammed() of
LServer = jlib:nameprep(Server), true ->
case catch odbc_queries:add_user(LServer, Username, Scram = password_to_scram(Password),
Pass) case catch odbc_queries:add_user_scram(
of LServer,
{updated, 1} -> {atomic, ok}; Username,
_ -> {atomic, exists} ejabberd_odbc:escape(Scram#scram.storedkey),
end ejabberd_odbc:escape(Scram#scram.serverkey),
ejabberd_odbc:escape(Scram#scram.salt),
integer_to_binary(Scram#scram.iterationcount)
) of
{updated, 1} -> {atomic, ok};
_ -> {atomic, exists}
end;
false ->
Pass = ejabberd_odbc:escape(Password),
case catch odbc_queries:add_user(LServer, Username,
Pass)
of
{updated, 1} -> {atomic, ok};
_ -> {atomic, exists}
end
end
end. end.
dirty_get_registered_users() -> dirty_get_registered_users() ->
@ -175,29 +259,53 @@ get_vh_registered_users_number(Server, Opts) ->
end. end.
get_password(User, Server) -> get_password(User, Server) ->
case jlib:nodeprep(User) of LServer = jlib:nameprep(Server),
error -> false; LUser = jlib:nodeprep(User),
LUser -> if (LUser == error) or (LServer == error) ->
Username = ejabberd_odbc:escape(LUser), false;
LServer = jlib:nameprep(Server), (LUser == <<>>) or (LServer == <<>>) ->
case catch odbc_queries:get_password(LServer, Username) false;
of true ->
{selected, [<<"password">>], [[Password]]} -> Password; Username = ejabberd_odbc:escape(LUser),
_ -> false case is_scrammed() of
end true ->
case catch odbc_queries:get_password_scram(
LServer, Username) of
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>],
[[StoredKey, ServerKey, Salt, IterationCount]]} ->
{jlib:decode_base64(StoredKey),
jlib:decode_base64(ServerKey),
jlib:decode_base64(Salt),
binary_to_integer(IterationCount)};
_ -> false
end;
false ->
case catch odbc_queries:get_password(LServer, Username)
of
{selected, [<<"password">>], [[Password]]} -> Password;
_ -> false
end
end
end. end.
get_password_s(User, Server) -> get_password_s(User, Server) ->
case jlib:nodeprep(User) of LServer = jlib:nameprep(Server),
error -> <<"">>; LUser = jlib:nodeprep(User),
LUser -> if (LUser == error) or (LServer == error) ->
Username = ejabberd_odbc:escape(LUser), <<"">>;
LServer = jlib:nameprep(Server), (LUser == <<>>) or (LServer == <<>>) ->
case catch odbc_queries:get_password(LServer, Username) <<"">>;
of true ->
{selected, [<<"password">>], [[Password]]} -> Password; case is_scrammed() of
_ -> <<"">> false ->
end Username = ejabberd_odbc:escape(LUser),
case catch odbc_queries:get_password(LServer, Username) of
{selected, [<<"password">>], [[Password]]} -> Password;
_ -> <<"">>
end;
true -> <<"">>
end
end. end.
%% @spec (User, Server) -> true | false | {error, Error} %% @spec (User, Server) -> true | false | {error, Error}
@ -234,23 +342,72 @@ remove_user(User, Server) ->
%% @spec (User, Server, Password) -> ok | error | not_exists | not_allowed %% @spec (User, Server, Password) -> ok | error | not_exists | not_allowed
%% @doc Remove user if the provided password is correct. %% @doc Remove user if the provided password is correct.
remove_user(User, Server, Password) -> remove_user(User, Server, Password) ->
case jlib:nodeprep(User) of LServer = jlib:nameprep(Server),
error -> error; LUser = jlib:nodeprep(User),
LUser -> if (LUser == error) or (LServer == error) ->
Username = ejabberd_odbc:escape(LUser), error;
Pass = ejabberd_odbc:escape(Password), (LUser == <<>>) or (LServer == <<>>) ->
LServer = jlib:nameprep(Server), error;
F = fun () -> true ->
Result = odbc_queries:del_user_return_password(LServer, case is_scrammed() of
Username, true ->
Pass), case check_password(User, Server, Password) of
case Result of true ->
{selected, [<<"password">>], [[Password]]} -> ok; remove_user(User, Server),
{selected, [<<"password">>], []} -> not_exists; ok;
_ -> not_allowed false -> not_allowed
end end;
end, false ->
{atomic, Result} = odbc_queries:sql_transaction(LServer, Username = ejabberd_odbc:escape(LUser),
F), Pass = ejabberd_odbc:escape(Password),
Result F = fun () ->
Result = odbc_queries:del_user_return_password(
LServer, Username, Pass),
case Result of
{selected, [<<"password">>],
[[Password]]} -> ok;
{selected, [<<"password">>],
[]} -> not_exists;
_ -> not_allowed
end
end,
{atomic, Result} = odbc_queries:sql_transaction(
LServer, F),
Result
end
end. end.
%%%
%%% SCRAM
%%%
is_scrammed() ->
scram ==
ejabberd_config:get_option({auth_password_format, ?MYNAME},
fun(V) -> V end).
password_to_scram(Password) ->
password_to_scram(Password,
?SCRAM_DEFAULT_ITERATION_COUNT).
password_to_scram(Password, IterationCount) ->
Salt = crypto:rand_bytes(?SALT_LENGTH),
SaltedPassword = scram:salted_password(Password, Salt,
IterationCount),
StoredKey =
scram:stored_key(scram:client_key(SaltedPassword)),
ServerKey = scram:server_key(SaltedPassword),
#scram{storedkey = jlib:encode_base64(StoredKey),
serverkey = jlib:encode_base64(ServerKey),
salt = jlib:encode_base64(Salt),
iterationcount = IterationCount}.
is_password_scram_valid(Password, Scram) ->
IterationCount = Scram#scram.iterationcount,
Salt = jlib:decode_base64(Scram#scram.salt),
SaltedPassword = scram:salted_password(Password, Salt,
IterationCount),
StoredKey =
scram:stored_key(scram:client_key(SaltedPassword)),
jlib:decode_base64(Scram#scram.storedkey) == StoredKey.

View File

@ -28,8 +28,10 @@
-author("mremond@process-one.net"). -author("mremond@process-one.net").
-export([get_db_type/0, update/5, update_t/4, sql_transaction/2, -export([get_db_type/0, update/5, update_t/4, sql_transaction/2,
get_last/2, set_last_t/4, del_last/2, get_password/2, get_last/2, set_last_t/4, del_last/2,
set_password_t/3, add_user/3, del_user/2, get_password/2, get_password_scram/2,
set_password_t/3, set_password_scram_t/6,
add_user/3, add_user_scram/6, del_user/2,
del_user_return_password/3, list_users/1, list_users/2, del_user_return_password/3, list_users/1, list_users/2,
users_number/1, users_number/2, add_spool_sql/2, users_number/1, users_number/2, add_spool_sql/2,
add_spool/2, get_and_del_spool_msg_t/2, del_spool_msg/2, add_spool/2, get_and_del_spool_msg_t/2, del_spool_msg/2,
@ -157,6 +159,12 @@ get_password(LServer, Username) ->
[<<"select password from users where username='">>, [<<"select password from users where username='">>,
Username, <<"';">>]). Username, <<"';">>]).
get_password_scram(LServer, Username) ->
ejabberd_odbc:sql_query(
LServer,
[<<"select password, serverkey, salt, iterationcount from users where "
"username='">>, Username, <<"';">>]).
set_password_t(LServer, Username, Pass) -> set_password_t(LServer, Username, Pass) ->
ejabberd_odbc:sql_transaction(LServer, ejabberd_odbc:sql_transaction(LServer,
fun () -> fun () ->
@ -168,12 +176,39 @@ set_password_t(LServer, Username, Pass) ->
<<"'">>]) <<"'">>])
end). end).
set_password_scram_t(LServer, Username,
StoredKey, ServerKey, Salt, IterationCount) ->
ejabberd_odbc:sql_transaction(LServer,
fun () ->
update_t(<<"users">>,
[<<"username">>,
<<"password">>,
<<"serverkey">>,
<<"salt">>,
<<"iterationcount">>],
[Username, StoredKey,
ServerKey, Salt,
IterationCount],
[<<"username='">>, Username,
<<"'">>])
end).
add_user(LServer, Username, Pass) -> add_user(LServer, Username, Pass) ->
ejabberd_odbc:sql_query(LServer, ejabberd_odbc:sql_query(LServer,
[<<"insert into users(username, password) " [<<"insert into users(username, password) "
"values ('">>, "values ('">>,
Username, <<"', '">>, Pass, <<"');">>]). Username, <<"', '">>, Pass, <<"');">>]).
add_user_scram(LServer, Username,
StoredKey, ServerKey, Salt, IterationCount) ->
ejabberd_odbc:sql_query(LServer,
[<<"insert into users(username, password, serverkey, salt, iterationcount) "
"values ('">>,
Username, <<"', '">>, StoredKey, <<"', '">>,
ServerKey, <<"', '">>,
Salt, <<"', '">>,
IterationCount, <<"');">>]).
del_user(LServer, Username) -> del_user(LServer, Username) ->
ejabberd_odbc:sql_query(LServer, ejabberd_odbc:sql_query(LServer,
[<<"delete from users where username='">>, Username, [<<"delete from users where username='">>, Username,