diff --git a/src/cyrsasl.erl b/src/cyrsasl.erl index 4b0f5a26b..1edf44678 100644 --- a/src/cyrsasl.erl +++ b/src/cyrsasl.erl @@ -31,18 +31,11 @@ -export([start/0, register_mechanism/3, listmech/1, server_new/7, server_start/3, server_step/2, - opt_type/1]). + get_mech/1, format_error/2, opt_type/1]). -include("ejabberd.hrl"). -include("logger.hrl"). -%% --export_type([ - mechanism/0, - mechanisms/0, - sasl_mechanism/0 -]). - -record(sasl_mechanism, {mechanism = <<"">> :: mechanism() | '$1', module :: atom(), @@ -51,12 +44,22 @@ -type(mechanism() :: binary()). -type(mechanisms() :: [mechanism(),...]). -type(password_type() :: plain | digest | scram). --type(props() :: [{username, binary()} | - {authzid, binary()} | - {auth_module, atom()}]). +-type sasl_property() :: {username, binary()} | + {authzid, binary()} | + {mechanism, binary()} | + {auth_module, atom()}. +-type sasl_return() :: {ok, [sasl_property()]} | + {ok, [sasl_property()], binary()} | + {continue, binary(), sasl_state()} | + {error, atom(), binary()}. -type(sasl_mechanism() :: #sasl_mechanism{}). - +-type error_reason() :: cyrsasl_digest:error_reason() | + cyrsasl_oauth:error_reason() | + cyrsasl_plain:error_reason() | + cyrsasl_scram:error_reason() | + unsupported_mechanism | nodeprep_failed | + empty_username | aborted. -record(sasl_state, { service, @@ -65,16 +68,16 @@ get_password, check_password, check_password_digest, + mech_name = <<"">>, mech_mod, mech_state }). +-type sasl_state() :: #sasl_state{}. +-export_type([mechanism/0, mechanisms/0, sasl_mechanism/0, error_reason/0, + sasl_state/0, sasl_return/0, sasl_property/0]). -callback mech_new(binary(), fun(), fun(), fun()) -> any(). --callback mech_step(any(), binary()) -> {ok, props()} | - {ok, props(), binary()} | - {continue, binary(), any()} | - {error, atom()} | - {error, atom(), binary()}. +-callback mech_step(any(), binary()) -> sasl_return(). start() -> ets:new(sasl_mechanism, @@ -87,7 +90,25 @@ start() -> cyrsasl_oauth:start([]), ok. -%% +-spec format_error(mechanism() | sasl_state(), error_reason()) -> {atom(), binary()}. +format_error(_, unsupported_mechanism) -> + {'invalid-mechanism', <<"Unsupported mechanism">>}; +format_error(_, nodeprep_failed) -> + {'bad-protocol', <<"Nodeprep failed">>}; +format_error(_, empty_username) -> + {'bad-protocol', <<"Empty username">>}; +format_error(_, aborted) -> + {'aborted', <<"Aborted">>}; +format_error(#sasl_state{mech_mod = Mod}, Reason) -> + Mod:format_error(Reason); +format_error(Mech, Reason) -> + case ets:lookup(sasl_mechanism, Mech) of + [#sasl_mechanism{module = Mod}] -> + Mod:format_error(Reason); + [] -> + {'invalid-mechanism', <<"Unsupported mechanism">>} + end. + -spec register_mechanism(Mechanim :: mechanism(), Module :: module(), PasswordType :: password_type()) -> any(). @@ -105,8 +126,8 @@ register_mechanism(Mechanism, Module, PasswordType) -> check_credentials(_State, Props) -> User = proplists:get_value(authzid, Props, <<>>), case jid:nodeprep(User) of - error -> {error, 'not-authorized'}; - <<"">> -> {error, 'not-authorized'}; + error -> {error, nodeprep_failed}; + <<"">> -> {error, empty_username}; _LUser -> ok end. @@ -128,6 +149,8 @@ listmech(Host) -> ['$1']}]), filter_anonymous(Host, Mechs). +-spec server_new(binary(), binary(), binary(), term(), + fun(), fun(), fun()) -> sasl_state(). server_new(Service, ServerFQDN, UserRealm, _SecFlags, GetPassword, CheckPassword, CheckPasswordDigest) -> #sasl_state{service = Service, myname = ServerFQDN, @@ -135,8 +158,7 @@ server_new(Service, ServerFQDN, UserRealm, _SecFlags, check_password = CheckPassword, check_password_digest = CheckPasswordDigest}. -server_start(State, Mech, undefined) -> - server_start(State, Mech, <<"">>); +-spec server_start(sasl_state(), mechanism(), binary()) -> sasl_return(). server_start(State, Mech, ClientIn) -> case lists:member(Mech, listmech(State#sasl_state.myname)) @@ -150,15 +172,15 @@ server_start(State, Mech, ClientIn) -> State#sasl_state.check_password, State#sasl_state.check_password_digest), server_step(State#sasl_state{mech_mod = Module, + mech_name = Mech, mech_state = MechState}, ClientIn); - _ -> {error, 'no-mechanism'} + _ -> {error, unsupported_mechanism, <<"">>} end; - false -> {error, 'no-mechanism'} + false -> {error, unsupported_mechanism, <<"">>} end. -server_step(State, undefined) -> - server_step(State, <<"">>); +-spec server_step(sasl_state(), binary()) -> sasl_return(). server_step(State, ClientIn) -> Module = State#sasl_state.mech_mod, MechState = State#sasl_state.mech_state, @@ -166,21 +188,25 @@ server_step(State, ClientIn) -> {ok, Props} -> case check_credentials(State, Props) of ok -> {ok, Props}; - {error, Error} -> {error, Error} + {error, Error} -> {error, Error, <<"">>} end; {ok, Props, ServerOut} -> case check_credentials(State, Props) of ok -> {ok, Props, ServerOut}; - {error, Error} -> {error, Error} + {error, Error} -> {error, Error, <<"">>} end; {continue, ServerOut, NewMechState} -> {continue, ServerOut, State#sasl_state{mech_state = NewMechState}}; {error, Error, Username} -> {error, Error, Username}; {error, Error} -> - {error, Error} + {error, Error, <<"">>} end. +-spec get_mech(sasl_state()) -> binary(). +get_mech(#sasl_state{mech_name = Mech}) -> + Mech. + %% Remove the anonymous mechanism from the list if not enabled for the given %% host %% diff --git a/src/cyrsasl_anonymous.erl b/src/cyrsasl_anonymous.erl index 15980afc5..cad9cdf93 100644 --- a/src/cyrsasl_anonymous.erl +++ b/src/cyrsasl_anonymous.erl @@ -43,10 +43,9 @@ stop() -> ok. mech_new(Host, _GetPassword, _CheckPassword, _CheckPasswordDigest) -> {ok, #state{server = Host}}. -mech_step(#state{server = Server} = S, ClientIn) -> +mech_step(#state{}, _ClientIn) -> User = iolist_to_binary([randoms:get_string(), integer_to_binary(p1_time_compat:unique_integer([positive]))]), - case ejabberd_auth:is_user_exists(User, Server) of - true -> mech_step(S, ClientIn); - false -> {ok, [{username, User}, {authzid, User}, {auth_module, ejabberd_auth_anonymous}]} - end. + {ok, [{username, User}, + {authzid, User}, + {auth_module, ejabberd_auth_anonymous}]}. diff --git a/src/cyrsasl_digest.erl b/src/cyrsasl_digest.erl index 150aa854c..9b4faca20 100644 --- a/src/cyrsasl_digest.erl +++ b/src/cyrsasl_digest.erl @@ -30,7 +30,7 @@ -author('alexey@sevcom.net'). -export([start/1, stop/0, mech_new/4, mech_step/2, - parse/1, opt_type/1]). + parse/1, format_error/1, opt_type/1]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -39,11 +39,13 @@ -type get_password_fun() :: fun((binary()) -> {false, any()} | {binary(), atom()}). - -type check_password_fun() :: fun((binary(), binary(), binary(), fun((binary()) -> binary())) -> {boolean(), any()} | false). +-type error_reason() :: parser_failed | invalid_digest_uri | + not_authorized | unexpected_response. +-export_type([error_reason/0]). -record(state, {step = 1 :: 1 | 3 | 5, nonce = <<"">> :: binary(), @@ -64,6 +66,16 @@ start(_Opts) -> stop() -> ok. +-spec format_error(error_reason()) -> {atom(), binary()}. +format_error(parser_failed) -> + {'bad-protocol', <<"Response decoding failed">>}; +format_error(invalid_digest_uri) -> + {'bad-protocol', <<"Invalid digest URI">>}; +format_error(not_authorized) -> + {'not-authorized', <<"Invalid username or password">>}; +format_error(unexpected_response) -> + {'bad-protocol', <<"Unexpected response">>}. + mech_new(Host, GetPassword, _CheckPassword, CheckPasswordDigest) -> {ok, @@ -80,8 +92,8 @@ mech_step(#state{step = 1, nonce = Nonce} = State, _) -> mech_step(#state{step = 3, nonce = Nonce} = State, ClientIn) -> case parse(ClientIn) of - bad -> {error, 'bad-protocol'}; - KeyVals -> + bad -> {error, parser_failed}; + KeyVals -> DigestURI = proplists:get_value(<<"digest-uri">>, KeyVals, <<>>), UserName = proplists:get_value(<<"username">>, KeyVals, <<>>), case is_digesturi_valid(DigestURI, State#state.host, @@ -92,11 +104,11 @@ mech_step(#state{step = 3, nonce = Nonce} = State, "seems invalid: ~p (checking for Host " "~p, FQDN ~p)", [DigestURI, State#state.host, State#state.hostfqdn]), - {error, 'not-authorized', UserName}; + {error, invalid_digest_uri, UserName}; true -> AuthzId = proplists:get_value(<<"authzid">>, KeyVals, <<>>), case (State#state.get_password)(UserName) of - {false, _} -> {error, 'not-authorized', UserName}; + {false, _} -> {error, not_authorized, UserName}; {Passwd, AuthModule} -> case (State#state.check_password)(UserName, UserName, <<"">>, proplists:get_value(<<"response">>, KeyVals, <<>>), @@ -116,8 +128,8 @@ mech_step(#state{step = 3, nonce = Nonce} = State, State#state{step = 5, auth_module = AuthModule, username = UserName, authzid = AuthzId}}; - false -> {error, 'not-authorized', UserName}; - {false, _} -> {error, 'not-authorized', UserName} + false -> {error, not_authorized, UserName}; + {false, _} -> {error, not_authorized, UserName} end end end @@ -134,7 +146,7 @@ mech_step(#state{step = 5, auth_module = AuthModule, {auth_module, AuthModule}]}; mech_step(A, B) -> ?DEBUG("SASL DIGEST: A ~p B ~p", [A, B]), - {error, 'bad-protocol'}. + {error, unexpected_response}. parse(S) -> parse1(binary_to_list(S), "", []). diff --git a/src/cyrsasl_oauth.erl b/src/cyrsasl_oauth.erl index 21dedc6db..be7e9a68d 100644 --- a/src/cyrsasl_oauth.erl +++ b/src/cyrsasl_oauth.erl @@ -27,11 +27,13 @@ -author('alexey@process-one.net'). --export([start/1, stop/0, mech_new/4, mech_step/2, parse/1]). +-export([start/1, stop/0, mech_new/4, mech_step/2, parse/1, format_error/1]). -behaviour(cyrsasl). -record(state, {host}). +-type error_reason() :: parser_failed | not_authorized. +-export_type([error_reason/0]). start(_Opts) -> cyrsasl:register_mechanism(<<"X-OAUTH2">>, ?MODULE, plain), @@ -39,6 +41,12 @@ start(_Opts) -> stop() -> ok. +-spec format_error(error_reason()) -> {atom(), binary()}. +format_error(parser_failed) -> + {'bad-protocol', <<"Response decoding failed">>}; +format_error(not_authorized) -> + {'not-authorized', <<"Invalid token">>}. + mech_new(Host, _GetPassword, _CheckPassword, _CheckPasswordDigest) -> {ok, #state{host = Host}}. @@ -52,9 +60,9 @@ mech_step(State, ClientIn) -> [{username, User}, {authzid, AuthzId}, {auth_module, ejabberd_oauth}]}; _ -> - {error, 'not-authorized', User} + {error, not_authorized, User} end; - _ -> {error, 'bad-protocol'} + _ -> {error, parser_failed} end. prepare(ClientIn) -> diff --git a/src/cyrsasl_plain.erl b/src/cyrsasl_plain.erl index 8e9b32b99..bbac8deff 100644 --- a/src/cyrsasl_plain.erl +++ b/src/cyrsasl_plain.erl @@ -27,11 +27,13 @@ -author('alexey@process-one.net'). --export([start/1, stop/0, mech_new/4, mech_step/2, parse/1]). +-export([start/1, stop/0, mech_new/4, mech_step/2, parse/1, format_error/1]). -behaviour(cyrsasl). -record(state, {check_password}). +-type error_reason() :: parser_failed | not_authorized. +-export_type([error_reason/0]). start(_Opts) -> cyrsasl:register_mechanism(<<"PLAIN">>, ?MODULE, plain), @@ -39,6 +41,12 @@ start(_Opts) -> stop() -> ok. +-spec format_error(error_reason()) -> {atom(), binary()}. +format_error(parser_failed) -> + {'bad-protocol', <<"Response decoding failed">>}; +format_error(not_authorized) -> + {'not-authorized', <<"Invalid username or password">>}. + mech_new(_Host, _GetPassword, CheckPassword, _CheckPasswordDigest) -> {ok, #state{check_password = CheckPassword}}. @@ -50,9 +58,9 @@ mech_step(State, ClientIn) -> {ok, [{username, User}, {authzid, AuthzId}, {auth_module, AuthModule}]}; - _ -> {error, 'not-authorized', User} + _ -> {error, not_authorized, User} end; - _ -> {error, 'bad-protocol'} + _ -> {error, parser_failed} end. prepare(ClientIn) -> diff --git a/src/cyrsasl_scram.erl b/src/cyrsasl_scram.erl index 1e2a5c681..55e06fd25 100644 --- a/src/cyrsasl_scram.erl +++ b/src/cyrsasl_scram.erl @@ -29,7 +29,7 @@ -protocol({rfc, 5802}). --export([start/1, stop/0, mech_new/4, mech_step/2]). +-export([start/1, stop/0, mech_new/4, mech_step/2, format_error/1]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -41,6 +41,7 @@ stored_key = <<"">> :: binary(), server_key = <<"">> :: binary(), username = <<"">> :: binary(), + auth_module :: module(), get_password :: fun(), check_password :: fun(), auth_message = <<"">> :: binary(), @@ -48,15 +49,39 @@ server_nonce = <<"">> :: binary()}). -define(SALT_LENGTH, 16). - -define(NONCE_LENGTH, 16). +-type error_reason() :: unsupported_extension | bad_username | + not_authorized | saslprep_failed | + parser_failed | bad_attribute | + nonce_mismatch | bad_channel_binding. + +-export_type([error_reason/0]). + start(_Opts) -> cyrsasl:register_mechanism(<<"SCRAM-SHA-1">>, ?MODULE, scram). stop() -> ok. +-spec format_error(error_reason()) -> {atom(), binary()}. +format_error(unsupported_extension) -> + {'bad-protocol', <<"Unsupported extension">>}; +format_error(bad_username) -> + {'invalid-authzid', <<"Malformed username">>}; +format_error(not_authorized) -> + {'not-authorized', <<"Invalid username or password">>}; +format_error(saslprep_failed) -> + {'not-authorized', <<"SASLprep failed">>}; +format_error(parser_failed) -> + {'bad-protocol', <<"Response decoding failed">>}; +format_error(bad_attribute) -> + {'bad-protocol', <<"Malformed or unexpected attribute">>}; +format_error(nonce_mismatch) -> + {'bad-protocol', <<"Nonce mismatch">>}; +format_error(bad_channel_binding) -> + {'bad-protocol', <<"Invalid channel binding">>}. + mech_new(_Host, GetPassword, _CheckPassword, _CheckPasswordDigest) -> {ok, #state{step = 2, get_password = GetPassword}}. @@ -64,22 +89,22 @@ mech_new(_Host, GetPassword, _CheckPassword, mech_step(#state{step = 2} = State, ClientIn) -> case re:split(ClientIn, <<",">>, [{return, binary}]) of [_CBind, _AuthorizationIdentity, _UserNameAttribute, _ClientNonceAttribute, ExtensionAttribute | _] - when ExtensionAttribute /= [] -> - {error, 'protocol-error-extension-not-supported'}; + when ExtensionAttribute /= <<"">> -> + {error, unsupported_extension}; [CBind, _AuthorizationIdentity, UserNameAttribute, ClientNonceAttribute | _] when (CBind == <<"y">>) or (CBind == <<"n">>) -> case parse_attribute(UserNameAttribute) of {error, Reason} -> {error, Reason}; {_, EscapedUserName} -> case unescape_username(EscapedUserName) of - error -> {error, 'protocol-error-bad-username'}; + error -> {error, bad_username}; UserName -> case parse_attribute(ClientNonceAttribute) of {$r, ClientNonce} -> - {Ret, _AuthModule} = (State#state.get_password)(UserName), + {Ret, AuthModule} = (State#state.get_password)(UserName), case {Ret, jid:resourceprep(Ret)} of - {false, _} -> {error, 'not-authorized', UserName}; - {_, error} when is_binary(Ret) -> ?WARNING_MSG("invalid plain password", []), {error, 'not-authorized', UserName}; + {false, _} -> {error, not_authorized, UserName}; + {_, error} when is_binary(Ret) -> {error, saslprep_failed, UserName}; {Ret, _} -> {StoredKey, ServerKey, Salt, IterationCount} = if is_tuple(Ret) -> Ret; @@ -112,6 +137,7 @@ mech_step(#state{step = 2} = State, ClientIn) -> {continue, ServerFirstMessage, State#state{step = 4, stored_key = StoredKey, server_key = ServerKey, + auth_module = AuthModule, auth_message = <>, @@ -119,11 +145,11 @@ mech_step(#state{step = 2} = State, ClientIn) -> server_nonce = ServerNonce, username = UserName}} end; - _Else -> {error, 'not-supported'} + _ -> {error, bad_attribute} end end end; - _Else -> {error, 'bad-protocol'} + _Else -> {error, parser_failed} end; mech_step(#state{step = 4} = State, ClientIn) -> case str:tokens(ClientIn, <<",">>) of @@ -158,39 +184,31 @@ mech_step(#state{step = 4} = State, ClientIn) -> scram:server_signature(State#state.server_key, AuthMessage), {ok, [{username, State#state.username}, + {auth_module, State#state.auth_module}, {authzid, State#state.username}], <<"v=", (jlib:encode_base64(ServerSignature))/binary>>}; - true -> {error, 'bad-auth', State#state.username} + true -> {error, not_authorized, State#state.username} end; - _Else -> {error, 'bad-protocol'} + _ -> {error, bad_attribute} end; - {$r, _} -> {error, 'bad-nonce'}; - _Else -> {error, 'bad-protocol'} + {$r, _} -> {error, nonce_mismatch}; + _ -> {error, bad_attribute} end; - true -> {error, 'bad-channel-binding'} + true -> {error, bad_channel_binding} end; - _Else -> {error, 'bad-protocol'} + _ -> {error, bad_attribute} end; - _Else -> {error, 'bad-protocol'} + _ -> {error, parser_failed} end. -parse_attribute(Attribute) -> - AttributeLen = byte_size(Attribute), - if AttributeLen >= 3 -> - AttributeS = binary_to_list(Attribute), - SecondChar = lists:nth(2, AttributeS), - case is_alpha(lists:nth(1, AttributeS)) of - true -> - if SecondChar == $= -> - String = str:substr(Attribute, 3), - {lists:nth(1, AttributeS), String}; - true -> {error, 'bad-format-second-char-not-equal-sign'} - end; - _Else -> {error, 'bad-format-first-char-not-a-letter'} - end; - true -> {error, 'bad-format-attribute-too-short'} - end. +parse_attribute(<>) when Val /= <<>> -> + case is_alpha(Name) of + true -> {Name, Val}; + false -> {error, bad_attribute} + end; +parse_attribute(_) -> + {error, bad_attribute}. unescape_username(<<"">>) -> <<"">>; unescape_username(EscapedUsername) ->