25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-11-20 16:15:59 +01:00

Add SCRAM support to ejabberd_auth_odbc

This commit is contained in:
Alexey Shchepin 2015-02-17 23:26:31 +03:00
parent 0eb6b942ff
commit e575c87ea2
4 changed files with 307 additions and 107 deletions

View File

@ -22,6 +22,10 @@ CREATE TABLE users (
created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB CHARACTER SET utf8; ) ENGINE=InnoDB CHARACTER SET utf8;
-- To support SCRAM auth:
-- ALTER TABLE users ADD COLUMN serverkey text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN salt text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN iterationcount integer NOT NULL DEFAULT 0;
CREATE TABLE last ( CREATE TABLE last (
username varchar(250) PRIMARY KEY, username varchar(250) PRIMARY KEY,

View File

@ -22,6 +22,10 @@ CREATE TABLE users (
created_at TIMESTAMP NOT NULL DEFAULT now() created_at TIMESTAMP NOT NULL DEFAULT now()
); );
-- To support SCRAM auth:
-- ALTER TABLE users ADD COLUMN serverkey text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN salt text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN iterationcount integer NOT NULL DEFAULT 0;
CREATE TABLE last ( CREATE TABLE last (
username text PRIMARY KEY, username text PRIMARY KEY,

View File

@ -43,22 +43,58 @@
-include("ejabberd.hrl"). -include("ejabberd.hrl").
-include("logger.hrl"). -include("logger.hrl").
-define(SALT_LENGTH, 16).
%%%---------------------------------------------------------------------- %%%----------------------------------------------------------------------
%%% API %%% API
%%%---------------------------------------------------------------------- %%%----------------------------------------------------------------------
start(_Host) -> ok. start(_Host) -> ok.
plain_password_required() -> false. plain_password_required() ->
case is_scrammed() of
false -> false;
true -> true
end.
store_type() -> plain. store_type() ->
case is_scrammed() of
false -> plain; %% allows: PLAIN DIGEST-MD5 SCRAM
true -> scram %% allows: PLAIN SCRAM
end.
%% @spec (User, Server, Password) -> true | false | {error, Error} %% @spec (User, Server, Password) -> true | false | {error, Error}
check_password(User, Server, Password) -> check_password(User, Server, Password) ->
case jlib:nodeprep(User) of
error -> false;
LUser ->
Username = ejabberd_odbc:escape(LUser),
LServer = jlib:nameprep(Server), LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
false;
(LUser == <<>>) or (LServer == <<>>) ->
false;
true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of
true ->
try odbc_queries:get_password_scram(LServer, Username) of
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>],
[[StoredKey, ServerKey, Salt, IterationCount]]} ->
Scram =
#scram{storedkey = StoredKey,
serverkey = ServerKey,
salt = Salt,
iterationcount = binary_to_integer(
IterationCount)},
is_password_scram_valid(Password, Scram);
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>], []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
catch
_:_ ->
false %% Typical error is database not accessible
end;
false ->
try odbc_queries:get_password(LServer, Username) of try odbc_queries:get_password(LServer, Username) of
{selected, [<<"password">>], [[Password]]} -> {selected, [<<"password">>], [[Password]]} ->
Password /= <<"">>; Password /= <<"">>;
@ -72,16 +108,22 @@ check_password(User, Server, Password) ->
_:_ -> _:_ ->
false %% Typical error is database not accessible false %% Typical error is database not accessible
end end
end
end. end.
%% @spec (User, Server, Password, Digest, DigestGen) -> true | false | {error, Error} %% @spec (User, Server, Password, Digest, DigestGen) -> true | false | {error, Error}
check_password(User, Server, Password, Digest, check_password(User, Server, Password, Digest,
DigestGen) -> DigestGen) ->
case jlib:nodeprep(User) of
error -> false;
LUser ->
Username = ejabberd_odbc:escape(LUser),
LServer = jlib:nameprep(Server), LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
false;
(LUser == <<>>) or (LServer == <<>>) ->
false;
true ->
case is_scrammed() of
false ->
Username = ejabberd_odbc:escape(LUser),
try odbc_queries:get_password(LServer, Username) of try odbc_queries:get_password(LServer, Username) of
%% Account exists, check if password is valid %% Account exists, check if password is valid
{selected, [<<"password">>], [[Passwd]]} -> {selected, [<<"password">>], [[Passwd]]} ->
@ -99,40 +141,82 @@ check_password(User, Server, Password, Digest,
catch catch
_:_ -> _:_ ->
false %% Typical error is database not accessible false %% Typical error is database not accessible
end;
true ->
false
end end
end. end.
%% @spec (User::string(), Server::string(), Password::string()) -> %% @spec (User::string(), Server::string(), Password::string()) ->
%% ok | {error, invalid_jid} %% ok | {error, invalid_jid}
set_password(User, Server, Password) -> set_password(User, Server, Password) ->
case jlib:nodeprep(User) of
error -> {error, invalid_jid};
LUser ->
Username = ejabberd_odbc:escape(LUser),
Pass = ejabberd_odbc:escape(Password),
LServer = jlib:nameprep(Server), LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
{error, invalid_jid};
(LUser == <<>>) or (LServer == <<>>) ->
{error, invalid_jid};
true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of
true ->
Scram = password_to_scram(Password),
case catch odbc_queries:set_password_scram_t(
LServer,
Username,
ejabberd_odbc:escape(Scram#scram.storedkey),
ejabberd_odbc:escape(Scram#scram.serverkey),
ejabberd_odbc:escape(Scram#scram.salt),
integer_to_binary(Scram#scram.iterationcount)
)
of
{atomic, ok} -> ok;
Other -> {error, Other}
end;
false ->
Pass = ejabberd_odbc:escape(Password),
case catch odbc_queries:set_password_t(LServer, case catch odbc_queries:set_password_t(LServer,
Username, Pass) Username, Pass)
of of
{atomic, ok} -> ok; {atomic, ok} -> ok;
Other -> {error, Other} Other -> {error, Other}
end end
end
end. end.
%% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, invalid_jid} %% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, invalid_jid}
try_register(User, Server, Password) -> try_register(User, Server, Password) ->
case jlib:nodeprep(User) of
error -> {error, invalid_jid};
LUser ->
Username = ejabberd_odbc:escape(LUser),
Pass = ejabberd_odbc:escape(Password),
LServer = jlib:nameprep(Server), LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
{error, invalid_jid};
(LUser == <<>>) or (LServer == <<>>) ->
{error, invalid_jid};
true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of
true ->
Scram = password_to_scram(Password),
case catch odbc_queries:add_user_scram(
LServer,
Username,
ejabberd_odbc:escape(Scram#scram.storedkey),
ejabberd_odbc:escape(Scram#scram.serverkey),
ejabberd_odbc:escape(Scram#scram.salt),
integer_to_binary(Scram#scram.iterationcount)
) of
{updated, 1} -> {atomic, ok};
_ -> {atomic, exists}
end;
false ->
Pass = ejabberd_odbc:escape(Password),
case catch odbc_queries:add_user(LServer, Username, case catch odbc_queries:add_user(LServer, Username,
Pass) Pass)
of of
{updated, 1} -> {atomic, ok}; {updated, 1} -> {atomic, ok};
_ -> {atomic, exists} _ -> {atomic, exists}
end end
end
end. end.
dirty_get_registered_users() -> dirty_get_registered_users() ->
@ -175,28 +259,52 @@ get_vh_registered_users_number(Server, Opts) ->
end. end.
get_password(User, Server) -> get_password(User, Server) ->
case jlib:nodeprep(User) of
error -> false;
LUser ->
Username = ejabberd_odbc:escape(LUser),
LServer = jlib:nameprep(Server), LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
false;
(LUser == <<>>) or (LServer == <<>>) ->
false;
true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of
true ->
case catch odbc_queries:get_password_scram(
LServer, Username) of
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>],
[[StoredKey, ServerKey, Salt, IterationCount]]} ->
{jlib:decode_base64(StoredKey),
jlib:decode_base64(ServerKey),
jlib:decode_base64(Salt),
binary_to_integer(IterationCount)};
_ -> false
end;
false ->
case catch odbc_queries:get_password(LServer, Username) case catch odbc_queries:get_password(LServer, Username)
of of
{selected, [<<"password">>], [[Password]]} -> Password; {selected, [<<"password">>], [[Password]]} -> Password;
_ -> false _ -> false
end end
end
end. end.
get_password_s(User, Server) -> get_password_s(User, Server) ->
case jlib:nodeprep(User) of
error -> <<"">>;
LUser ->
Username = ejabberd_odbc:escape(LUser),
LServer = jlib:nameprep(Server), LServer = jlib:nameprep(Server),
case catch odbc_queries:get_password(LServer, Username) LUser = jlib:nodeprep(User),
of if (LUser == error) or (LServer == error) ->
<<"">>;
(LUser == <<>>) or (LServer == <<>>) ->
<<"">>;
true ->
case is_scrammed() of
false ->
Username = ejabberd_odbc:escape(LUser),
case catch odbc_queries:get_password(LServer, Username) of
{selected, [<<"password">>], [[Password]]} -> Password; {selected, [<<"password">>], [[Password]]} -> Password;
_ -> <<"">> _ -> <<"">>
end;
true -> <<"">>
end end
end. end.
@ -234,23 +342,72 @@ remove_user(User, Server) ->
%% @spec (User, Server, Password) -> ok | error | not_exists | not_allowed %% @spec (User, Server, Password) -> ok | error | not_exists | not_allowed
%% @doc Remove user if the provided password is correct. %% @doc Remove user if the provided password is correct.
remove_user(User, Server, Password) -> remove_user(User, Server, Password) ->
case jlib:nodeprep(User) of LServer = jlib:nameprep(Server),
error -> error; LUser = jlib:nodeprep(User),
LUser -> if (LUser == error) or (LServer == error) ->
error;
(LUser == <<>>) or (LServer == <<>>) ->
error;
true ->
case is_scrammed() of
true ->
case check_password(User, Server, Password) of
true ->
remove_user(User, Server),
ok;
false -> not_allowed
end;
false ->
Username = ejabberd_odbc:escape(LUser), Username = ejabberd_odbc:escape(LUser),
Pass = ejabberd_odbc:escape(Password), Pass = ejabberd_odbc:escape(Password),
LServer = jlib:nameprep(Server),
F = fun () -> F = fun () ->
Result = odbc_queries:del_user_return_password(LServer, Result = odbc_queries:del_user_return_password(
Username, LServer, Username, Pass),
Pass),
case Result of case Result of
{selected, [<<"password">>], [[Password]]} -> ok; {selected, [<<"password">>],
{selected, [<<"password">>], []} -> not_exists; [[Password]]} -> ok;
{selected, [<<"password">>],
[]} -> not_exists;
_ -> not_allowed _ -> not_allowed
end end
end, end,
{atomic, Result} = odbc_queries:sql_transaction(LServer, {atomic, Result} = odbc_queries:sql_transaction(
F), LServer, F),
Result Result
end
end. end.
%%%
%%% SCRAM
%%%
is_scrammed() ->
scram ==
ejabberd_config:get_option({auth_password_format, ?MYNAME},
fun(V) -> V end).
password_to_scram(Password) ->
password_to_scram(Password,
?SCRAM_DEFAULT_ITERATION_COUNT).
password_to_scram(Password, IterationCount) ->
Salt = crypto:rand_bytes(?SALT_LENGTH),
SaltedPassword = scram:salted_password(Password, Salt,
IterationCount),
StoredKey =
scram:stored_key(scram:client_key(SaltedPassword)),
ServerKey = scram:server_key(SaltedPassword),
#scram{storedkey = jlib:encode_base64(StoredKey),
serverkey = jlib:encode_base64(ServerKey),
salt = jlib:encode_base64(Salt),
iterationcount = IterationCount}.
is_password_scram_valid(Password, Scram) ->
IterationCount = Scram#scram.iterationcount,
Salt = jlib:decode_base64(Scram#scram.salt),
SaltedPassword = scram:salted_password(Password, Salt,
IterationCount),
StoredKey =
scram:stored_key(scram:client_key(SaltedPassword)),
jlib:decode_base64(Scram#scram.storedkey) == StoredKey.

View File

@ -28,8 +28,10 @@
-author("mremond@process-one.net"). -author("mremond@process-one.net").
-export([get_db_type/0, update/5, update_t/4, sql_transaction/2, -export([get_db_type/0, update/5, update_t/4, sql_transaction/2,
get_last/2, set_last_t/4, del_last/2, get_password/2, get_last/2, set_last_t/4, del_last/2,
set_password_t/3, add_user/3, del_user/2, get_password/2, get_password_scram/2,
set_password_t/3, set_password_scram_t/6,
add_user/3, add_user_scram/6, del_user/2,
del_user_return_password/3, list_users/1, list_users/2, del_user_return_password/3, list_users/1, list_users/2,
users_number/1, users_number/2, add_spool_sql/2, users_number/1, users_number/2, add_spool_sql/2,
add_spool/2, get_and_del_spool_msg_t/2, del_spool_msg/2, add_spool/2, get_and_del_spool_msg_t/2, del_spool_msg/2,
@ -157,6 +159,12 @@ get_password(LServer, Username) ->
[<<"select password from users where username='">>, [<<"select password from users where username='">>,
Username, <<"';">>]). Username, <<"';">>]).
get_password_scram(LServer, Username) ->
ejabberd_odbc:sql_query(
LServer,
[<<"select password, serverkey, salt, iterationcount from users where "
"username='">>, Username, <<"';">>]).
set_password_t(LServer, Username, Pass) -> set_password_t(LServer, Username, Pass) ->
ejabberd_odbc:sql_transaction(LServer, ejabberd_odbc:sql_transaction(LServer,
fun () -> fun () ->
@ -168,12 +176,39 @@ set_password_t(LServer, Username, Pass) ->
<<"'">>]) <<"'">>])
end). end).
set_password_scram_t(LServer, Username,
StoredKey, ServerKey, Salt, IterationCount) ->
ejabberd_odbc:sql_transaction(LServer,
fun () ->
update_t(<<"users">>,
[<<"username">>,
<<"password">>,
<<"serverkey">>,
<<"salt">>,
<<"iterationcount">>],
[Username, StoredKey,
ServerKey, Salt,
IterationCount],
[<<"username='">>, Username,
<<"'">>])
end).
add_user(LServer, Username, Pass) -> add_user(LServer, Username, Pass) ->
ejabberd_odbc:sql_query(LServer, ejabberd_odbc:sql_query(LServer,
[<<"insert into users(username, password) " [<<"insert into users(username, password) "
"values ('">>, "values ('">>,
Username, <<"', '">>, Pass, <<"');">>]). Username, <<"', '">>, Pass, <<"');">>]).
add_user_scram(LServer, Username,
StoredKey, ServerKey, Salt, IterationCount) ->
ejabberd_odbc:sql_query(LServer,
[<<"insert into users(username, password, serverkey, salt, iterationcount) "
"values ('">>,
Username, <<"', '">>, StoredKey, <<"', '">>,
ServerKey, <<"', '">>,
Salt, <<"', '">>,
IterationCount, <<"');">>]).
del_user(LServer, Username) -> del_user(LServer, Username) ->
ejabberd_odbc:sql_query(LServer, ejabberd_odbc:sql_query(LServer,
[<<"delete from users where username='">>, Username, [<<"delete from users where username='">>, Username,