Add SCRAM support to ejabberd_auth_odbc

This commit is contained in:
Alexey Shchepin 2015-02-17 23:26:31 +03:00
parent 0eb6b942ff
commit e575c87ea2
4 changed files with 307 additions and 107 deletions

View File

@ -22,6 +22,10 @@ CREATE TABLE users (
created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB CHARACTER SET utf8;
-- To support SCRAM auth:
-- ALTER TABLE users ADD COLUMN serverkey text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN salt text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN iterationcount integer NOT NULL DEFAULT 0;
CREATE TABLE last (
username varchar(250) PRIMARY KEY,

View File

@ -22,6 +22,10 @@ CREATE TABLE users (
created_at TIMESTAMP NOT NULL DEFAULT now()
);
-- To support SCRAM auth:
-- ALTER TABLE users ADD COLUMN serverkey text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN salt text NOT NULL DEFAULT '';
-- ALTER TABLE users ADD COLUMN iterationcount integer NOT NULL DEFAULT 0;
CREATE TABLE last (
username text PRIMARY KEY,

View File

@ -43,96 +43,180 @@
-include("ejabberd.hrl").
-include("logger.hrl").
-define(SALT_LENGTH, 16).
%%%----------------------------------------------------------------------
%%% API
%%%----------------------------------------------------------------------
start(_Host) -> ok.
plain_password_required() -> false.
plain_password_required() ->
case is_scrammed() of
false -> false;
true -> true
end.
store_type() -> plain.
store_type() ->
case is_scrammed() of
false -> plain; %% allows: PLAIN DIGEST-MD5 SCRAM
true -> scram %% allows: PLAIN SCRAM
end.
%% @spec (User, Server, Password) -> true | false | {error, Error}
check_password(User, Server, Password) ->
case jlib:nodeprep(User) of
error -> false;
LUser ->
Username = ejabberd_odbc:escape(LUser),
LServer = jlib:nameprep(Server),
try odbc_queries:get_password(LServer, Username) of
{selected, [<<"password">>], [[Password]]} ->
Password /= <<"">>;
{selected, [<<"password">>], [[_Password2]]} ->
false; %% Password is not correct
{selected, [<<"password">>], []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
catch
_:_ ->
false %% Typical error is database not accessible
end
LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
false;
(LUser == <<>>) or (LServer == <<>>) ->
false;
true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of
true ->
try odbc_queries:get_password_scram(LServer, Username) of
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>],
[[StoredKey, ServerKey, Salt, IterationCount]]} ->
Scram =
#scram{storedkey = StoredKey,
serverkey = ServerKey,
salt = Salt,
iterationcount = binary_to_integer(
IterationCount)},
is_password_scram_valid(Password, Scram);
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>], []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
catch
_:_ ->
false %% Typical error is database not accessible
end;
false ->
try odbc_queries:get_password(LServer, Username) of
{selected, [<<"password">>], [[Password]]} ->
Password /= <<"">>;
{selected, [<<"password">>], [[_Password2]]} ->
false; %% Password is not correct
{selected, [<<"password">>], []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
catch
_:_ ->
false %% Typical error is database not accessible
end
end
end.
%% @spec (User, Server, Password, Digest, DigestGen) -> true | false | {error, Error}
check_password(User, Server, Password, Digest,
DigestGen) ->
case jlib:nodeprep(User) of
error -> false;
LUser ->
Username = ejabberd_odbc:escape(LUser),
LServer = jlib:nameprep(Server),
try odbc_queries:get_password(LServer, Username) of
%% Account exists, check if password is valid
{selected, [<<"password">>], [[Passwd]]} ->
DigRes = if Digest /= <<"">> ->
Digest == DigestGen(Passwd);
true -> false
end,
if DigRes -> true;
true -> (Passwd == Password) and (Password /= <<"">>)
end;
{selected, [<<"password">>], []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
catch
_:_ ->
false %% Typical error is database not accessible
end
LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
false;
(LUser == <<>>) or (LServer == <<>>) ->
false;
true ->
case is_scrammed() of
false ->
Username = ejabberd_odbc:escape(LUser),
try odbc_queries:get_password(LServer, Username) of
%% Account exists, check if password is valid
{selected, [<<"password">>], [[Passwd]]} ->
DigRes = if Digest /= <<"">> ->
Digest == DigestGen(Passwd);
true -> false
end,
if DigRes -> true;
true -> (Passwd == Password) and (Password /= <<"">>)
end;
{selected, [<<"password">>], []} ->
false; %% Account does not exist
{error, _Error} ->
false %% Typical error is that table doesn't exist
catch
_:_ ->
false %% Typical error is database not accessible
end;
true ->
false
end
end.
%% @spec (User::string(), Server::string(), Password::string()) ->
%% ok | {error, invalid_jid}
set_password(User, Server, Password) ->
case jlib:nodeprep(User) of
error -> {error, invalid_jid};
LUser ->
Username = ejabberd_odbc:escape(LUser),
Pass = ejabberd_odbc:escape(Password),
LServer = jlib:nameprep(Server),
case catch odbc_queries:set_password_t(LServer,
Username, Pass)
of
{atomic, ok} -> ok;
Other -> {error, Other}
end
LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
{error, invalid_jid};
(LUser == <<>>) or (LServer == <<>>) ->
{error, invalid_jid};
true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of
true ->
Scram = password_to_scram(Password),
case catch odbc_queries:set_password_scram_t(
LServer,
Username,
ejabberd_odbc:escape(Scram#scram.storedkey),
ejabberd_odbc:escape(Scram#scram.serverkey),
ejabberd_odbc:escape(Scram#scram.salt),
integer_to_binary(Scram#scram.iterationcount)
)
of
{atomic, ok} -> ok;
Other -> {error, Other}
end;
false ->
Pass = ejabberd_odbc:escape(Password),
case catch odbc_queries:set_password_t(LServer,
Username, Pass)
of
{atomic, ok} -> ok;
Other -> {error, Other}
end
end
end.
%% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, invalid_jid}
try_register(User, Server, Password) ->
case jlib:nodeprep(User) of
error -> {error, invalid_jid};
LUser ->
LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
{error, invalid_jid};
(LUser == <<>>) or (LServer == <<>>) ->
{error, invalid_jid};
true ->
Username = ejabberd_odbc:escape(LUser),
Pass = ejabberd_odbc:escape(Password),
LServer = jlib:nameprep(Server),
case catch odbc_queries:add_user(LServer, Username,
Pass)
of
{updated, 1} -> {atomic, ok};
_ -> {atomic, exists}
end
case is_scrammed() of
true ->
Scram = password_to_scram(Password),
case catch odbc_queries:add_user_scram(
LServer,
Username,
ejabberd_odbc:escape(Scram#scram.storedkey),
ejabberd_odbc:escape(Scram#scram.serverkey),
ejabberd_odbc:escape(Scram#scram.salt),
integer_to_binary(Scram#scram.iterationcount)
) of
{updated, 1} -> {atomic, ok};
_ -> {atomic, exists}
end;
false ->
Pass = ejabberd_odbc:escape(Password),
case catch odbc_queries:add_user(LServer, Username,
Pass)
of
{updated, 1} -> {atomic, ok};
_ -> {atomic, exists}
end
end
end.
dirty_get_registered_users() ->
@ -175,29 +259,53 @@ get_vh_registered_users_number(Server, Opts) ->
end.
get_password(User, Server) ->
case jlib:nodeprep(User) of
error -> false;
LUser ->
Username = ejabberd_odbc:escape(LUser),
LServer = jlib:nameprep(Server),
case catch odbc_queries:get_password(LServer, Username)
of
{selected, [<<"password">>], [[Password]]} -> Password;
_ -> false
end
LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
false;
(LUser == <<>>) or (LServer == <<>>) ->
false;
true ->
Username = ejabberd_odbc:escape(LUser),
case is_scrammed() of
true ->
case catch odbc_queries:get_password_scram(
LServer, Username) of
{selected, [<<"password">>, <<"serverkey">>,
<<"salt">>, <<"iterationcount">>],
[[StoredKey, ServerKey, Salt, IterationCount]]} ->
{jlib:decode_base64(StoredKey),
jlib:decode_base64(ServerKey),
jlib:decode_base64(Salt),
binary_to_integer(IterationCount)};
_ -> false
end;
false ->
case catch odbc_queries:get_password(LServer, Username)
of
{selected, [<<"password">>], [[Password]]} -> Password;
_ -> false
end
end
end.
get_password_s(User, Server) ->
case jlib:nodeprep(User) of
error -> <<"">>;
LUser ->
Username = ejabberd_odbc:escape(LUser),
LServer = jlib:nameprep(Server),
case catch odbc_queries:get_password(LServer, Username)
of
{selected, [<<"password">>], [[Password]]} -> Password;
_ -> <<"">>
end
LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
<<"">>;
(LUser == <<>>) or (LServer == <<>>) ->
<<"">>;
true ->
case is_scrammed() of
false ->
Username = ejabberd_odbc:escape(LUser),
case catch odbc_queries:get_password(LServer, Username) of
{selected, [<<"password">>], [[Password]]} -> Password;
_ -> <<"">>
end;
true -> <<"">>
end
end.
%% @spec (User, Server) -> true | false | {error, Error}
@ -234,23 +342,72 @@ remove_user(User, Server) ->
%% @spec (User, Server, Password) -> ok | error | not_exists | not_allowed
%% @doc Remove user if the provided password is correct.
remove_user(User, Server, Password) ->
case jlib:nodeprep(User) of
error -> error;
LUser ->
Username = ejabberd_odbc:escape(LUser),
Pass = ejabberd_odbc:escape(Password),
LServer = jlib:nameprep(Server),
F = fun () ->
Result = odbc_queries:del_user_return_password(LServer,
Username,
Pass),
case Result of
{selected, [<<"password">>], [[Password]]} -> ok;
{selected, [<<"password">>], []} -> not_exists;
_ -> not_allowed
end
end,
{atomic, Result} = odbc_queries:sql_transaction(LServer,
F),
Result
LServer = jlib:nameprep(Server),
LUser = jlib:nodeprep(User),
if (LUser == error) or (LServer == error) ->
error;
(LUser == <<>>) or (LServer == <<>>) ->
error;
true ->
case is_scrammed() of
true ->
case check_password(User, Server, Password) of
true ->
remove_user(User, Server),
ok;
false -> not_allowed
end;
false ->
Username = ejabberd_odbc:escape(LUser),
Pass = ejabberd_odbc:escape(Password),
F = fun () ->
Result = odbc_queries:del_user_return_password(
LServer, Username, Pass),
case Result of
{selected, [<<"password">>],
[[Password]]} -> ok;
{selected, [<<"password">>],
[]} -> not_exists;
_ -> not_allowed
end
end,
{atomic, Result} = odbc_queries:sql_transaction(
LServer, F),
Result
end
end.
%%%
%%% SCRAM
%%%
is_scrammed() ->
scram ==
ejabberd_config:get_option({auth_password_format, ?MYNAME},
fun(V) -> V end).
password_to_scram(Password) ->
password_to_scram(Password,
?SCRAM_DEFAULT_ITERATION_COUNT).
password_to_scram(Password, IterationCount) ->
Salt = crypto:rand_bytes(?SALT_LENGTH),
SaltedPassword = scram:salted_password(Password, Salt,
IterationCount),
StoredKey =
scram:stored_key(scram:client_key(SaltedPassword)),
ServerKey = scram:server_key(SaltedPassword),
#scram{storedkey = jlib:encode_base64(StoredKey),
serverkey = jlib:encode_base64(ServerKey),
salt = jlib:encode_base64(Salt),
iterationcount = IterationCount}.
is_password_scram_valid(Password, Scram) ->
IterationCount = Scram#scram.iterationcount,
Salt = jlib:decode_base64(Scram#scram.salt),
SaltedPassword = scram:salted_password(Password, Salt,
IterationCount),
StoredKey =
scram:stored_key(scram:client_key(SaltedPassword)),
jlib:decode_base64(Scram#scram.storedkey) == StoredKey.

View File

@ -28,8 +28,10 @@
-author("mremond@process-one.net").
-export([get_db_type/0, update/5, update_t/4, sql_transaction/2,
get_last/2, set_last_t/4, del_last/2, get_password/2,
set_password_t/3, add_user/3, del_user/2,
get_last/2, set_last_t/4, del_last/2,
get_password/2, get_password_scram/2,
set_password_t/3, set_password_scram_t/6,
add_user/3, add_user_scram/6, del_user/2,
del_user_return_password/3, list_users/1, list_users/2,
users_number/1, users_number/2, add_spool_sql/2,
add_spool/2, get_and_del_spool_msg_t/2, del_spool_msg/2,
@ -157,6 +159,12 @@ get_password(LServer, Username) ->
[<<"select password from users where username='">>,
Username, <<"';">>]).
get_password_scram(LServer, Username) ->
ejabberd_odbc:sql_query(
LServer,
[<<"select password, serverkey, salt, iterationcount from users where "
"username='">>, Username, <<"';">>]).
set_password_t(LServer, Username, Pass) ->
ejabberd_odbc:sql_transaction(LServer,
fun () ->
@ -168,12 +176,39 @@ set_password_t(LServer, Username, Pass) ->
<<"'">>])
end).
set_password_scram_t(LServer, Username,
StoredKey, ServerKey, Salt, IterationCount) ->
ejabberd_odbc:sql_transaction(LServer,
fun () ->
update_t(<<"users">>,
[<<"username">>,
<<"password">>,
<<"serverkey">>,
<<"salt">>,
<<"iterationcount">>],
[Username, StoredKey,
ServerKey, Salt,
IterationCount],
[<<"username='">>, Username,
<<"'">>])
end).
add_user(LServer, Username, Pass) ->
ejabberd_odbc:sql_query(LServer,
[<<"insert into users(username, password) "
"values ('">>,
Username, <<"', '">>, Pass, <<"');">>]).
add_user_scram(LServer, Username,
StoredKey, ServerKey, Salt, IterationCount) ->
ejabberd_odbc:sql_query(LServer,
[<<"insert into users(username, password, serverkey, salt, iterationcount) "
"values ('">>,
Username, <<"', '">>, StoredKey, <<"', '">>,
ServerKey, <<"', '">>,
Salt, <<"', '">>,
IterationCount, <<"');">>]).
del_user(LServer, Username) ->
ejabberd_odbc:sql_query(LServer,
[<<"delete from users where username='">>, Username,