From 0bb14d16c7505493b2e8ae69f1d155f6cf71f5e9 Mon Sep 17 00:00:00 2001 From: Evgeniy Khramtsov Date: Fri, 6 Jul 2018 01:07:36 +0300 Subject: [PATCH] Move XMPP stream and SASL processing to xmpp repo --- asn1/XmppAddr.asn1 | 14 - include/scram.hrl | 28 - rebar.config | 4 +- src/cyrsasl.erl | 229 ------ src/cyrsasl_anonymous.erl | 50 -- src/cyrsasl_digest.erl | 270 ------- src/cyrsasl_oauth.erl | 104 --- src/cyrsasl_plain.erl | 94 --- src/cyrsasl_scram.erl | 249 ------ src/ejabberd_bosh.erl | 4 +- src/ejabberd_c2s.erl | 43 +- src/ejabberd_config.erl | 13 +- src/ejabberd_idna.erl | 224 ------ src/ejabberd_logger.erl | 4 + src/ejabberd_pkix.erl | 2 +- src/ejabberd_s2s_in.erl | 15 +- src/ejabberd_service.erl | 2 +- src/ejabberd_sup.erl | 1 - src/scram.erl | 81 -- src/xmpp_socket.erl | 393 ---------- src/xmpp_stream_in.erl | 1220 ----------------------------- src/xmpp_stream_out.erl | 1321 -------------------------------- src/xmpp_stream_pkix.erl | 271 ------- test/ejabberd_cyrsasl_test.exs | 18 +- test/suite.erl | 4 +- 25 files changed, 69 insertions(+), 4589 deletions(-) delete mode 100644 asn1/XmppAddr.asn1 delete mode 100644 include/scram.hrl delete mode 100644 src/cyrsasl.erl delete mode 100644 src/cyrsasl_anonymous.erl delete mode 100644 src/cyrsasl_digest.erl delete mode 100644 src/cyrsasl_oauth.erl delete mode 100644 src/cyrsasl_plain.erl delete mode 100644 src/cyrsasl_scram.erl delete mode 100644 src/ejabberd_idna.erl delete mode 100644 src/scram.erl delete mode 100644 src/xmpp_socket.erl delete mode 100644 src/xmpp_stream_in.erl delete mode 100644 src/xmpp_stream_out.erl delete mode 100644 src/xmpp_stream_pkix.erl diff --git a/asn1/XmppAddr.asn1 b/asn1/XmppAddr.asn1 deleted file mode 100644 index 14f350d3d..000000000 --- a/asn1/XmppAddr.asn1 +++ /dev/null @@ -1,14 +0,0 @@ -XmppAddr { iso(1) identified-organization(3) - dod(6) internet(1) security(5) mechanisms(5) pkix(7) - id-on(8) id-on-xmppAddr(5) } - -DEFINITIONS EXPLICIT TAGS ::= -BEGIN - -id-on-xmppAddr OBJECT IDENTIFIER ::= { iso(1) identified-organization(3) - dod(6) internet(1) security(5) mechanisms(5) pkix(7) - id-on(8) 5 } - -XmppAddr ::= UTF8String - -END diff --git a/include/scram.hrl b/include/scram.hrl deleted file mode 100644 index 156a6401a..000000000 --- a/include/scram.hrl +++ /dev/null @@ -1,28 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --record(scram, {storedkey = <<"">> :: binary(), - serverkey = <<"">> :: binary(), - salt = <<"">> :: binary(), - iterationcount = 0 :: integer()}). - --type scram() :: #scram{}. - --define(SCRAM_DEFAULT_ITERATION_COUNT, 4096). diff --git a/rebar.config b/rebar.config index d5e19b29b..35dfb3553 100644 --- a/rebar.config +++ b/rebar.config @@ -25,7 +25,7 @@ {fast_tls, ".*", {git, "https://github.com/processone/fast_tls", {tag, "1.0.23"}}}, {stringprep, ".*", {git, "https://github.com/processone/stringprep", {tag, "1.0.12"}}}, {fast_xml, ".*", {git, "https://github.com/processone/fast_xml", {tag, "1.1.32"}}}, - {xmpp, ".*", {git, "https://github.com/processone/xmpp", "0e2ef5d"}}, + {xmpp, ".*", {git, "https://github.com/processone/xmpp", "2a5193c"}}, {fast_yaml, ".*", {git, "https://github.com/processone/fast_yaml", {tag, "1.0.15"}}}, {jiffy, ".*", {git, "https://github.com/davisp/jiffy", {tag, "0.14.8"}}}, {p1_oauth2, ".*", {git, "https://github.com/processone/p1_oauth2", {tag, "0.6.3"}}}, @@ -100,7 +100,7 @@ {if_have_fun, {public_key, short_name_hash, 1}, {d, 'SHORT_NAME_HASH'}}, {if_var_true, new_sql_schema, {d, 'NEW_SQL_SCHEMA'}}, {if_var_true, hipe, native}, - {src_dirs, [asn1, src, + {src_dirs, [src, {if_var_true, tools, tools}, {if_var_true, elixir, include}]}]}. diff --git a/src/cyrsasl.erl b/src/cyrsasl.erl deleted file mode 100644 index 223d5fe68..000000000 --- a/src/cyrsasl.erl +++ /dev/null @@ -1,229 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : cyrsasl.erl -%%% Author : Alexey Shchepin -%%% Purpose : Cyrus SASL-like library -%%% Created : 8 Mar 2003 by Alexey Shchepin -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(cyrsasl). - --author('alexey@process-one.net'). --behaviour(gen_server). - --export([start_link/0, register_mechanism/3, listmech/1, - server_new/7, server_start/3, server_step/2, - get_mech/1, format_error/2]). -%% gen_server callbacks --export([init/1, handle_call/3, handle_cast/2, handle_info/2, - terminate/2, code_change/3]). - --include("logger.hrl"). - --record(state, {}). - --record(sasl_mechanism, - {mechanism = <<"">> :: mechanism() | '$1', - module :: atom(), - password_type = plain :: password_type() | '$2'}). - --type(mechanism() :: binary()). --type(mechanisms() :: [mechanism(),...]). --type(password_type() :: plain | digest | scram). --type sasl_property() :: {username, binary()} | - {authzid, binary()} | - {mechanism, binary()} | - {auth_module, atom()}. --type sasl_return() :: {ok, [sasl_property()]} | - {ok, [sasl_property()], binary()} | - {continue, binary(), sasl_state()} | - {error, atom(), binary()}. - --type(sasl_mechanism() :: #sasl_mechanism{}). --type error_reason() :: cyrsasl_digest:error_reason() | - cyrsasl_oauth:error_reason() | - cyrsasl_plain:error_reason() | - cyrsasl_scram:error_reason() | - unsupported_mechanism | nodeprep_failed | - empty_username | aborted. --record(sasl_state, -{ - service, - myname, - realm, - get_password, - check_password, - check_password_digest, - mech_name = <<"">>, - mech_mod, - mech_state -}). --type sasl_state() :: #sasl_state{}. --export_type([mechanism/0, mechanisms/0, sasl_mechanism/0, error_reason/0, - sasl_state/0, sasl_return/0, sasl_property/0]). - --callback start(list()) -> any(). --callback stop() -> any(). --callback mech_new(binary(), fun(), fun(), fun()) -> any(). --callback mech_step(any(), binary()) -> sasl_return(). - -start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). - -init([]) -> - ets:new(sasl_mechanism, - [named_table, public, - {keypos, #sasl_mechanism.mechanism}]), - cyrsasl_plain:start([]), - cyrsasl_digest:start([]), - cyrsasl_scram:start([]), - cyrsasl_anonymous:start([]), - cyrsasl_oauth:start([]), - {ok, #state{}}. - -handle_call(_Request, _From, State) -> - Reply = ok, - {reply, Reply, State}. - -handle_cast(_Msg, State) -> - {noreply, State}. - -handle_info(_Info, State) -> - {noreply, State}. - -terminate(_Reason, _State) -> - cyrsasl_plain:stop(), - cyrsasl_digest:stop(), - cyrsasl_scram:stop(), - cyrsasl_anonymous:stop(), - cyrsasl_oauth:stop(). - -code_change(_OldVsn, State, _Extra) -> - {ok, State}. - --spec format_error(mechanism() | sasl_state(), error_reason()) -> {atom(), binary()}. -format_error(_, unsupported_mechanism) -> - {'invalid-mechanism', <<"Unsupported mechanism">>}; -format_error(_, nodeprep_failed) -> - {'bad-protocol', <<"Nodeprep failed">>}; -format_error(_, empty_username) -> - {'bad-protocol', <<"Empty username">>}; -format_error(_, aborted) -> - {'aborted', <<"Aborted">>}; -format_error(#sasl_state{mech_mod = Mod}, Reason) -> - Mod:format_error(Reason); -format_error(Mech, Reason) -> - case ets:lookup(sasl_mechanism, Mech) of - [#sasl_mechanism{module = Mod}] -> - Mod:format_error(Reason); - [] -> - {'invalid-mechanism', <<"Unsupported mechanism">>} - end. - --spec register_mechanism(Mechanim :: mechanism(), Module :: module(), - PasswordType :: password_type()) -> any(). - -register_mechanism(Mechanism, Module, PasswordType) -> - ets:insert(sasl_mechanism, - #sasl_mechanism{mechanism = Mechanism, module = Module, - password_type = PasswordType}). - -check_credentials(_State, Props) -> - User = proplists:get_value(authzid, Props, <<>>), - case jid:nodeprep(User) of - error -> {error, nodeprep_failed}; - <<"">> -> {error, empty_username}; - _LUser -> ok - end. - --spec listmech(Host ::binary()) -> Mechanisms::mechanisms(). - -listmech(Host) -> - ets:select(sasl_mechanism, - [{#sasl_mechanism{mechanism = '$1', - password_type = '$2', _ = '_'}, - case catch ejabberd_auth:store_type(Host) of - external -> [{'==', '$2', plain}]; - scram -> [{'/=', '$2', digest}]; - {'EXIT', {undef, [{Module, store_type, []} | _]}} -> - ?WARNING_MSG("~p doesn't implement the function store_type/0", - [Module]), - []; - _Else -> [] - end, - ['$1']}]). - --spec server_new(binary(), binary(), binary(), term(), - fun(), fun(), fun()) -> sasl_state(). -server_new(Service, ServerFQDN, UserRealm, _SecFlags, - GetPassword, CheckPassword, CheckPasswordDigest) -> - #sasl_state{service = Service, myname = ServerFQDN, - realm = UserRealm, get_password = GetPassword, - check_password = CheckPassword, - check_password_digest = CheckPasswordDigest}. - --spec server_start(sasl_state(), mechanism(), binary()) -> sasl_return(). -server_start(State, Mech, ClientIn) -> - case lists:member(Mech, - listmech(State#sasl_state.myname)) - of - true -> - case ets:lookup(sasl_mechanism, Mech) of - [#sasl_mechanism{module = Module}] -> - {ok, MechState} = - Module:mech_new(State#sasl_state.myname, - State#sasl_state.get_password, - State#sasl_state.check_password, - State#sasl_state.check_password_digest), - server_step(State#sasl_state{mech_mod = Module, - mech_name = Mech, - mech_state = MechState}, - ClientIn); - _ -> {error, unsupported_mechanism, <<"">>} - end; - false -> {error, unsupported_mechanism, <<"">>} - end. - --spec server_step(sasl_state(), binary()) -> sasl_return(). -server_step(State, ClientIn) -> - Module = State#sasl_state.mech_mod, - MechState = State#sasl_state.mech_state, - case Module:mech_step(MechState, ClientIn) of - {ok, Props} -> - case check_credentials(State, Props) of - ok -> {ok, Props}; - {error, Error} -> {error, Error, <<"">>} - end; - {ok, Props, ServerOut} -> - case check_credentials(State, Props) of - ok -> {ok, Props, ServerOut}; - {error, Error} -> {error, Error, <<"">>} - end; - {continue, ServerOut, NewMechState} -> - {continue, ServerOut, State#sasl_state{mech_state = NewMechState}}; - {error, Error, Username} -> - {error, Error, Username}; - {error, Error} -> - {error, Error, <<"">>} - end. - --spec get_mech(sasl_state()) -> binary(). -get_mech(#sasl_state{mech_name = Mech}) -> - Mech. diff --git a/src/cyrsasl_anonymous.erl b/src/cyrsasl_anonymous.erl deleted file mode 100644 index 557f22c29..000000000 --- a/src/cyrsasl_anonymous.erl +++ /dev/null @@ -1,50 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : cyrsasl_anonymous.erl -%%% Author : Magnus Henoch -%%% Purpose : ANONYMOUS SASL mechanism -%%% See http://www.ietf.org/internet-drafts/draft-ietf-sasl-anon-05.txt -%%% Created : 23 Aug 2005 by Magnus Henoch -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(cyrsasl_anonymous). - --protocol({xep, 175, '1.2'}). - --export([start/1, stop/0, mech_new/4, mech_step/2]). - --behaviour(cyrsasl). - --record(state, {server = <<"">> :: binary()}). - -start(_Opts) -> - cyrsasl:register_mechanism(<<"ANONYMOUS">>, ?MODULE, plain). - -stop() -> ok. - -mech_new(Host, _GetPassword, _CheckPassword, _CheckPasswordDigest) -> - {ok, #state{server = Host}}. - -mech_step(#state{}, _ClientIn) -> - User = iolist_to_binary([p1_rand:get_string(), - integer_to_binary(p1_time_compat:unique_integer([positive]))]), - {ok, [{username, User}, - {authzid, User}, - {auth_module, ejabberd_auth_anonymous}]}. diff --git a/src/cyrsasl_digest.erl b/src/cyrsasl_digest.erl deleted file mode 100644 index 73ec9e1d1..000000000 --- a/src/cyrsasl_digest.erl +++ /dev/null @@ -1,270 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : cyrsasl_digest.erl -%%% Author : Alexey Shchepin -%%% Purpose : DIGEST-MD5 SASL mechanism -%%% Created : 11 Mar 2003 by Alexey Shchepin -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(cyrsasl_digest). - --behaviour(ejabberd_config). - --author('alexey@sevcom.net'). - --export([start/1, stop/0, mech_new/4, mech_step/2, - parse/1, format_error/1, opt_type/1]). - --include("logger.hrl"). - --behaviour(cyrsasl). - --type get_password_fun() :: fun((binary()) -> {false, any()} | - {binary(), atom()}). --type check_password_fun() :: fun((binary(), binary(), binary(), binary(), - fun((binary()) -> binary())) -> - {boolean(), any()} | - false). --type error_reason() :: parser_failed | invalid_digest_uri | - not_authorized | unexpected_response. --export_type([error_reason/0]). - --record(state, {step = 1 :: 1 | 3 | 5, - nonce = <<"">> :: binary(), - username = <<"">> :: binary(), - authzid = <<"">> :: binary(), - get_password :: get_password_fun(), - check_password :: check_password_fun(), - auth_module :: atom(), - host = <<"">> :: binary(), - hostfqdn = [] :: [binary()]}). - -start(_Opts) -> - Fqdn = get_local_fqdn(), - ?DEBUG("FQDN used to check DIGEST-MD5 SASL authentication: ~s", - [Fqdn]), - cyrsasl:register_mechanism(<<"DIGEST-MD5">>, ?MODULE, - digest). - -stop() -> ok. - --spec format_error(error_reason()) -> {atom(), binary()}. -format_error(parser_failed) -> - {'bad-protocol', <<"Response decoding failed">>}; -format_error(invalid_digest_uri) -> - {'bad-protocol', <<"Invalid digest URI">>}; -format_error(not_authorized) -> - {'not-authorized', <<"Invalid username or password">>}; -format_error(unexpected_response) -> - {'bad-protocol', <<"Unexpected response">>}. - -mech_new(Host, GetPassword, _CheckPassword, - CheckPasswordDigest) -> - {ok, - #state{step = 1, nonce = p1_rand:get_string(), - host = Host, hostfqdn = get_local_fqdn(), - get_password = GetPassword, - check_password = CheckPasswordDigest}}. - -mech_step(#state{step = 1, nonce = Nonce} = State, _) -> - {continue, - <<"nonce=\"", Nonce/binary, - "\",qop=\"auth\",charset=utf-8,algorithm=md5-sess">>, - State#state{step = 3}}; -mech_step(#state{step = 3, nonce = Nonce} = State, - ClientIn) -> - case parse(ClientIn) of - bad -> {error, parser_failed}; - KeyVals -> - DigestURI = proplists:get_value(<<"digest-uri">>, KeyVals, <<>>), - UserName = proplists:get_value(<<"username">>, KeyVals, <<>>), - case is_digesturi_valid(DigestURI, State#state.host, - State#state.hostfqdn) - of - false -> - ?DEBUG("User login not authorized because digest-uri " - "seems invalid: ~p (checking for Host " - "~p, FQDN ~p)", - [DigestURI, State#state.host, State#state.hostfqdn]), - {error, invalid_digest_uri, UserName}; - true -> - AuthzId = proplists:get_value(<<"authzid">>, KeyVals, <<>>), - case (State#state.get_password)(UserName) of - {false, _} -> {error, not_authorized, UserName}; - {Passwd, AuthModule} -> - case (State#state.check_password)(UserName, UserName, <<"">>, - proplists:get_value(<<"response">>, KeyVals, <<>>), - fun (PW) -> - response(KeyVals, - UserName, - PW, - Nonce, - AuthzId, - <<"AUTHENTICATE">>) - end) - of - {true, _} -> - RspAuth = response(KeyVals, UserName, Passwd, Nonce, - AuthzId, <<"">>), - {continue, <<"rspauth=", RspAuth/binary>>, - State#state{step = 5, auth_module = AuthModule, - username = UserName, - authzid = AuthzId}}; - false -> {error, not_authorized, UserName}; - {false, _} -> {error, not_authorized, UserName} - end - end - end - end; -mech_step(#state{step = 5, auth_module = AuthModule, - username = UserName, authzid = AuthzId}, - <<"">>) -> - {ok, - [{username, UserName}, {authzid, case AuthzId of - <<"">> -> UserName; - _ -> AuthzId - end - }, - {auth_module, AuthModule}]}; -mech_step(A, B) -> - ?DEBUG("SASL DIGEST: A ~p B ~p", [A, B]), - {error, unexpected_response}. - -parse(S) -> parse1(binary_to_list(S), "", []). - -parse1([$= | Cs], S, Ts) -> - parse2(Cs, lists:reverse(S), "", Ts); -parse1([$, | Cs], [], Ts) -> parse1(Cs, [], Ts); -parse1([$\s | Cs], [], Ts) -> parse1(Cs, [], Ts); -parse1([C | Cs], S, Ts) -> parse1(Cs, [C | S], Ts); -parse1([], [], T) -> lists:reverse(T); -parse1([], _S, _T) -> bad. - -parse2([$" | Cs], Key, Val, Ts) -> - parse3(Cs, Key, Val, Ts); -parse2([C | Cs], Key, Val, Ts) -> - parse4(Cs, Key, [C | Val], Ts); -parse2([], _, _, _) -> bad. - -parse3([$" | Cs], Key, Val, Ts) -> - parse4(Cs, Key, Val, Ts); -parse3([$\\, C | Cs], Key, Val, Ts) -> - parse3(Cs, Key, [C | Val], Ts); -parse3([C | Cs], Key, Val, Ts) -> - parse3(Cs, Key, [C | Val], Ts); -parse3([], _, _, _) -> bad. - -parse4([$, | Cs], Key, Val, Ts) -> - parse1(Cs, "", [{list_to_binary(Key), list_to_binary(lists:reverse(Val))} | Ts]); -parse4([$\s | Cs], Key, Val, Ts) -> - parse4(Cs, Key, Val, Ts); -parse4([C | Cs], Key, Val, Ts) -> - parse4(Cs, Key, [C | Val], Ts); -parse4([], Key, Val, Ts) -> -%% @doc Check if the digest-uri is valid. -%% RFC-2831 allows to provide the IP address in Host, -%% however ejabberd doesn't allow that. -%% If the service (for example jabber.example.org) -%% is provided by several hosts (being one of them server3.example.org), -%% then acceptable digest-uris would be: -%% xmpp/server3.example.org/jabber.example.org, xmpp/server3.example.org and -%% xmpp/jabber.example.org -%% The last version is not actually allowed by the RFC, but implemented by popular clients - parse1([], "", [{list_to_binary(Key), list_to_binary(lists:reverse(Val))} | Ts]). - -is_digesturi_valid(DigestURICase, JabberDomain, - JabberFQDN) -> - DigestURI = stringprep:tolower(DigestURICase), - case catch str:tokens(DigestURI, <<"/">>) of - [<<"xmpp">>, Host] -> - IsHostFqdn = is_host_fqdn(Host, JabberFQDN), - (Host == JabberDomain) or IsHostFqdn; - [<<"xmpp">>, Host, ServName] -> - IsHostFqdn = is_host_fqdn(Host, JabberFQDN), - (ServName == JabberDomain) and IsHostFqdn; - _ -> - false - end. - -is_host_fqdn(_Host, []) -> - false; -is_host_fqdn(Host, [Fqdn | _FqdnTail]) when Host == Fqdn -> - true; -is_host_fqdn(Host, [Fqdn | FqdnTail]) when Host /= Fqdn -> - is_host_fqdn(Host, FqdnTail). - -get_local_fqdn() -> - case ejabberd_config:get_option(fqdn) of - undefined -> - {ok, Hostname} = inet:gethostname(), - {ok, {hostent, Fqdn, _, _, _, _}} = inet:gethostbyname(Hostname), - [list_to_binary(Fqdn)]; - Fqdn -> - Fqdn - end. - -hex(S) -> - str:to_hexlist(S). - -proplists_get_bin_value(Key, Pairs, Default) -> - case proplists:get_value(Key, Pairs, Default) of - L when is_list(L) -> - list_to_binary(L); - L2 -> - L2 - end. - -response(KeyVals, User, Passwd, Nonce, AuthzId, - A2Prefix) -> - Realm = proplists_get_bin_value(<<"realm">>, KeyVals, <<>>), - CNonce = proplists_get_bin_value(<<"cnonce">>, KeyVals, <<>>), - DigestURI = proplists_get_bin_value(<<"digest-uri">>, KeyVals, <<>>), - NC = proplists_get_bin_value(<<"nc">>, KeyVals, <<>>), - QOP = proplists_get_bin_value(<<"qop">>, KeyVals, <<>>), - MD5Hash = erlang:md5(<>), - A1 = case AuthzId of - <<"">> -> - <>; - _ -> - <> - end, - A2 = case QOP of - <<"auth">> -> - <>; - _ -> - <> - end, - T = <<(hex((erlang:md5(A1))))/binary, ":", Nonce/binary, - ":", NC/binary, ":", CNonce/binary, ":", QOP/binary, - ":", (hex((erlang:md5(A2))))/binary>>, - hex((erlang:md5(T))). - --spec opt_type(fqdn) -> fun((binary() | [binary()]) -> [binary()]); - (atom()) -> [atom()]. -opt_type(fqdn) -> - fun(FQDN) when is_binary(FQDN) -> - [FQDN]; - (FQDNs) when is_list(FQDNs) -> - [iolist_to_binary(FQDN) || FQDN <- FQDNs] - end; -opt_type(_) -> [fqdn]. diff --git a/src/cyrsasl_oauth.erl b/src/cyrsasl_oauth.erl deleted file mode 100644 index 5520451cb..000000000 --- a/src/cyrsasl_oauth.erl +++ /dev/null @@ -1,104 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : cyrsasl_oauth.erl -%%% Author : Alexey Shchepin -%%% Purpose : X-OAUTH2 SASL mechanism -%%% Created : 17 Sep 2015 by Alexey Shchepin -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(cyrsasl_oauth). - --author('alexey@process-one.net'). - --export([start/1, stop/0, mech_new/4, mech_step/2, parse/1, format_error/1]). - --behaviour(cyrsasl). - --record(state, {host}). --type error_reason() :: parser_failed | not_authorized. --export_type([error_reason/0]). - -start(_Opts) -> - cyrsasl:register_mechanism(<<"X-OAUTH2">>, ?MODULE, plain). - -stop() -> ok. - --spec format_error(error_reason()) -> {atom(), binary()}. -format_error(parser_failed) -> - {'bad-protocol', <<"Response decoding failed">>}; -format_error(not_authorized) -> - {'not-authorized', <<"Invalid token">>}. - -mech_new(Host, _GetPassword, _CheckPassword, _CheckPasswordDigest) -> - {ok, #state{host = Host}}. - -mech_step(State, ClientIn) -> - case prepare(ClientIn) of - [AuthzId, User, Token] -> - case ejabberd_oauth:check_token( - User, State#state.host, [<<"sasl_auth">>], Token) of - true -> - {ok, - [{username, User}, {authzid, AuthzId}, - {auth_module, ejabberd_oauth}]}; - _ -> - {error, not_authorized, User} - end; - _ -> {error, parser_failed} - end. - -prepare(ClientIn) -> - case parse(ClientIn) of - [<<"">>, UserMaybeDomain, Token] -> - case parse_domain(UserMaybeDomain) of - %% login@domainpwd - [User, _Domain] -> [User, User, Token]; - %% loginpwd - [User] -> [User, User, Token] - end; - %% login@domainloginpwd - [AuthzId, User, Token] -> - case parse_domain(AuthzId) of - %% login@domainloginpwd - [AuthzUser, _Domain] -> [AuthzUser, User, Token]; - %% loginloginpwd - [AuthzUser] -> [AuthzUser, User, Token] - end; - _ -> error - end. - -parse(S) -> parse1(binary_to_list(S), "", []). - -parse1([0 | Cs], S, T) -> - parse1(Cs, "", [list_to_binary(lists:reverse(S)) | T]); -parse1([C | Cs], S, T) -> parse1(Cs, [C | S], T); -%parse1([], [], T) -> -% lists:reverse(T); -parse1([], S, T) -> - lists:reverse([list_to_binary(lists:reverse(S)) | T]). - -parse_domain(S) -> parse_domain1(binary_to_list(S), "", []). - -parse_domain1([$@ | Cs], S, T) -> - parse_domain1(Cs, "", [list_to_binary(lists:reverse(S)) | T]); -parse_domain1([C | Cs], S, T) -> - parse_domain1(Cs, [C | S], T); -parse_domain1([], S, T) -> - lists:reverse([list_to_binary(lists:reverse(S)) | T]). diff --git a/src/cyrsasl_plain.erl b/src/cyrsasl_plain.erl deleted file mode 100644 index 3bdcf8476..000000000 --- a/src/cyrsasl_plain.erl +++ /dev/null @@ -1,94 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : cyrsasl_plain.erl -%%% Author : Alexey Shchepin -%%% Purpose : PLAIN SASL mechanism -%%% Created : 8 Mar 2003 by Alexey Shchepin -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(cyrsasl_plain). - --author('alexey@process-one.net'). - --export([start/1, stop/0, mech_new/4, mech_step/2, parse/1, format_error/1]). - --behaviour(cyrsasl). - --record(state, {check_password}). --type error_reason() :: parser_failed | not_authorized. --export_type([error_reason/0]). - -start(_Opts) -> - cyrsasl:register_mechanism(<<"PLAIN">>, ?MODULE, plain). - -stop() -> ok. - --spec format_error(error_reason()) -> {atom(), binary()}. -format_error(parser_failed) -> - {'bad-protocol', <<"Response decoding failed">>}; -format_error(not_authorized) -> - {'not-authorized', <<"Invalid username or password">>}. - -mech_new(_Host, _GetPassword, CheckPassword, _CheckPasswordDigest) -> - {ok, #state{check_password = CheckPassword}}. - -mech_step(State, ClientIn) -> - case prepare(ClientIn) of - [AuthzId, User, Password] -> - case (State#state.check_password)(User, AuthzId, Password) of - {true, AuthModule} -> - {ok, - [{username, User}, {authzid, AuthzId}, - {auth_module, AuthModule}]}; - _ -> {error, not_authorized, User} - end; - _ -> {error, parser_failed} - end. - -prepare(ClientIn) -> - case parse(ClientIn) of - [<<"">>, UserMaybeDomain, Password] -> - case parse_domain(UserMaybeDomain) of - %% login@domainpwd - [User, _Domain] -> [User, User, Password]; - %% loginpwd - [User] -> [User, User, Password] - end; - [AuthzId, User, Password] -> - case parse_domain(AuthzId) of - %% login@domainloginpwd - [AuthzUser, _Domain] -> [AuthzUser, User, Password]; - %% loginloginpwd - [AuthzUser] -> [AuthzUser, User, Password] - end; - _ -> error - end. - -parse(S) -> - binary:split(S, <<0>>, [global]). - -parse_domain(S) -> parse_domain1(binary_to_list(S), "", []). - -parse_domain1([$@ | Cs], S, T) -> - parse_domain1(Cs, "", [list_to_binary(lists:reverse(S)) | T]); -parse_domain1([C | Cs], S, T) -> - parse_domain1(Cs, [C | S], T); -parse_domain1([], S, T) -> - lists:reverse([list_to_binary(lists:reverse(S)) | T]). diff --git a/src/cyrsasl_scram.erl b/src/cyrsasl_scram.erl deleted file mode 100644 index 069eae117..000000000 --- a/src/cyrsasl_scram.erl +++ /dev/null @@ -1,249 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : cyrsasl_scram.erl -%%% Author : Stephen Röttger -%%% Purpose : SASL SCRAM authentication -%%% Created : 7 Aug 2011 by Stephen Röttger -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(cyrsasl_scram). - --author('stephen.roettger@googlemail.com'). - --protocol({rfc, 5802}). - --export([start/1, stop/0, mech_new/4, mech_step/2, format_error/1]). - --include("scram.hrl"). --include("logger.hrl"). - --behaviour(cyrsasl). - --record(state, - {step = 2 :: 2 | 4, - stored_key = <<"">> :: binary(), - server_key = <<"">> :: binary(), - username = <<"">> :: binary(), - auth_module :: module(), - get_password :: fun((binary()) -> - {false | ejabberd_auth:password(), module()}), - auth_message = <<"">> :: binary(), - client_nonce = <<"">> :: binary(), - server_nonce = <<"">> :: binary()}). - --define(SALT_LENGTH, 16). --define(NONCE_LENGTH, 16). - --type error_reason() :: unsupported_extension | bad_username | - not_authorized | saslprep_failed | - parser_failed | bad_attribute | - nonce_mismatch | bad_channel_binding. - --export_type([error_reason/0]). - -start(_Opts) -> - cyrsasl:register_mechanism(<<"SCRAM-SHA-1">>, ?MODULE, - scram). - -stop() -> ok. - --spec format_error(error_reason()) -> {atom(), binary()}. -format_error(unsupported_extension) -> - {'bad-protocol', <<"Unsupported extension">>}; -format_error(bad_username) -> - {'invalid-authzid', <<"Malformed username">>}; -format_error(not_authorized) -> - {'not-authorized', <<"Invalid username or password">>}; -format_error(saslprep_failed) -> - {'not-authorized', <<"SASLprep failed">>}; -format_error(parser_failed) -> - {'bad-protocol', <<"Response decoding failed">>}; -format_error(bad_attribute) -> - {'bad-protocol', <<"Malformed or unexpected attribute">>}; -format_error(nonce_mismatch) -> - {'bad-protocol', <<"Nonce mismatch">>}; -format_error(bad_channel_binding) -> - {'bad-protocol', <<"Invalid channel binding">>}. - -mech_new(_Host, GetPassword, _CheckPassword, - _CheckPasswordDigest) -> - {ok, #state{step = 2, get_password = GetPassword}}. - -mech_step(#state{step = 2} = State, ClientIn) -> - case re:split(ClientIn, <<",">>, [{return, binary}]) of - [_CBind, _AuthorizationIdentity, _UserNameAttribute, _ClientNonceAttribute, ExtensionAttribute | _] - when ExtensionAttribute /= <<"">> -> - {error, unsupported_extension}; - [CBind, _AuthorizationIdentity, UserNameAttribute, ClientNonceAttribute | _] - when (CBind == <<"y">>) or (CBind == <<"n">>) -> - case parse_attribute(UserNameAttribute) of - {error, Reason} -> {error, Reason}; - {_, EscapedUserName} -> - case unescape_username(EscapedUserName) of - error -> {error, bad_username}; - UserName -> - case parse_attribute(ClientNonceAttribute) of - {$r, ClientNonce} -> - {Pass, AuthModule} = (State#state.get_password)(UserName), - LPass = if is_binary(Pass) -> jid:resourceprep(Pass); - true -> Pass - end, - if Pass == false -> - {error, not_authorized, UserName}; - LPass == error -> - {error, saslprep_failed, UserName}; - true -> - {StoredKey, ServerKey, Salt, IterationCount} = - if is_record(Pass, scram) -> - {base64:decode(Pass#scram.storedkey), - base64:decode(Pass#scram.serverkey), - base64:decode(Pass#scram.salt), - Pass#scram.iterationcount}; - true -> - TempSalt = - p1_rand:bytes(?SALT_LENGTH), - SaltedPassword = - scram:salted_password(Pass, - TempSalt, - ?SCRAM_DEFAULT_ITERATION_COUNT), - {scram:stored_key(scram:client_key(SaltedPassword)), - scram:server_key(SaltedPassword), - TempSalt, - ?SCRAM_DEFAULT_ITERATION_COUNT} - end, - ClientFirstMessageBare = - str:substr(ClientIn, - str:str(ClientIn, <<"n=">>)), - ServerNonce = - base64:encode(p1_rand:bytes(?NONCE_LENGTH)), - ServerFirstMessage = - iolist_to_binary( - ["r=", - ClientNonce, - ServerNonce, - ",", "s=", - base64:encode(Salt), - ",", "i=", - integer_to_list(IterationCount)]), - {continue, ServerFirstMessage, - State#state{step = 4, stored_key = StoredKey, - server_key = ServerKey, - auth_module = AuthModule, - auth_message = - <>, - client_nonce = ClientNonce, - server_nonce = ServerNonce, - username = UserName}} - end; - _ -> {error, bad_attribute} - end - end - end; - _Else -> {error, parser_failed} - end; -mech_step(#state{step = 4} = State, ClientIn) -> - case str:tokens(ClientIn, <<",">>) of - [GS2ChannelBindingAttribute, NonceAttribute, - ClientProofAttribute] -> - case parse_attribute(GS2ChannelBindingAttribute) of - {$c, CVal} -> - ChannelBindingSupport = try binary:first(base64:decode(CVal)) - catch _:badarg -> 0 - end, - if (ChannelBindingSupport == $n) - or (ChannelBindingSupport == $y) -> - Nonce = <<(State#state.client_nonce)/binary, - (State#state.server_nonce)/binary>>, - case parse_attribute(NonceAttribute) of - {$r, CompareNonce} when CompareNonce == Nonce -> - case parse_attribute(ClientProofAttribute) of - {$p, ClientProofB64} -> - ClientProof = try base64:decode(ClientProofB64) - catch _:badarg -> <<>> - end, - AuthMessage = iolist_to_binary( - [State#state.auth_message, - ",", - str:substr(ClientIn, 1, - str:str(ClientIn, <<",p=">>) - - 1)]), - ClientSignature = - scram:client_signature(State#state.stored_key, - AuthMessage), - ClientKey = scram:client_key(ClientProof, - ClientSignature), - CompareStoredKey = scram:stored_key(ClientKey), - if CompareStoredKey == State#state.stored_key -> - ServerSignature = - scram:server_signature(State#state.server_key, - AuthMessage), - {ok, [{username, State#state.username}, - {auth_module, State#state.auth_module}, - {authzid, State#state.username}], - <<"v=", - (base64:encode(ServerSignature))/binary>>}; - true -> {error, not_authorized, State#state.username} - end; - _ -> {error, bad_attribute} - end; - {$r, _} -> {error, nonce_mismatch}; - _ -> {error, bad_attribute} - end; - true -> {error, bad_channel_binding} - end; - _ -> {error, bad_attribute} - end; - _ -> {error, parser_failed} - end. - -parse_attribute(<>) when Val /= <<>> -> - case is_alpha(Name) of - true -> {Name, Val}; - false -> {error, bad_attribute} - end; -parse_attribute(_) -> - {error, bad_attribute}. - -unescape_username(<<"">>) -> <<"">>; -unescape_username(EscapedUsername) -> - Pos = str:str(EscapedUsername, <<"=">>), - if Pos == 0 -> EscapedUsername; - true -> - Start = str:substr(EscapedUsername, 1, Pos - 1), - End = str:substr(EscapedUsername, Pos), - EndLen = byte_size(End), - if EndLen < 3 -> error; - true -> - case str:substr(End, 1, 3) of - <<"=2C">> -> - <>; - <<"=3D">> -> - < TLSRequired. -tls_verify(#{tls_verify := TLSVerify}) -> - TLSVerify. - tls_enabled(#{tls_enabled := TLSEnabled, tls_required := TLSRequired, tls_verify := TLSVerify}) -> @@ -358,25 +355,41 @@ unauthenticated_stream_features(#{lserver := LServer}) -> authenticated_stream_features(#{lserver := LServer}) -> ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]). -sasl_mechanisms(Mechs, #{lserver := LServer}) -> +sasl_mechanisms(Mechs, #{lserver := LServer} = State) -> + Type = ejabberd_auth:store_type(LServer), Mechs1 = ejabberd_config:get_option({disable_sasl_mechanisms, LServer}, []), - Mechs2 = case ejabberd_auth_anonymous:is_sasl_anonymous_enabled(LServer) of - true -> Mechs1; - false -> [<<"ANONYMOUS">>|Mechs1] - end, - Mechs -- Mechs2. + %% I re-created it from cyrsasl ets magic, but I think it's wrong + %% TODO: need to check before 18.09 release + lists:filter( + fun(<<"ANONYMOUS">>) -> + ejabberd_auth_anonymous:is_sasl_anonymous_enabled(LServer); + (<<"DIGEST-MD5">>) -> Type == plain; + (<<"SCRAM-SHA-1">>) -> Type /= external; + (<<"PLAIN">>) -> true; + (<<"X-OAUTH2">>) -> true; + (<<"EXTERNAL">>) -> maps:get(tls_verify, State, false); + (_) -> false + end, Mechs -- Mechs1). -get_password_fun(#{lserver := LServer}) -> +get_password_fun(_Mech, #{lserver := LServer}) -> fun(U) -> ejabberd_auth:get_password_with_authmodule(U, LServer) end. -check_password_fun(#{lserver := LServer}) -> +check_password_fun(<<"X-OAUTH2">>, #{lserver := LServer}) -> + fun(User, _AuthzId, Token) -> + case ejabberd_oauth:check_token( + User, LServer, [<<"sasl_auth">>], Token) of + true -> {true, ejabberd_oauth}; + _ -> {false, ejabberd_oauth} + end + end; +check_password_fun(_Mech, #{lserver := LServer}) -> fun(U, AuthzId, P) -> ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P) end. -check_password_digest_fun(#{lserver := LServer}) -> +check_password_digest_fun(_Mech, #{lserver := LServer}) -> fun(U, AuthzId, P, D, DG) -> ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG) end. @@ -920,7 +933,7 @@ change_shaper(#{shaper := ShaperName, ip := IP, lserver := LServer, Shaper = acl:access_matches(ShaperName, #{usr => jid:split(JID), ip => IP}, LServer), - xmpp_stream_in:change_shaper(State, Shaper). + xmpp_stream_in:change_shaper(State, ejabberd_shaper:new(Shaper)). -spec format_reason(state(), term()) -> binary(). format_reason(#{stop_reason := Reason}, _) -> diff --git a/src/ejabberd_config.erl b/src/ejabberd_config.erl index 8e55ae4d0..5dcb24711 100644 --- a/src/ejabberd_config.erl +++ b/src/ejabberd_config.erl @@ -782,8 +782,13 @@ set_opts(State) -> fun(#local_config{key = Key, value = Val}) -> {Key, Val} end, Opts)), + set_fqdn(), set_log_level(). +set_fqdn() -> + FQDNs = get_option(fqdn, []), + xmpp:set_config([{fqdn, FQDNs}]). + set_log_level() -> Level = get_option(loglevel, 4), ejabberd_logger:set(Level). @@ -1452,10 +1457,16 @@ opt_type(node_start) -> fun(I) when is_integer(I), I>=0 -> I end; opt_type(validate_stream) -> fun(B) when is_boolean(B) -> B end; +opt_type(fqdn) -> + fun(Domain) when is_binary(Domain) -> + [Domain]; + (Domains) -> + [iolist_to_binary(Domain) || Domain <- Domains] + end; opt_type(_) -> [hide_sensitive_log_data, hosts, language, max_fsm_queue, default_db, default_ram_db, queue_type, queue_dir, loglevel, - use_cache, cache_size, cache_missed, cache_life_time, + use_cache, cache_size, cache_missed, cache_life_time, fqdn, shared_key, node_start, validate_stream, negotiation_timeout]. -spec may_hide_data(any()) -> any(). diff --git a/src/ejabberd_idna.erl b/src/ejabberd_idna.erl deleted file mode 100644 index ef47c45c1..000000000 --- a/src/ejabberd_idna.erl +++ /dev/null @@ -1,224 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : ejabberd_idna.erl -%%% Author : Alexey Shchepin -%%% Purpose : Support for IDNA (RFC3490) -%%% Created : 10 Apr 2004 by Alexey Shchepin -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(ejabberd_idna). - --author('alexey@process-one.net'). - --export([domain_utf8_to_ascii/1, - domain_ucs2_to_ascii/1, - utf8_to_ucs2/1]). - --ifdef(TEST). --include_lib("eunit/include/eunit.hrl"). --endif. - --spec domain_utf8_to_ascii(binary()) -> false | binary(). - -domain_utf8_to_ascii(Domain) -> - domain_ucs2_to_ascii(utf8_to_ucs2(Domain)). - -utf8_to_ucs2(S) -> - utf8_to_ucs2(binary_to_list(S), ""). - -utf8_to_ucs2([], R) -> lists:reverse(R); -utf8_to_ucs2([C | S], R) when C < 128 -> - utf8_to_ucs2(S, [C | R]); -utf8_to_ucs2([C1, C2 | S], R) when C1 < 224 -> - utf8_to_ucs2(S, [C1 band 31 bsl 6 bor C2 band 63 | R]); -utf8_to_ucs2([C1, C2, C3 | S], R) when C1 < 240 -> - utf8_to_ucs2(S, - [C1 band 15 bsl 12 bor (C2 band 63 bsl 6) bor C3 band 63 - | R]). - --spec domain_ucs2_to_ascii(list()) -> false | binary(). - -domain_ucs2_to_ascii(Domain) -> - case catch domain_ucs2_to_ascii1(Domain) of - {'EXIT', _Reason} -> false; - Res -> iolist_to_binary(Res) - end. - -domain_ucs2_to_ascii1(Domain) -> - Parts = string:tokens(Domain, - [46, 12290, 65294, 65377]), - ASCIIParts = lists:map(fun (P) -> to_ascii(P) end, - Parts), - string:strip(lists:flatmap(fun (P) -> [$. | P] end, - ASCIIParts), - left, $.). - -%% Domain names are already nameprep'ed in ejabberd, so we skiping this step -to_ascii(Name) -> - false = lists:any(fun (C) - when (0 =< C) and (C =< 44) or - (46 =< C) and (C =< 47) - or (58 =< C) and (C =< 64) - or (91 =< C) and (C =< 96) - or (123 =< C) and (C =< 127) -> - true; - (_) -> false - end, - Name), - case Name of - [H | _] when H /= $- -> true = lists:last(Name) /= $- - end, - ASCIIName = case lists:any(fun (C) -> C > 127 end, Name) - of - true -> - true = case Name of - "xn--" ++ _ -> false; - _ -> true - end, - "xn--" ++ punycode_encode(Name); - false -> Name - end, - L = length(ASCIIName), - true = (1 =< L) and (L =< 63), - ASCIIName. - -%%% PUNYCODE (RFC3492) - --define(BASE, 36). - --define(TMIN, 1). - --define(TMAX, 26). - --define(SKEW, 38). - --define(DAMP, 700). - --define(INITIAL_BIAS, 72). - --define(INITIAL_N, 128). - -punycode_encode(Input) -> - N = (?INITIAL_N), - Delta = 0, - Bias = (?INITIAL_BIAS), - Basic = lists:filter(fun (C) -> C =< 127 end, Input), - NonBasic = lists:filter(fun (C) -> C > 127 end, Input), - L = length(Input), - B = length(Basic), - SNonBasic = lists:usort(NonBasic), - Output1 = if B > 0 -> Basic ++ "-"; - true -> "" - end, - Output2 = punycode_encode1(Input, SNonBasic, B, B, L, N, - Delta, Bias, ""), - Output1 ++ Output2. - -punycode_encode1(Input, [M | SNonBasic], B, H, L, N, - Delta, Bias, Out) - when H < L -> - Delta1 = Delta + (M - N) * (H + 1), - % let n = m - {NewDelta, NewBias, NewH, NewOut} = lists:foldl(fun (C, - {ADelta, ABias, AH, - AOut}) -> - if C < M -> - {ADelta + 1, - ABias, AH, - AOut}; - C == M -> - NewOut = - punycode_encode_delta(ADelta, - ABias, - AOut), - NewBias = - adapt(ADelta, - H + - 1, - H - == - B), - {0, NewBias, - AH + 1, - NewOut}; - true -> - {ADelta, - ABias, AH, - AOut} - end - end, - {Delta1, Bias, H, Out}, - Input), - punycode_encode1(Input, SNonBasic, B, NewH, L, M + 1, - NewDelta + 1, NewBias, NewOut); -punycode_encode1(_Input, _SNonBasic, _B, _H, _L, _N, - _Delta, _Bias, Out) -> - lists:reverse(Out). - -punycode_encode_delta(Delta, Bias, Out) -> - punycode_encode_delta(Delta, Bias, Out, ?BASE). - -punycode_encode_delta(Delta, Bias, Out, K) -> - T = if K =< Bias -> ?TMIN; - K >= Bias + (?TMAX) -> ?TMAX; - true -> K - Bias - end, - if Delta < T -> [codepoint(Delta) | Out]; - true -> - C = T + (Delta - T) rem ((?BASE) - T), - punycode_encode_delta((Delta - T) div ((?BASE) - T), - Bias, [codepoint(C) | Out], K + (?BASE)) - end. - -adapt(Delta, NumPoints, FirstTime) -> - Delta1 = if FirstTime -> Delta div (?DAMP); - true -> Delta div 2 - end, - Delta2 = Delta1 + Delta1 div NumPoints, - adapt1(Delta2, 0). - -adapt1(Delta, K) -> - if Delta > ((?BASE) - (?TMIN)) * (?TMAX) div 2 -> - adapt1(Delta div ((?BASE) - (?TMIN)), K + (?BASE)); - true -> - K + - ((?BASE) - (?TMIN) + 1) * Delta div (Delta + (?SKEW)) - end. - -codepoint(C) -> - if (0 =< C) and (C =< 25) -> C + 97; - (26 =< C) and (C =< 35) -> C + 22 - end. - -%%%=================================================================== -%%% Unit tests -%%%=================================================================== --ifdef(TEST). - -acsii_test() -> - ?assertEqual(<<"test.org">>, domain_utf8_to_ascii(<<"test.org">>)). - -utf8_test() -> - ?assertEqual( - <<"xn--d1acufc.xn--p1ai">>, - domain_utf8_to_ascii( - <<208,180,208,190,208,188,208,181,208,189,46,209,128,209,132>>)). - --endif. diff --git a/src/ejabberd_logger.erl b/src/ejabberd_logger.erl index fecd9485c..2b3eab0eb 100644 --- a/src/ejabberd_logger.erl +++ b/src/ejabberd_logger.erl @@ -214,6 +214,10 @@ set(LogLevel) when is_integer(LogLevel) -> ok end, gen_event:which_handlers(lager_event)) end, + case LogLevel of + 5 -> xmpp:set_config([{debug, true}]); + _ -> ok + end, {module, lager}; set({_LogLevel, _}) -> error_logger:error_msg("custom loglevels are not supported for 'lager'"), diff --git a/src/ejabberd_pkix.erl b/src/ejabberd_pkix.erl index 0f23e6871..002a98917 100644 --- a/src/ejabberd_pkix.erl +++ b/src/ejabberd_pkix.erl @@ -133,7 +133,7 @@ get_certfile(Domain) -> -spec get_certfile_no_default(binary()) -> {ok, binary()} | error. get_certfile_no_default(Domain) -> - case ejabberd_idna:domain_utf8_to_ascii(Domain) of + case xmpp_idna:domain_utf8_to_ascii(Domain) of false -> error; ASCIIDomain -> diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index 7a6bc46e5..20106ab6f 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -30,8 +30,8 @@ %% xmpp_stream_in callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, - compress_methods/1, +-export([tls_options/1, tls_required/1, tls_enabled/1, + compress_methods/1, sasl_mechanisms/2, unauthenticated_stream_features/1, authenticated_stream_features/1, handle_stream_start/2, handle_stream_end/2, handle_stream_established/1, handle_auth_success/4, @@ -144,12 +144,15 @@ tls_options(#{tls_options := TLSOpts, server_host := LServer}) -> tls_required(#{server_host := LServer}) -> ejabberd_s2s:tls_required(LServer). -tls_verify(#{server_host := LServer}) -> - ejabberd_s2s:tls_verify(LServer). - tls_enabled(#{server_host := LServer}) -> ejabberd_s2s:tls_enabled(LServer). +sasl_mechanisms(Mechs, #{server_host := LServer}) -> + lists:filter( + fun(<<"EXTERNAL">>) -> ejabberd_s2s:tls_verify(LServer); + (_) -> false + end, Mechs). + compress_methods(#{server_host := LServer}) -> case ejabberd_s2s:zlib_enabled(LServer) of true -> [<<"zlib">>]; @@ -344,7 +347,7 @@ set_idle_timeout(State) -> change_shaper(#{shaper := ShaperName, server_host := ServerHost} = State, RServer) -> Shaper = acl:match_rule(ServerHost, ShaperName, jid:make(RServer)), - xmpp_stream_in:change_shaper(State, Shaper). + xmpp_stream_in:change_shaper(State, ejabberd_shaper:new(Shaper)). -spec listen_opt_type(shaper) -> fun((any()) -> any()); (certfile) -> fun((binary()) -> binary()); diff --git a/src/ejabberd_service.erl b/src/ejabberd_service.erl index 496532a6b..96478be93 100644 --- a/src/ejabberd_service.erl +++ b/src/ejabberd_service.erl @@ -101,7 +101,7 @@ init([State, Opts]) -> end, GlobalRoutes = proplists:get_value(global_routes, Opts, true), Timeout = ejabberd_config:negotiation_timeout(), - State1 = xmpp_stream_in:change_shaper(State, Shaper), + State1 = xmpp_stream_in:change_shaper(State, ejabberd_shaper:new(Shaper)), State2 = xmpp_stream_in:set_timeout(State1, Timeout), State3 = State2#{access => Access, xmlns => ?NS_COMPONENT, diff --git a/src/ejabberd_sup.erl b/src/ejabberd_sup.erl index 73cb5b99f..b509cdfc0 100644 --- a/src/ejabberd_sup.erl +++ b/src/ejabberd_sup.erl @@ -39,7 +39,6 @@ init([]) -> {ok, {{one_for_one, 10, 1}, [worker(ejabberd_hooks), worker(ejabberd_cluster), - worker(cyrsasl), worker(translate), worker(ejabberd_access_permissions), worker(ejabberd_ctl), diff --git a/src/scram.erl b/src/scram.erl deleted file mode 100644 index 48557ab39..000000000 --- a/src/scram.erl +++ /dev/null @@ -1,81 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : scram.erl -%%% Author : Stephen Röttger -%%% Purpose : SCRAM (RFC 5802) -%%% Created : 7 Aug 2011 by Stephen Röttger -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(scram). - --author('stephen.roettger@googlemail.com'). - -%% External exports -%% ejabberd doesn't implement SASLPREP, so we use the similar RESOURCEPREP instead --export([salted_password/3, stored_key/1, server_key/1, - server_signature/2, client_signature/2, client_key/1, - client_key/2]). - --spec salted_password(binary(), binary(), non_neg_integer()) -> binary(). - -salted_password(Password, Salt, IterationCount) -> - hi(jid:resourceprep(Password), Salt, IterationCount). - --spec client_key(binary()) -> binary(). - -client_key(SaltedPassword) -> - sha_mac(SaltedPassword, <<"Client Key">>). - --spec stored_key(binary()) -> binary(). - -stored_key(ClientKey) -> crypto:hash(sha, ClientKey). - --spec server_key(binary()) -> binary(). - -server_key(SaltedPassword) -> - sha_mac(SaltedPassword, <<"Server Key">>). - --spec client_signature(binary(), binary()) -> binary(). - -client_signature(StoredKey, AuthMessage) -> - sha_mac(StoredKey, AuthMessage). - --spec client_key(binary(), binary()) -> binary(). - -client_key(ClientProof, ClientSignature) -> - crypto:exor(ClientProof, ClientSignature). - --spec server_signature(binary(), binary()) -> binary(). - -server_signature(ServerKey, AuthMessage) -> - sha_mac(ServerKey, AuthMessage). - -hi(Password, Salt, IterationCount) -> - U1 = sha_mac(Password, <>), - crypto:exor(U1, hi_round(Password, U1, IterationCount - 1)). - -hi_round(Password, UPrev, 1) -> - sha_mac(Password, UPrev); -hi_round(Password, UPrev, IterationCount) -> - U = sha_mac(Password, UPrev), - crypto:exor(U, hi_round(Password, U, IterationCount - 1)). - -sha_mac(Key, Data) -> - crypto:hmac(sha, Key, Data). diff --git a/src/xmpp_socket.erl b/src/xmpp_socket.erl deleted file mode 100644 index 5eedce67e..000000000 --- a/src/xmpp_socket.erl +++ /dev/null @@ -1,393 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : xmpp_socket.erl -%%% Author : Alexey Shchepin -%%% Purpose : Socket with zlib and TLS support library -%%% Created : 23 Aug 2006 by Alexey Shchepin -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(xmpp_socket). - --author('alexey@process-one.net'). - -%% API --export([start/4, - connect/3, - connect/4, - connect/5, - starttls/2, - compress/1, - compress/2, - reset_stream/1, - send_element/2, - send_header/2, - send_trailer/1, - send/2, - send_xml/2, - recv/2, - activate/1, - change_shaper/2, - monitor/1, - get_sockmod/1, - get_transport/1, - get_peer_certificate/2, - get_verify_result/1, - close/1, - pp/1, - sockname/1, peername/1]). - --include("xmpp.hrl"). --include("logger.hrl"). - --type sockmod() :: ejabberd_bosh | - ejabberd_http_ws | - gen_tcp | fast_tls | ezlib. --type receiver() :: atom(). --type socket() :: pid() | inet:socket() | - fast_tls:tls_socket() | - ezlib:zlib_socket() | - ejabberd_bosh:bosh_socket() | - ejabberd_http_ws:ws_socket(). - --record(socket_state, {sockmod = gen_tcp :: sockmod(), - socket :: socket(), - max_stanza_size = infinity :: timeout(), - xml_stream :: undefined | fxml_stream:xml_stream_state(), - shaper = none :: none | ejabberd_shaper:shaper(), - receiver :: receiver()}). - --type socket_state() :: #socket_state{}. - --export_type([socket/0, socket_state/0, sockmod/0]). - --callback start({module(), socket_state()}, - [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore. --callback start_link({module(), socket_state()}, - [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore. --callback socket_type() -> xml_stream | independent | raw. - --define(is_http_socket(S), - (S#socket_state.sockmod == ejabberd_bosh orelse - S#socket_state.sockmod == ejabberd_http_ws)). - -%%==================================================================== -%% API -%%==================================================================== --spec start(atom(), sockmod(), socket(), [proplists:property()]) - -> {ok, pid() | independent} | {error, inet:posix() | any()} | ignore. -start(Module, SockMod, Socket, Opts) -> - try - case Module:socket_type() of - independent -> - {ok, independent}; - xml_stream -> - MaxStanzaSize = proplists:get_value(max_stanza_size, Opts, infinity), - Receiver = proplists:get_value(receiver, Opts), - SocketData = #socket_state{sockmod = SockMod, - socket = Socket, - receiver = Receiver, - max_stanza_size = MaxStanzaSize}, - {ok, Pid} = Module:start({?MODULE, SocketData}, Opts), - Receiver1 = if is_pid(Receiver) -> Receiver; - true -> Pid - end, - ok = controlling_process(SocketData, Receiver1), - ok = become_controller(SocketData, Pid), - {ok, Receiver1}; - raw -> - {ok, Pid} = Module:start({SockMod, Socket}, Opts), - ok = SockMod:controlling_process(Socket, Pid), - {ok, Pid} - end - catch - _:{badmatch, {error, _} = Err} -> - SockMod:close(Socket), - Err; - _:{badmatch, ignore} -> - SockMod:close(Socket), - ignore - end. - -connect(Addr, Port, Opts) -> - connect(Addr, Port, Opts, infinity, self()). - -connect(Addr, Port, Opts, Timeout) -> - connect(Addr, Port, Opts, Timeout, self()). - -connect(Addr, Port, Opts, Timeout, Owner) -> - case gen_tcp:connect(Addr, Port, Opts, Timeout) of - {ok, Socket} -> - SocketData = #socket_state{sockmod = gen_tcp, socket = Socket}, - case controlling_process(SocketData, Owner) of - ok -> - activate_after(Socket, Owner, 0), - {ok, SocketData}; - {error, _Reason} = Error -> - gen_tcp:close(Socket), - Error - end; - {error, _Reason} = Error -> - Error - end. - -starttls(#socket_state{socket = Socket, - receiver = undefined} = SocketData, TLSOpts) -> - case fast_tls:tcp_to_tls(Socket, TLSOpts) of - {ok, TLSSocket} -> - SocketData1 = SocketData#socket_state{socket = TLSSocket, - sockmod = fast_tls}, - SocketData2 = reset_stream(SocketData1), - case fast_tls:recv_data(TLSSocket, <<>>) of - {ok, TLSData} -> - parse(SocketData2, TLSData); - {error, _} = Err -> - Err - end; - {error, _} = Err -> - Err - end. - -compress(SocketData) -> compress(SocketData, undefined). - -compress(#socket_state{receiver = undefined, - sockmod = SockMod, - socket = Socket} = SocketData, Data) -> - {ok, ZlibSocket} = ezlib:enable_zlib(SockMod, Socket), - case Data of - undefined -> ok; - _ -> send(SocketData, Data) - end, - SocketData1 = SocketData#socket_state{socket = ZlibSocket, - sockmod = ezlib}, - SocketData2 = reset_stream(SocketData1), - case ezlib:recv_data(ZlibSocket, <<"">>) of - {ok, ZlibData} -> - parse(SocketData2, ZlibData); - {error, _} = Err -> - Err - end. - -reset_stream(#socket_state{xml_stream = XMLStream, - receiver = Receiver, - sockmod = SockMod, socket = Socket, - max_stanza_size = MaxStanzaSize} = SocketData) -> - XMLStream1 = try fxml_stream:reset(XMLStream) - catch error:_ -> - close_stream(XMLStream), - fxml_stream:new(self(), MaxStanzaSize) - end, - case Receiver of - undefined -> - SocketData#socket_state{xml_stream = XMLStream1}; - _ -> - Socket1 = SockMod:reset_stream(Socket), - SocketData#socket_state{xml_stream = XMLStream1, socket = Socket1} - end. - --spec send_element(socket_state(), fxml:xmlel()) -> ok | {error, inet:posix()}. -send_element(SocketData, El) when ?is_http_socket(SocketData) -> - send_xml(SocketData, {xmlstreamelement, El}); -send_element(SocketData, El) -> - send(SocketData, fxml:element_to_binary(El)). - --spec send_header(socket_state(), fxml:xmlel()) -> ok | {error, inet:posix()}. -send_header(SocketData, El) when ?is_http_socket(SocketData) -> - send_xml(SocketData, {xmlstreamstart, El#xmlel.name, El#xmlel.attrs}); -send_header(SocketData, El) -> - send(SocketData, fxml:element_to_header(El)). - --spec send_trailer(socket_state()) -> ok | {error, inet:posix()}. -send_trailer(SocketData) when ?is_http_socket(SocketData) -> - send_xml(SocketData, {xmlstreamend, <<"stream:stream">>}); -send_trailer(SocketData) -> - send(SocketData, <<"">>). - --spec send(socket_state(), iodata()) -> ok | {error, closed | inet:posix()}. -send(#socket_state{sockmod = SockMod, socket = Socket} = SocketData, Data) -> - ?DEBUG("(~s) Send XML on stream = ~p", [pp(SocketData), Data]), - try SockMod:send(Socket, Data) of - {error, einval} -> {error, closed}; - Result -> Result - catch _:badarg -> - %% Some modules throw badarg exceptions on closed sockets - %% TODO: their code should be improved - {error, closed} - end. - --spec send_xml(socket_state(), - {xmlstreamelement, fxml:xmlel()} | - {xmlstreamstart, binary(), [{binary(), binary()}]} | - {xmlstreamend, binary()} | - {xmlstreamraw, iodata()}) -> term(). -send_xml(SocketData, El) -> - (SocketData#socket_state.sockmod):send_xml(SocketData#socket_state.socket, El). - -recv(#socket_state{xml_stream = undefined} = SocketData, Data) -> - XMLStream = fxml_stream:new(self(), SocketData#socket_state.max_stanza_size), - recv(SocketData#socket_state{xml_stream = XMLStream}, Data); -recv(#socket_state{sockmod = SockMod, socket = Socket} = SocketData, Data) -> - case SockMod of - fast_tls -> - case fast_tls:recv_data(Socket, Data) of - {ok, TLSData} -> - parse(SocketData, TLSData); - {error, _} = Err -> - Err - end; - ezlib -> - case ezlib:recv_data(Socket, Data) of - {ok, ZlibData} -> - parse(SocketData, ZlibData); - {error, _} = Err -> - Err - end; - _ -> - parse(SocketData, Data) - end. - -change_shaper(#socket_state{receiver = undefined} = SocketData, Shaper) -> - ShaperState = ejabberd_shaper:new(Shaper), - SocketData#socket_state{shaper = ShaperState}; -change_shaper(#socket_state{sockmod = SockMod, - socket = Socket} = SocketData, Shaper) -> - SockMod:change_shaper(Socket, Shaper), - SocketData. - -monitor(#socket_state{receiver = undefined}) -> - make_ref(); -monitor(#socket_state{sockmod = SockMod, socket = Socket}) -> - SockMod:monitor(Socket). - -controlling_process(#socket_state{sockmod = SockMod, - socket = Socket}, Pid) -> - SockMod:controlling_process(Socket, Pid). - -become_controller(#socket_state{receiver = Receiver, - sockmod = SockMod, - socket = Socket}, Pid) -> - if is_pid(Receiver) -> - SockMod:become_controller(Receiver, Pid); - true -> - activate_after(Socket, Pid, 0) - end. - -get_sockmod(SocketData) -> - SocketData#socket_state.sockmod. - -get_transport(#socket_state{sockmod = SockMod, - socket = Socket}) -> - case SockMod of - gen_tcp -> tcp; - fast_tls -> tls; - ezlib -> - case ezlib:get_sockmod(Socket) of - gen_tcp -> tcp_zlib; - fast_tls -> tls_zlib - end; - ejabberd_bosh -> http_bind; - ejabberd_http_ws -> websocket - end. - -get_peer_certificate(SocketData, Type) -> - fast_tls:get_peer_certificate(SocketData#socket_state.socket, Type). - -get_verify_result(SocketData) -> - fast_tls:get_verify_result(SocketData#socket_state.socket). - -close(#socket_state{sockmod = SockMod, socket = Socket}) -> - SockMod:close(Socket). - -sockname(#socket_state{sockmod = SockMod, - socket = Socket}) -> - case SockMod of - gen_tcp -> inet:sockname(Socket); - _ -> SockMod:sockname(Socket) - end. - -peername(#socket_state{sockmod = SockMod, - socket = Socket}) -> - case SockMod of - gen_tcp -> inet:peername(Socket); - _ -> SockMod:peername(Socket) - end. - -activate(#socket_state{sockmod = SockMod, socket = Socket}) -> - case SockMod of - gen_tcp -> inet:setopts(Socket, [{active, once}]); - _ -> SockMod:setopts(Socket, [{active, once}]) - end. - -activate_after(Socket, Pid, Pause) -> - if Pause > 0 -> - erlang:send_after(Pause, Pid, {tcp, Socket, <<>>}); - true -> - Pid ! {tcp, Socket, <<>>} - end, - ok. - -pp(#socket_state{receiver = Receiver} = State) -> - Transport = get_transport(State), - Receiver1 = case Receiver of - undefined -> self(); - _ -> Receiver - end, - io_lib:format("~s|~w", [Transport, Receiver1]). - -parse(SocketData, Data) when Data == <<>>; Data == [] -> - case activate(SocketData) of - ok -> - {ok, SocketData}; - {error, _} = Err -> - Err - end; -parse(SocketData, [El | Els]) when is_record(El, xmlel) -> - self() ! {'$gen_event', {xmlstreamelement, El}}, - parse(SocketData, Els); -parse(SocketData, [El | Els]) when - element(1, El) == xmlstreamstart; - element(1, El) == xmlstreamelement; - element(1, El) == xmlstreamend; - element(1, El) == xmlstreamerror -> - self() ! {'$gen_event', El}, - parse(SocketData, Els); -parse(#socket_state{xml_stream = XMLStream, - socket = Socket, - shaper = ShaperState} = SocketData, Data) - when is_binary(Data) -> - ?DEBUG("(~s) Received XML on stream = ~p", [pp(SocketData), Data]), - XMLStream1 = fxml_stream:parse(XMLStream, Data), - {ShaperState1, Pause} = ejabberd_shaper:update(ShaperState, byte_size(Data)), - Ret = if Pause > 0 -> - activate_after(Socket, self(), Pause); - true -> - activate(SocketData) - end, - case Ret of - ok -> - {ok, SocketData#socket_state{xml_stream = XMLStream1, - shaper = ShaperState1}}; - {error, _} = Err -> - Err - end. - -close_stream(undefined) -> - ok; -close_stream(XMLStream) -> - fxml_stream:close(XMLStream). diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl deleted file mode 100644 index 31018d434..000000000 --- a/src/xmpp_stream_in.erl +++ /dev/null @@ -1,1220 +0,0 @@ -%%%------------------------------------------------------------------- -%%% Created : 26 Nov 2016 by Evgeny Khramtsov -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%------------------------------------------------------------------- --module(xmpp_stream_in). --define(GEN_SERVER, p1_server). --behaviour(?GEN_SERVER). - --protocol({rfc, 6120}). --protocol({xep, 114, '1.6'}). - -%% API --export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1, - send/2, close/1, close/2, send_error/3, establish/1, - get_transport/1, change_shaper/2, set_timeout/2, format_error/1]). - -%% gen_server callbacks --export([init/1, handle_cast/2, handle_call/3, handle_info/2, - terminate/2, code_change/3]). - -%%-define(DBGFSM, true). --ifdef(DBGFSM). --define(FSMOPTS, [{debug, [trace]}]). --else. --define(FSMOPTS, []). --endif. - --include("xmpp.hrl"). --type state() :: map(). --type stop_reason() :: {stream, reset | {in | out, stream_error()}} | - {tls, inet:posix() | atom() | binary()} | - {socket, inet:posix() | atom()} | - internal_failure. --export_type([state/0, stop_reason/0]). --callback init(list()) -> {ok, state()} | {error, term()} | ignore. --callback handle_cast(term(), state()) -> state(). --callback handle_call(term(), term(), state()) -> state(). --callback handle_info(term(), state()) -> state(). --callback terminate(term(), state()) -> any(). --callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. --callback handle_stream_start(stream_start(), state()) -> state(). --callback handle_stream_established(state()) -> state(). --callback handle_stream_end(stop_reason(), state()) -> state(). --callback handle_cdata(binary(), state()) -> state(). --callback handle_unauthenticated_packet(xmpp_element(), state()) -> state(). --callback handle_authenticated_packet(xmpp_element(), state()) -> state(). --callback handle_unbinded_packet(xmpp_element(), state()) -> state(). --callback handle_auth_success(binary(), binary(), module(), state()) -> state(). --callback handle_auth_failure(binary(), binary(), binary(), state()) -> state(). --callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state(). --callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state(). --callback handle_timeout(state()) -> state(). --callback get_password_fun(state()) -> fun(). --callback check_password_fun(state()) -> fun(). --callback check_password_digest_fun(state()) -> fun(). --callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}. --callback compress_methods(state()) -> [binary()]. --callback tls_options(state()) -> [proplists:property()]. --callback tls_required(state()) -> boolean(). --callback tls_verify(state()) -> boolean(). --callback tls_enabled(state()) -> boolean(). --callback sasl_mechanisms([cyrsasl:mechanism()], state()) -> [cyrsasl:mechanism()]. --callback unauthenticated_stream_features(state()) -> [xmpp_element()]. --callback authenticated_stream_features(state()) -> [xmpp_element()]. - -%% All callbacks are optional --optional_callbacks([init/1, - handle_cast/2, - handle_call/3, - handle_info/2, - terminate/2, - code_change/3, - handle_stream_start/2, - handle_stream_established/1, - handle_stream_end/2, - handle_cdata/2, - handle_authenticated_packet/2, - handle_unauthenticated_packet/2, - handle_unbinded_packet/2, - handle_auth_success/4, - handle_auth_failure/4, - handle_send/3, - handle_recv/3, - handle_timeout/1, - get_password_fun/1, - check_password_fun/1, - check_password_digest_fun/1, - bind/2, - compress_methods/1, - tls_options/1, - tls_required/1, - tls_verify/1, - tls_enabled/1, - sasl_mechanisms/2, - unauthenticated_stream_features/1, - authenticated_stream_features/1]). - -%%%=================================================================== -%%% API -%%%=================================================================== -start(Mod, Args, Opts) -> - ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). - -start_link(Mod, Args, Opts) -> - ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). - -call(Ref, Msg, Timeout) -> - ?GEN_SERVER:call(Ref, Msg, Timeout). - -cast(Ref, Msg) -> - ?GEN_SERVER:cast(Ref, Msg). - -reply(Ref, Reply) -> - ?GEN_SERVER:reply(Ref, Reply). - --spec stop(pid()) -> ok; - (state()) -> no_return(). -stop(Pid) when is_pid(Pid) -> - cast(Pid, stop); -stop(#{owner := Owner} = State) when Owner == self() -> - terminate(normal, State), - exit(normal); -stop(_) -> - erlang:error(badarg). - --spec send(pid(), xmpp_element()) -> ok; - (state(), xmpp_element()) -> state(). -send(Pid, Pkt) when is_pid(Pid) -> - cast(Pid, {send, Pkt}); -send(#{owner := Owner} = State, Pkt) when Owner == self() -> - send_pkt(State, Pkt); -send(_, _) -> - erlang:error(badarg). - --spec close(pid()) -> ok; - (state()) -> state(). -close(Pid) when is_pid(Pid) -> - close(Pid, closed); -close(#{owner := Owner} = State) when Owner == self() -> - close_socket(State); -close(_) -> - erlang:error(badarg). - --spec close(pid(), atom()) -> ok. -close(Pid, Reason) -> - cast(Pid, {close, Reason}). - --spec establish(state()) -> state(). -establish(State) -> - process_stream_established(State). - --spec set_timeout(state(), non_neg_integer() | infinity) -> state(). -set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> - case Timeout of - infinity -> State#{stream_timeout => infinity}; - _ -> - Time = p1_time_compat:monotonic_time(milli_seconds), - State#{stream_timeout => {Timeout, Time}} - end; -set_timeout(_, _) -> - erlang:error(badarg). - -get_transport(#{socket := Socket, owner := Owner}) - when Owner == self() -> - xmpp_socket:get_transport(Socket); -get_transport(_) -> - erlang:error(badarg). - --spec change_shaper(state(), ejabberd_shaper:shaper()) -> state(). -change_shaper(#{socket := Socket, owner := Owner} = State, Shaper) - when Owner == self() -> - Socket1 = xmpp_socket:change_shaper(Socket, Shaper), - State#{socket => Socket1}; -change_shaper(_, _) -> - erlang:error(badarg). - --spec format_error(stop_reason()) -> binary(). -format_error({socket, Reason}) -> - format("Connection failed: ~s", [format_inet_error(Reason)]); -format_error({stream, reset}) -> - <<"Stream reset by peer">>; -format_error({stream, {in, #stream_error{} = Err}}) -> - format("Stream closed by peer: ~s", [xmpp:format_stream_error(Err)]); -format_error({stream, {out, #stream_error{} = Err}}) -> - format("Stream closed by us: ~s", [xmpp:format_stream_error(Err)]); -format_error({tls, Reason}) -> - format("TLS failed: ~s", [format_tls_error(Reason)]); -format_error(internal_failure) -> - <<"Internal server error">>; -format_error(Err) -> - format("Unrecognized error: ~w", [Err]). - -%%%=================================================================== -%%% gen_server callbacks -%%%=================================================================== -init([Mod, {_SockMod, Socket}, Opts]) -> - Encrypted = proplists:get_bool(tls, Opts), - SocketMonitor = xmpp_socket:monitor(Socket), - case xmpp_socket:peername(Socket) of - {ok, IP} -> - Time = p1_time_compat:monotonic_time(milli_seconds), - State = #{owner => self(), - mod => Mod, - socket => Socket, - socket_monitor => SocketMonitor, - stream_timeout => {timer:seconds(30), Time}, - stream_direction => in, - stream_id => new_id(), - stream_state => wait_for_stream, - stream_header_sent => false, - stream_restarted => false, - stream_compressed => false, - stream_encrypted => Encrypted, - stream_version => {1,0}, - stream_authenticated => false, - codec_options => [ignore_els], - xmlns => ?NS_CLIENT, - lang => <<"">>, - user => <<"">>, - server => <<"">>, - resource => <<"">>, - lserver => <<"">>, - ip => IP}, - case try Mod:init([State, Opts]) - catch _:undef -> {ok, State} - end of - {ok, State1} when not Encrypted -> - {_, State2, Timeout} = noreply(State1), - {ok, State2, Timeout}; - {ok, State1} when Encrypted -> - TLSOpts = try callback(tls_options, State1) - catch _:{?MODULE, undef} -> [] - end, - case xmpp_socket:starttls(Socket, TLSOpts) of - {ok, TLSSocket} -> - State2 = State1#{socket => TLSSocket}, - {_, State3, Timeout} = noreply(State2), - {ok, State3, Timeout}; - {error, Reason} -> - {stop, Reason} - end; - {error, Reason} -> - {stop, Reason}; - ignore -> - ignore - end; - {error, _Reason} -> - ignore - end. - -handle_cast({send, Pkt}, State) -> - noreply(send_pkt(State, Pkt)); -handle_cast(stop, State) -> - {stop, normal, State}; -handle_cast({close, Reason}, State) -> - State1 = close_socket(State), - noreply( - case is_disconnected(State) of - true -> State1; - false -> process_stream_end({socket, Reason}, State) - end); -handle_cast(Cast, State) -> - noreply(try callback(handle_cast, Cast, State) - catch _:{?MODULE, undef} -> State - end). - -handle_call(Call, From, State) -> - noreply(try callback(handle_call, Call, From, State) - catch _:{?MODULE, undef} -> State - end). - -handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, - #{stream_state := wait_for_stream, - xmlns := XMLNS, lang := MyLang} = State) -> - El = #xmlel{name = Name, attrs = Attrs}, - noreply( - try xmpp:decode(El, XMLNS, []) of - #stream_start{} = Pkt -> - State1 = send_header(State, Pkt), - case is_disconnected(State1) of - true -> State1; - false -> process_stream(Pkt, State1) - end; - _ -> - State1 = send_header(State), - case is_disconnected(State1) of - true -> State1; - false -> send_pkt(State1, xmpp:serr_invalid_xml()) - end - catch _:{xmpp_codec, Why} -> - State1 = send_header(State), - case is_disconnected(State1) of - true -> State1; - false -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - Err = xmpp:serr_invalid_xml(Txt, Lang), - send_pkt(State1, Err) - end - end); -handle_info({'$gen_event', {xmlstreamend, _}}, State) -> - noreply(process_stream_end({stream, reset}, State)); -handle_info({'$gen_event', closed}, State) -> - noreply(process_stream_end({socket, closed}, State)); -handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> - State1 = send_header(State), - noreply( - case is_disconnected(State1) of - true -> State1; - false -> - Err = case Reason of - <<"XML stanza is too big">> -> - xmpp:serr_policy_violation(Reason, Lang); - {_, Txt} -> - xmpp:serr_not_well_formed(Txt, Lang) - end, - send_pkt(State1, Err) - end); -handle_info({'$gen_event', El}, #{stream_state := wait_for_stream} = State) -> - error_logger:warning_msg("unexpected event from XML driver: ~p; " - "xmlstreamstart was expected", [El]), - State1 = send_header(State), - noreply( - case is_disconnected(State1) of - true -> State1; - false -> send_pkt(State1, xmpp:serr_invalid_xml()) - end); -handle_info({'$gen_event', {xmlstreamelement, El}}, - #{xmlns := NS, codec_options := Opts} = State) -> - noreply( - try xmpp:decode(El, NS, Opts) of - Pkt -> - State1 = try callback(handle_recv, El, Pkt, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> process_element(Pkt, State1) - end - catch _:{xmpp_codec, Why} -> - State1 = try callback(handle_recv, El, {error, Why}, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> process_invalid_xml(State1, El, Why) - end - end); -handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, - State) -> - noreply(try callback(handle_cdata, Data, State) - catch _:{?MODULE, undef} -> State - end); -handle_info(timeout, #{lang := Lang} = State) -> - Disconnected = is_disconnected(State), - noreply(try callback(handle_timeout, State) - catch _:{?MODULE, undef} when not Disconnected -> - Txt = <<"Idle connection">>, - send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang)); - _:{?MODULE, undef} -> - stop(State) - end); -handle_info({'DOWN', MRef, _Type, _Object, _Info}, - #{socket_monitor := MRef} = State) -> - noreply(process_stream_end({socket, closed}, State)); -handle_info({tcp, _, Data}, #{socket := Socket} = State) -> - noreply( - case xmpp_socket:recv(Socket, Data) of - {ok, NewSocket} -> - State#{socket => NewSocket}; - {error, Reason} when is_atom(Reason) -> - process_stream_end({socket, Reason}, State); - {error, Reason} -> - %% TODO: make fast_tls return atoms - process_stream_end({tls, Reason}, State) - end); -handle_info({tcp_closed, _}, State) -> - handle_info({'$gen_event', closed}, State); -handle_info({tcp_error, _, Reason}, State) -> - noreply(process_stream_end({socket, Reason}, State)); -handle_info(Info, State) -> - noreply(try callback(handle_info, Info, State) - catch _:{?MODULE, undef} -> State - end). - -terminate(Reason, State) -> - case get(already_terminated) of - true -> - State; - _ -> - put(already_terminated, true), - try callback(terminate, Reason, State) - catch _:{?MODULE, undef} -> ok - end, - send_trailer(State) - end. - -code_change(OldVsn, State, Extra) -> - callback(code_change, OldVsn, State, Extra). - -%%%=================================================================== -%%% Internal functions -%%%=================================================================== --spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. -noreply(#{stream_timeout := infinity} = State) -> - {noreply, State, infinity}; -noreply(#{stream_timeout := {MSecs, StartTime}} = State) -> - CurrentTime = p1_time_compat:monotonic_time(milli_seconds), - Timeout = max(0, MSecs - CurrentTime + StartTime), - {noreply, State, Timeout}. - --spec new_id() -> binary(). -new_id() -> - p1_rand:get_string(). - --spec is_disconnected(state()) -> boolean(). -is_disconnected(#{stream_state := StreamState}) -> - StreamState == disconnected. - --spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state(). -process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> - case xmpp:is_stanza(El) of - true -> - Txt = xmpp:io_format_error(Reason), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - send_error(State, El, xmpp:err_bad_request(Txt, Lang)); - false -> - case {xmpp:get_name(El), xmpp:get_ns(El)} of - {Tag, ?NS_SASL} when Tag == <<"auth">>; - Tag == <<"response">>; - Tag == <<"abort">> -> - Txt = xmpp:io_format_error(Reason), - Err = #sasl_failure{reason = 'malformed-request', - text = xmpp:mk_text(Txt, MyLang)}, - send_pkt(State, Err); - {<<"starttls">>, ?NS_TLS} -> - send_pkt(State, #starttls_failure{}); - {<<"compress">>, ?NS_COMPRESS} -> - Err = #compress_failure{reason = 'setup-failed'}, - send_pkt(State, Err); - _ -> - %% Maybe add something more? - State - end - end. - --spec process_stream_end(stop_reason(), state()) -> state(). -process_stream_end(_, #{stream_state := disconnected} = State) -> - State; -process_stream_end(Reason, State) -> - State1 = State#{stream_timeout => infinity, - stream_state => disconnected}, - try callback(handle_stream_end, Reason, State1) - catch _:{?MODULE, undef} -> stop(State1) - end. - --spec process_stream(stream_start(), state()) -> state(). -process_stream(#stream_start{xmlns = XML_NS, - stream_xmlns = STREAM_NS}, - #{xmlns := NS} = State) - when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> - send_pkt(State, xmpp:serr_invalid_namespace()); -process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> - send_pkt(State, xmpp:serr_unsupported_version()); -process_stream(#stream_start{lang = Lang}, - #{xmlns := ?NS_CLIENT, lang := DefaultLang} = State) - when size(Lang) > 35 -> - %% As stated in BCP47, 4.4.1: - %% Protocols or specifications that specify limited buffer sizes for - %% language tags MUST allow for language tags of at least 35 characters. - %% Do not store long language tag to avoid possible DoS/flood attacks - Txt = <<"Too long value of 'xml:lang' attribute">>, - send_pkt(State, xmpp:serr_policy_violation(Txt, DefaultLang)); -process_stream(#stream_start{to = undefined, version = Version} = StreamStart, - #{lang := Lang, server := Server, xmlns := NS} = State) -> - if Version < {1,0} andalso NS /= ?NS_COMPONENT -> - %% Work-around for gmail servers - To = jid:make(Server), - process_stream(StreamStart#stream_start{to = To}, State); - true -> - Txt = <<"Missing 'to' attribute">>, - send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)) - end; -process_stream(#stream_start{to = #jid{luser = U, lresource = R}}, - #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> -> - Txt = <<"Improper 'to' attribute">>, - send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)); -process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart, - #{xmlns := ?NS_COMPONENT} = State) -> - State1 = State#{remote_server => RemoteServer, - stream_state => wait_for_handshake}, - try callback(handle_stream_start, StreamStart, State1) - catch _:{?MODULE, undef} -> State1 - end; -process_stream(#stream_start{to = #jid{server = Server, lserver = LServer}, - from = From} = StreamStart, - #{stream_authenticated := Authenticated, - stream_restarted := StreamWasRestarted, - xmlns := NS, resource := Resource, - stream_encrypted := Encrypted} = State) -> - State1 = if not StreamWasRestarted -> - State#{server => Server, lserver => LServer}; - true -> - State - end, - State2 = case From of - #jid{lserver = RemoteServer} when NS == ?NS_SERVER -> - State1#{remote_server => RemoteServer}; - _ -> - State1 - end, - State3 = try callback(handle_stream_start, StreamStart, State2) - catch _:{?MODULE, undef} -> State2 - end, - case is_disconnected(State3) of - true -> State3; - false -> - State4 = send_features(State3), - case is_disconnected(State4) of - true -> State4; - false -> - TLSRequired = is_starttls_required(State4), - if not Authenticated and (TLSRequired and not Encrypted) -> - State4#{stream_state => wait_for_starttls}; - not Authenticated -> - State4#{stream_state => wait_for_sasl_request}; - (NS == ?NS_CLIENT) and (Resource == <<"">>) -> - State4#{stream_state => wait_for_bind}; - true -> - process_stream_established(State4) - end - end - end. - --spec process_element(xmpp_element(), state()) -> state(). -process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> - case Pkt of - #starttls{} when StateName == wait_for_starttls; - StateName == wait_for_sasl_request -> - process_starttls(State); - #starttls{} -> - process_starttls_failure(unexpected_starttls_request, State); - #sasl_auth{} when StateName == wait_for_starttls -> - send_pkt(State, #sasl_failure{reason = 'encryption-required'}); - #sasl_auth{} when StateName == wait_for_sasl_request -> - process_sasl_request(Pkt, State); - #sasl_auth{} when StateName == wait_for_sasl_response -> - process_sasl_request(Pkt, maps:remove(sasl_state, State)); - #sasl_auth{} -> - Txt = <<"SASL negotiation is not allowed in this state">>, - send_pkt(State, #sasl_failure{reason = 'not-authorized', - text = xmpp:mk_text(Txt, Lang)}); - #sasl_response{} when StateName == wait_for_starttls -> - send_pkt(State, #sasl_failure{reason = 'encryption-required'}); - #sasl_response{} when StateName == wait_for_sasl_response -> - process_sasl_response(Pkt, State); - #sasl_response{} -> - Txt = <<"SASL negotiation is not allowed in this state">>, - send_pkt(State, #sasl_failure{reason = 'not-authorized', - text = xmpp:mk_text(Txt, Lang)}); - #sasl_abort{} when StateName == wait_for_sasl_response -> - process_sasl_abort(State); - #sasl_abort{} -> - send_pkt(State, #sasl_failure{reason = 'aborted'}); - #sasl_success{} -> - State; - #compress{} -> - process_compress(Pkt, State); - #handshake{} when StateName == wait_for_handshake -> - process_handshake(Pkt, State); - #handshake{} -> - State; - #stream_error{} -> - process_stream_end({stream, {in, Pkt}}, State); - _ when StateName == wait_for_sasl_request; - StateName == wait_for_handshake; - StateName == wait_for_sasl_response -> - process_unauthenticated_packet(Pkt, State); - _ when StateName == wait_for_starttls -> - Txt = <<"Use of STARTTLS required">>, - Err = xmpp:serr_policy_violation(Txt, Lang), - send_pkt(State, Err); - _ when StateName == wait_for_bind -> - process_bind(Pkt, State); - _ when StateName == established -> - process_authenticated_packet(Pkt, State) - end. - --spec process_unauthenticated_packet(xmpp_element(), state()) -> state(). -process_unauthenticated_packet(Pkt, State) -> - NewPkt = set_lang(Pkt, State), - try callback(handle_unauthenticated_packet, NewPkt, State) - catch _:{?MODULE, undef} -> - Err = xmpp:serr_not_authorized(), - send(State, Err) - end. - --spec process_authenticated_packet(xmpp_element(), state()) -> state(). -process_authenticated_packet(Pkt, State) -> - Pkt1 = set_lang(Pkt, State), - case set_from_to(Pkt1, State) of - {ok, Pkt2} -> - try callback(handle_authenticated_packet, Pkt2, State) - catch _:{?MODULE, undef} -> - Err = xmpp:err_service_unavailable(), - send_error(State, Pkt, Err) - end; - {error, Err} -> - send_pkt(State, Err) - end. - --spec process_bind(xmpp_element(), state()) -> state(). -process_bind(#iq{type = set, sub_els = [_]} = Pkt, - #{xmlns := ?NS_CLIENT, lang := MyLang} = State) -> - try xmpp:try_subtag(Pkt, #bind{}) of - #bind{resource = R} -> - case callback(bind, R, State) of - {ok, #{user := U, server := S, resource := NewR} = State1} - when NewR /= <<"">> -> - Reply = #bind{jid = jid:make(U, S, NewR)}, - State2 = send_pkt(State1, xmpp:make_iq_result(Pkt, Reply)), - process_stream_established(State2); - {error, #stanza_error{} = Err, State1} -> - send_error(State1, Pkt, Err) - end; - _ -> - try callback(handle_unbinded_packet, Pkt, State) - catch _:{?MODULE, undef} -> - Err = xmpp:err_not_authorized(), - send_error(State, Pkt, Err) - end - catch _:{xmpp_codec, Why} -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(Pkt)), - Err = xmpp:err_bad_request(Txt, Lang), - send_error(State, Pkt, Err) - end; -process_bind(Pkt, State) -> - try callback(handle_unbinded_packet, Pkt, State) - catch _:{?MODULE, undef} -> - Err = xmpp:err_not_authorized(), - send_error(State, Pkt, Err) - end. - --spec process_handshake(handshake(), state()) -> state(). -process_handshake(#handshake{data = Digest}, - #{stream_id := StreamID, - remote_server := RemoteServer} = State) -> - GetPW = try callback(get_password_fun, State) - catch _:{?MODULE, undef} -> fun(_) -> {false, undefined} end - end, - AuthRes = case GetPW(<<"">>) of - {false, _} -> - false; - {Password, _} -> - str:sha(<>) == Digest - end, - case AuthRes of - true -> - State1 = try callback(handle_auth_success, - RemoteServer, <<"handshake">>, undefined, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> - State2 = send_pkt(State1, #handshake{}), - process_stream_established(State2) - end; - false -> - State1 = try callback(handle_auth_failure, - RemoteServer, <<"handshake">>, <<"not authorized">>, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> send_pkt(State1, xmpp:serr_not_authorized()) - end - end. - --spec process_stream_established(state()) -> state(). -process_stream_established(#{stream_state := StateName} = State) - when StateName == disconnected; StateName == established -> - State; -process_stream_established(State) -> - State1 = State#{stream_authenticated => true, - stream_state => established, - stream_timeout => infinity}, - try callback(handle_stream_established, State1) - catch _:{?MODULE, undef} -> State1 - end. - --spec process_compress(compress(), state()) -> state(). -process_compress(#compress{}, - #{stream_compressed := Compressed, - stream_authenticated := Authenticated} = State) - when Compressed or not Authenticated -> - send_pkt(State, #compress_failure{reason = 'setup-failed'}); -process_compress(#compress{methods = HisMethods}, - #{socket := Socket} = State) -> - MyMethods = try callback(compress_methods, State) - catch _:{?MODULE, undef} -> [] - end, - CommonMethods = lists_intersection(MyMethods, HisMethods), - case lists:member(<<"zlib">>, CommonMethods) of - true -> - case xmpp_socket:compress(Socket) of - {ok, ZlibSocket} -> - State1 = send_pkt(State, #compressed{}), - case is_disconnected(State1) of - true -> State1; - false -> - State1#{socket => ZlibSocket, - stream_id => new_id(), - stream_header_sent => false, - stream_restarted => true, - stream_state => wait_for_stream, - stream_compressed => true} - end; - {error, _} -> - Err = #compress_failure{reason = 'setup-failed'}, - send_pkt(State, Err) - end; - false -> - send_pkt(State, #compress_failure{reason = 'unsupported-method'}) - end. - --spec process_starttls(state()) -> state(). -process_starttls(#{stream_encrypted := true} = State) -> - process_starttls_failure(already_encrypted, State); -process_starttls(#{socket := Socket} = State) -> - case is_starttls_available(State) of - true -> - TLSOpts = try callback(tls_options, State) - catch _:{?MODULE, undef} -> [] - end, - case xmpp_socket:starttls(Socket, TLSOpts) of - {ok, TLSSocket} -> - State1 = send_pkt(State, #starttls_proceed{}), - case is_disconnected(State1) of - true -> State1; - false -> - State1#{socket => TLSSocket, - stream_id => new_id(), - stream_header_sent => false, - stream_restarted => true, - stream_state => wait_for_stream, - stream_encrypted => true} - end; - {error, Reason} -> - process_starttls_failure(Reason, State) - end; - false -> - process_starttls_failure(starttls_unsupported, State) - end. - --spec process_starttls_failure(term(), state()) -> state(). -process_starttls_failure(Why, State) -> - State1 = send_pkt(State, #starttls_failure{}), - case is_disconnected(State1) of - true -> State1; - false -> process_stream_end({tls, Why}, State1) - end. - --spec process_sasl_request(sasl_auth(), state()) -> state(). -process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, - #{lserver := LServer} = State) -> - State1 = State#{sasl_mech => Mech}, - Mechs = get_sasl_mechanisms(State1), - case lists:member(Mech, Mechs) of - true when Mech == <<"EXTERNAL">> -> - Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of - {ok, Peer} -> - {ok, [{auth_module, pkix}, {username, Peer}]}; - {error, Reason, Peer} -> - {error, Reason, Peer} - end, - process_sasl_result(Res, State1); - true -> - GetPW = try callback(get_password_fun, State1) - catch _:{?MODULE, undef} -> fun(_) -> false end - end, - CheckPW = try callback(check_password_fun, State1) - catch _:{?MODULE, undef} -> fun(_, _, _) -> false end - end, - CheckPWDigest = try callback(check_password_digest_fun, State1) - catch _:{?MODULE, undef} -> fun(_, _, _, _, _) -> false end - end, - SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [], - GetPW, CheckPW, CheckPWDigest), - Res = cyrsasl:server_start(SASLState, Mech, ClientIn), - process_sasl_result(Res, State1#{sasl_state => SASLState}); - false -> - process_sasl_result({error, unsupported_mechanism, <<"">>}, State1) - end. - --spec process_sasl_response(sasl_response(), state()) -> state(). -process_sasl_response(#sasl_response{text = ClientIn}, - #{sasl_state := SASLState} = State) -> - SASLResult = cyrsasl:server_step(SASLState, ClientIn), - process_sasl_result(SASLResult, State). - --spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state(). -process_sasl_result({ok, Props}, State) -> - process_sasl_success(Props, <<"">>, State); -process_sasl_result({ok, Props, ServerOut}, State) -> - process_sasl_success(Props, ServerOut, State); -process_sasl_result({continue, ServerOut, NewSASLState}, State) -> - process_sasl_continue(ServerOut, NewSASLState, State); -process_sasl_result({error, Reason, User}, State) -> - process_sasl_failure(Reason, User, State). - --spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state(). -process_sasl_success(Props, ServerOut, - #{socket := Socket, - sasl_mech := Mech} = State) -> - User = identity(Props), - AuthModule = proplists:get_value(auth_module, Props), - Socket1 = xmpp_socket:reset_stream(Socket), - State0 = State#{socket => Socket1}, - State1 = try callback(handle_auth_success, User, Mech, AuthModule, State0) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> - State2 = send_pkt(State1, #sasl_success{text = ServerOut}), - case is_disconnected(State2) of - true -> State2; - false -> - State3 = maps:remove(sasl_state, - maps:remove(sasl_mech, State2)), - State3#{stream_id => new_id(), - stream_authenticated => true, - stream_header_sent => false, - stream_restarted => true, - stream_state => wait_for_stream, - user => User} - end - end. - --spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state(). -process_sasl_continue(ServerOut, NewSASLState, State) -> - State1 = State#{sasl_state => NewSASLState, - stream_state => wait_for_sasl_response}, - send_pkt(State1, #sasl_challenge{text = ServerOut}). - --spec process_sasl_failure(atom(), binary(), state()) -> state(). -process_sasl_failure(Err, User, - #{sasl_mech := Mech, lang := Lang} = State) -> - {Reason, Text} = format_sasl_error(Mech, Err), - State1 = try callback(handle_auth_failure, User, Mech, Text, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> - State2 = send_pkt(State1, - #sasl_failure{reason = Reason, - text = xmpp:mk_text(Text, Lang)}), - case is_disconnected(State2) of - true -> State2; - false -> - State3 = maps:remove(sasl_state, - maps:remove(sasl_mech, State2)), - State3#{stream_state => wait_for_sasl_request} - end - end. - --spec process_sasl_abort(state()) -> state(). -process_sasl_abort(State) -> - process_sasl_failure(aborted, <<"">>, State). - --spec send_features(state()) -> state(). -send_features(#{stream_version := {1,0}, - stream_encrypted := Encrypted} = State) -> - TLSRequired = is_starttls_required(State), - Features = if TLSRequired and not Encrypted -> - get_tls_feature(State); - true -> - get_sasl_feature(State) ++ get_compress_feature(State) - ++ get_tls_feature(State) ++ get_bind_feature(State) - ++ get_session_feature(State) ++ get_other_features(State) - end, - send_pkt(State, #stream_features{sub_els = Features}); -send_features(State) -> - %% clients and servers from stone age - State. - --spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()]. -get_sasl_mechanisms(#{stream_encrypted := Encrypted, - xmlns := NS, lserver := LServer} = State) -> - Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer); - true -> [] - end, - TLSVerify = try callback(tls_verify, State) - catch _:{?MODULE, undef} -> false - end, - Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) -> - [<<"EXTERNAL">>|Mechs]; - true -> - Mechs - end, - try callback(sasl_mechanisms, Mechs1, State) - catch _:{?MODULE, undef} -> Mechs1 - end. - --spec get_sasl_feature(state()) -> [sasl_mechanisms()]. -get_sasl_feature(#{stream_authenticated := false, - stream_encrypted := Encrypted} = State) -> - TLSRequired = is_starttls_required(State), - if Encrypted or not TLSRequired -> - Mechs = get_sasl_mechanisms(State), - [#sasl_mechanisms{list = Mechs}]; - true -> - [] - end; -get_sasl_feature(_) -> - []. - --spec get_compress_feature(state()) -> [compression()]. -get_compress_feature(#{stream_compressed := false, - stream_authenticated := true} = State) -> - try callback(compress_methods, State) of - [] -> []; - Ms -> [#compression{methods = Ms}] - catch _:{?MODULE, undef} -> - [] - end; -get_compress_feature(_) -> - []. - --spec get_tls_feature(state()) -> [starttls()]. -get_tls_feature(#{stream_authenticated := false, - stream_encrypted := false} = State) -> - case is_starttls_available(State) of - true -> - TLSRequired = is_starttls_required(State), - [#starttls{required = TLSRequired}]; - false -> - [] - end; -get_tls_feature(_) -> - []. - --spec get_bind_feature(state()) -> [bind()]. -get_bind_feature(#{xmlns := ?NS_CLIENT, - stream_authenticated := true, - resource := <<"">>}) -> - [#bind{}]; -get_bind_feature(_) -> - []. - --spec get_session_feature(state()) -> [xmpp_session()]. -get_session_feature(#{xmlns := ?NS_CLIENT, - stream_authenticated := true, - resource := <<"">>}) -> - [#xmpp_session{optional = true}]; -get_session_feature(_) -> - []. - --spec get_other_features(state()) -> [xmpp_element()]. -get_other_features(#{stream_authenticated := Auth} = State) -> - try - if Auth -> callback(authenticated_stream_features, State); - true -> callback(unauthenticated_stream_features, State) - end - catch _:{?MODULE, undef} -> - [] - end. - --spec is_starttls_available(state()) -> boolean(). -is_starttls_available(State) -> - try callback(tls_enabled, State) - catch _:{?MODULE, undef} -> true - end. - --spec is_starttls_required(state()) -> boolean(). -is_starttls_required(State) -> - try callback(tls_required, State) - catch _:{?MODULE, undef} -> false - end. - --spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} | - {error, stream_error()}. -set_from_to(Pkt, _State) when not ?is_stanza(Pkt) -> - {ok, Pkt}; -set_from_to(Pkt, #{user := U, server := S, resource := R, - lang := Lang, xmlns := ?NS_CLIENT}) -> - JID = jid:make(U, S, R), - From = case xmpp:get_from(Pkt) of - undefined -> JID; - F -> F - end, - if JID#jid.luser == From#jid.luser andalso - JID#jid.lserver == From#jid.lserver andalso - (JID#jid.lresource == From#jid.lresource - orelse From#jid.lresource == <<"">>) -> - To = case xmpp:get_to(Pkt) of - undefined -> jid:make(U, S); - T -> T - end, - {ok, xmpp:set_from_to(Pkt, JID, To)}; - true -> - Txt = <<"Improper 'from' attribute">>, - {error, xmpp:serr_invalid_from(Txt, Lang)} - end; -set_from_to(Pkt, #{lang := Lang}) -> - From = xmpp:get_from(Pkt), - To = xmpp:get_to(Pkt), - if From == undefined -> - Txt = <<"Missing 'from' attribute">>, - {error, xmpp:serr_improper_addressing(Txt, Lang)}; - To == undefined -> - Txt = <<"Missing 'to' attribute">>, - {error, xmpp:serr_improper_addressing(Txt, Lang)}; - true -> - {ok, Pkt} - end. - --spec send_header(state()) -> state(). -send_header(#{stream_version := Version} = State) -> - send_header(State, #stream_start{version = Version}). - --spec send_header(state(), stream_start()) -> state(). -send_header(#{stream_id := StreamID, - stream_version := MyVersion, - stream_header_sent := false, - lang := MyLang, - xmlns := NS} = State, - #stream_start{to = HisTo, from = HisFrom, - lang = HisLang, version = HisVersion}) -> - Lang = select_lang(MyLang, HisLang), - NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK; - true -> <<"">> - end, - Version = case HisVersion of - undefined -> undefined; - {0,_} -> HisVersion; - _ -> MyVersion - end, - StreamStart = #stream_start{version = Version, - lang = Lang, - xmlns = NS, - stream_xmlns = ?NS_STREAM, - db_xmlns = NS_DB, - id = StreamID, - to = HisFrom, - from = HisTo}, - State1 = State#{lang => Lang, - stream_version => Version, - stream_header_sent => true}, - case socket_send(State1, StreamStart) of - ok -> State1; - {error, Why} -> process_stream_end({socket, Why}, State1) - end; -send_header(State, _) -> - State. - --spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). -send_pkt(State, Pkt) -> - Result = socket_send(State, Pkt), - State1 = try callback(handle_send, Pkt, Result, State) - catch _:{?MODULE, undef} -> State - end, - case Result of - _ when is_record(Pkt, stream_error) -> - process_stream_end({stream, {out, Pkt}}, State1); - ok -> - State1; - {error, Why} -> - process_stream_end({socket, Why}, State1) - end. - --spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state(). -send_error(State, Pkt, Err) -> - case xmpp:is_stanza(Pkt) of - true -> - case xmpp:get_type(Pkt) of - result -> State; - error -> State; - <<"result">> -> State; - <<"error">> -> State; - _ -> - ErrPkt = xmpp:make_error(Pkt, Err), - send_pkt(State, ErrPkt) - end; - false -> - State - end. - --spec send_trailer(state()) -> state(). -send_trailer(State) -> - socket_send(State, trailer), - close_socket(State). - --spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}. -socket_send(#{socket := Sock, - stream_state := StateName, - xmlns := NS, - stream_header_sent := true}, Pkt) -> - case Pkt of - trailer -> - xmpp_socket:send_trailer(Sock); - #stream_start{} when StateName /= disconnected -> - xmpp_socket:send_header(Sock, xmpp:encode(Pkt)); - _ when StateName /= disconnected -> - xmpp_socket:send_element(Sock, xmpp:encode(Pkt, NS)); - _ -> - {error, closed} - end; -socket_send(_, _) -> - {error, closed}. - --spec close_socket(state()) -> state(). -close_socket(#{socket := Socket} = State) -> - xmpp_socket:close(Socket), - State#{stream_timeout => infinity, - stream_state => disconnected}. - --spec select_lang(binary(), binary()) -> binary(). -select_lang(Lang, <<"">>) -> Lang; -select_lang(_, Lang) -> Lang. - --spec set_lang(xmpp_element(), state()) -> xmpp_element(). -set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) -> - HisLang = xmpp:get_lang(Pkt), - Lang = select_lang(MyLang, HisLang), - xmpp:set_lang(Pkt, Lang); -set_lang(Pkt, _) -> - Pkt. - --spec format_inet_error(atom()) -> string(). -format_inet_error(closed) -> - "connection closed"; -format_inet_error(Reason) -> - case inet:format_error(Reason) of - "unknown POSIX error" -> atom_to_list(Reason); - Txt -> Txt - end. - --spec format_sasl_error(cyrsasl:mechanism(), atom()) -> {atom(), binary()}. -format_sasl_error(<<"EXTERNAL">>, Err) -> - xmpp_stream_pkix:format_error(Err); -format_sasl_error(Mech, Err) -> - cyrsasl:format_error(Mech, Err). - --spec format_tls_error(atom() | binary()) -> list(). -format_tls_error(Reason) when is_atom(Reason) -> - format_inet_error(Reason); -format_tls_error(Reason) -> - Reason. - --spec format(io:format(), list()) -> binary(). -format(Fmt, Args) -> - iolist_to_binary(io_lib:format(Fmt, Args)). - --spec lists_intersection(list(), list()) -> list(). -lists_intersection(L1, L2) -> - lists:filter( - fun(E) -> - lists:member(E, L2) - end, L1). - --spec identity([cyrsasl:sasl_property()]) -> binary(). -identity(Props) -> - case proplists:get_value(authzid, Props, <<>>) of - <<>> -> proplists:get_value(username, Props, <<>>); - AuthzId -> AuthzId - end. - -%%%=================================================================== -%%% Callbacks -%%%=================================================================== -callback(F, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 1) of - true -> Mod:F(State); - false -> erlang:error({?MODULE, undef}) - end. - -callback(F, Arg1, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 2) of - true -> Mod:F(Arg1, State); - false -> erlang:error({?MODULE, undef}) - end. - -callback(code_change, OldVsn, #{mod := Mod} = State, Extra) -> - %% code_change/3 callback is a special snowflake - case erlang:function_exported(Mod, code_change, 3) of - true -> Mod:code_change(OldVsn, State, Extra); - false -> {ok, State} - end; -callback(F, Arg1, Arg2, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 3) of - true -> Mod:F(Arg1, Arg2, State); - false -> erlang:error({?MODULE, undef}) - end. - -callback(F, Arg1, Arg2, Arg3, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 4) of - true -> Mod:F(Arg1, Arg2, Arg3, State); - false -> erlang:error({?MODULE, undef}) - end. diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl deleted file mode 100644 index 171eef033..000000000 --- a/src/xmpp_stream_out.erl +++ /dev/null @@ -1,1321 +0,0 @@ -%%%------------------------------------------------------------------- -%%% Created : 14 Dec 2016 by Evgeny Khramtsov -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%------------------------------------------------------------------- --module(xmpp_stream_out). --define(GEN_SERVER, p1_server). --behaviour(?GEN_SERVER). - --protocol({rfc, 6120}). --protocol({xep, 114, '1.6'}). --protocol({xep, 368, '1.0.0'}). - -%% API --export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1, - stop/1, send/2, close/1, close/2, bind/2, establish/1, format_error/1, - set_timeout/2, get_transport/1, change_shaper/2]). -%% gen_server callbacks --export([init/1, handle_call/3, handle_cast/2, handle_info/2, - terminate/2, code_change/3]). - -%%-define(DBGFSM, true). --ifdef(DBGFSM). --define(FSMOPTS, [{debug, [trace]}]). --else. --define(FSMOPTS, []). --endif. - --define(TCP_SEND_TIMEOUT, 15000). - --include("xmpp.hrl"). --include_lib("kernel/include/inet.hrl"). - --type state() :: map(). --type noreply() :: {noreply, state(), timeout()}. --type host_port() :: {inet:hostname(), inet:port_number(), boolean()} | ip_port(). --type ip_port() :: {inet:ip_address(), inet:port_number(), boolean()}. --type h_addr_list() :: {{integer(), integer(), inet:port_number(), string()}, boolean()}. --type network_error() :: {error, inet:posix() | inet_res:res_error()}. --type tls_error_reason() :: inet:posix() | atom() | binary(). --type socket_error_reason() :: inet:posix() | atom(). --type stop_reason() :: {idna, bad_string} | - {dns, inet:posix() | inet_res:res_error()} | - {stream, reset | {in | out, stream_error()}} | - {tls, tls_error_reason()} | - {pkix, binary()} | - {auth, atom() | binary() | string()} | - {bind, stanza_error()} | - {socket, socket_error_reason()} | - internal_failure. --export_type([state/0, stop_reason/0]). --callback init(list()) -> {ok, state()} | {error, term()} | ignore. --callback handle_cast(term(), state()) -> state(). --callback handle_call(term(), term(), state()) -> state(). --callback handle_info(term(), state()) -> state(). --callback terminate(term(), state()) -> any(). --callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. --callback handle_stream_start(stream_start(), state()) -> state(). --callback handle_stream_established(state()) -> state(). --callback handle_stream_downgraded(stream_start(), state()) -> state(). --callback handle_stream_end(stop_reason(), state()) -> state(). --callback handle_cdata(binary(), state()) -> state(). --callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state(). --callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state(). --callback handle_timeout(state()) -> state(). --callback handle_authenticated_features(stream_features(), state()) -> state(). --callback handle_unauthenticated_features(stream_features(), state()) -> state(). --callback handle_auth_success(cyrsasl:mechanism(), state()) -> state(). --callback handle_auth_failure(cyrsasl:mechanism(), binary(), state()) -> state(). --callback handle_bind_success(state()) -> state(). --callback handle_bind_failure(stanza_error(), state()) -> state(). --callback handle_packet(xmpp_element(), state()) -> state(). --callback tls_options(state()) -> [proplists:property()]. --callback tls_required(state()) -> boolean(). --callback tls_verify(state()) -> boolean(). --callback tls_enabled(state()) -> boolean(). --callback resolve(string(), state()) -> [host_port()]. --callback sasl_mechanisms(state()) -> [binary()]. --callback dns_timeout(state()) -> timeout(). --callback dns_retries(state()) -> non_neg_integer(). --callback default_port(state()) -> inet:port_number(). --callback connect_options(inet:ip_address(), list(), state()) -> list(). --callback address_families(state()) -> [inet:address_family()]. --callback connect_timeout(state()) -> timeout(). - --optional_callbacks([init/1, - handle_cast/2, - handle_call/3, - handle_info/2, - terminate/2, - code_change/3, - handle_stream_start/2, - handle_stream_established/1, - handle_stream_downgraded/2, - handle_stream_end/2, - handle_cdata/2, - handle_send/3, - handle_recv/3, - handle_timeout/1, - handle_authenticated_features/2, - handle_unauthenticated_features/2, - handle_auth_success/2, - handle_auth_failure/3, - handle_bind_success/1, - handle_bind_failure/2, - handle_packet/2, - tls_options/1, - tls_required/1, - tls_verify/1, - tls_enabled/1, - resolve/2, - sasl_mechanisms/1, - dns_timeout/1, - dns_retries/1, - default_port/1, - connect_options/3, - address_families/1, - connect_timeout/1]). - -%%%=================================================================== -%%% API -%%%=================================================================== -start({local, Mod}, Args, Opts) -> - ?GEN_SERVER:start({local, Mod}, ?MODULE, [Mod|Args], Opts ++ ?FSMOPTS); -start(Mod, Args, Opts) -> - ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). - -start_link({local, Mod}, Args, Opts) -> - ?GEN_SERVER:start_link({local, Mod}, ?MODULE, [Mod|Args], Opts ++ ?FSMOPTS); -start_link(Mod, Args, Opts) -> - ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). - -call(Ref, Msg, Timeout) -> - ?GEN_SERVER:call(Ref, Msg, Timeout). - -cast(Ref, Msg) -> - ?GEN_SERVER:cast(Ref, Msg). - -reply(Ref, Reply) -> - ?GEN_SERVER:reply(Ref, Reply). - --spec connect(pid()) -> ok. -connect(Ref) -> - cast(Ref, connect). - --spec stop(pid()) -> ok; - (state()) -> no_return(). -stop(Pid) when is_pid(Pid) -> - cast(Pid, stop); -stop(#{owner := Owner} = State) when Owner == self() -> - terminate(normal, State), - exit(normal); -stop(_) -> - erlang:error(badarg). - --spec send(pid(), xmpp_element()) -> ok; - (state(), xmpp_element()) -> state(). -send(Pid, Pkt) when is_pid(Pid) -> - cast(Pid, {send, Pkt}); -send(#{owner := Owner} = State, Pkt) when Owner == self() -> - send_pkt(State, Pkt); -send(_, _) -> - erlang:error(badarg). - --spec close(pid()) -> ok; - (state()) -> state(). -close(Pid) when is_pid(Pid) -> - close(Pid, closed); -close(#{owner := Owner} = State) when Owner == self() -> - close_socket(State); -close(_) -> - erlang:error(badarg). - --spec close(pid(), atom()) -> ok. -close(Pid, Reason) -> - cast(Pid, {close, Reason}). - --spec bind(state(), stream_features()) -> state(). -bind(#{stream_authenticated := true} = State, StreamFeatures) -> - process_bind(StreamFeatures, State). - --spec establish(state()) -> state(). -establish(State) -> - process_stream_established(State). - --spec set_timeout(state(), timeout()) -> state(). -set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> - case Timeout of - infinity -> State#{stream_timeout => infinity}; - _ -> - Time = p1_time_compat:monotonic_time(milli_seconds), - State#{stream_timeout => {Timeout, Time}} - end; -set_timeout(_, _) -> - erlang:error(badarg). - -get_transport(#{socket := Socket, owner := Owner}) - when Owner == self() -> - xmpp_socket:get_transport(Socket); -get_transport(_) -> - erlang:error(badarg). - --spec change_shaper(state(), ejabberd_shaper:shaper()) -> state(). -change_shaper(#{socket := Socket, owner := Owner} = State, Shaper) - when Owner == self() -> - Socket1 = xmpp_socket:change_shaper(Socket, Shaper), - State#{socket => Socket1}; -change_shaper(_, _) -> - erlang:error(badarg). - --spec format_error(stop_reason()) -> binary(). -format_error({idna, _}) -> - <<"Remote domain is not an IDN hostname">>; -format_error({dns, Reason}) -> - format("DNS lookup failed: ~s", [format_inet_error(Reason)]); -format_error({socket, Reason}) -> - format("Connection failed: ~s", [format_inet_error(Reason)]); -format_error({pkix, Reason}) -> - {_, ErrTxt} = xmpp_stream_pkix:format_error(Reason), - format("Peer certificate rejected: ~s", [ErrTxt]); -format_error({stream, reset}) -> - <<"Stream reset by peer">>; -format_error({stream, {in, #stream_error{} = Err}}) -> - format("Stream closed by peer: ~s", [xmpp:format_stream_error(Err)]); -format_error({stream, {out, #stream_error{} = Err}}) -> - format("Stream closed by us: ~s", [xmpp:format_stream_error(Err)]); -format_error({bind, #stanza_error{} = Err}) -> - format("Resource binding failure: ~s", [xmpp:format_stanza_error(Err)]); -format_error({tls, Reason}) -> - format("TLS failed: ~s", [format_tls_error(Reason)]); -format_error({auth, Reason}) -> - format("Authentication failed: ~s", [Reason]); -format_error(internal_failure) -> - <<"Internal server error">>; -format_error(Err) -> - format("Unrecognized error: ~w", [Err]). - -%%%=================================================================== -%%% gen_server callbacks -%%%=================================================================== --spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore. -init([Mod, From, To, Opts]) -> - Time = p1_time_compat:monotonic_time(milli_seconds), - State = #{owner => self(), - mod => Mod, - server => From, - user => <<"">>, - resource => <<"">>, - password => <<"">>, - lang => <<"">>, - remote_server => To, - xmlns => ?NS_SERVER, - codec_options => [ignore_els], - stream_direction => out, - stream_timeout => {timer:seconds(30), Time}, - stream_id => new_id(), - stream_encrypted => false, - stream_verified => false, - stream_authenticated => false, - stream_restarted => false, - stream_state => connecting}, - case try Mod:init([State, Opts]) - catch _:undef -> {ok, State} - end of - {ok, State1} -> - {_, State2, Timeout} = noreply(State1), - {ok, State2, Timeout}; - {error, Reason} -> - {stop, Reason}; - ignore -> - ignore - end. - --spec handle_call(term(), term(), state()) -> noreply(). -handle_call(Call, From, State) -> - noreply(try callback(handle_call, Call, From, State) - catch _:{?MODULE, undef} -> State - end). - --spec handle_cast(term(), state()) -> noreply(). -handle_cast(connect, #{remote_server := RemoteServer, - stream_state := connecting} = State) -> - noreply( - case idna_to_ascii(RemoteServer) of - false -> - process_stream_end({idna, bad_string}, State); - ASCIIName -> - case resolve(binary_to_list(ASCIIName), State) of - {ok, AddrPorts} -> - case connect(AddrPorts, State) of - {ok, Socket, {Addr, Port, Encrypted}} -> - SocketMonitor = xmpp_socket:monitor(Socket), - State1 = State#{ip => {Addr, Port}, - socket => Socket, - stream_encrypted => Encrypted, - socket_monitor => SocketMonitor}, - State2 = State1#{stream_state => wait_for_stream}, - send_header(State2); - {error, {Class, Why}} -> - process_stream_end({Class, Why}, State) - end; - {error, Why} -> - process_stream_end({dns, Why}, State) - end - end); -handle_cast(connect, #{stream_state := disconnected} = State) -> - State1 = reset_state(State), - handle_cast(connect, State1); -handle_cast(connect, State) -> - %% Ignoring connection attempts in other states - noreply(State); -handle_cast({send, Pkt}, State) -> - noreply(send_pkt(State, Pkt)); -handle_cast(stop, State) -> - {stop, normal, State}; -handle_cast({close, Reason}, State) -> - State1 = close_socket(State), - noreply( - case is_disconnected(State) of - true -> State1; - false -> process_stream_end({socket, Reason}, State) - end); -handle_cast(Cast, State) -> - noreply(try callback(handle_cast, Cast, State) - catch _:{?MODULE, undef} -> State - end). - --spec handle_info(term(), state()) -> noreply(). -handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, - #{stream_state := wait_for_stream, - xmlns := XMLNS, lang := MyLang} = State) -> - El = #xmlel{name = Name, attrs = Attrs}, - noreply( - try xmpp:decode(El, XMLNS, []) of - #stream_start{} = Pkt -> - process_stream(Pkt, State); - _ -> - send_pkt(State, xmpp:serr_invalid_xml()) - catch _:{xmpp_codec, Why} -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - Err = xmpp:serr_invalid_xml(Txt, Lang), - send_pkt(State, Err) - end); -handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> - State1 = send_header(State), - noreply( - case is_disconnected(State1) of - true -> State1; - false -> - Err = case Reason of - <<"XML stanza is too big">> -> - xmpp:serr_policy_violation(Reason, Lang); - {_, Txt} -> - xmpp:serr_not_well_formed(Txt, Lang) - end, - send_pkt(State1, Err) - end); -handle_info({'$gen_event', {xmlstreamelement, El}}, - #{xmlns := NS, codec_options := Opts} = State) -> - noreply( - try xmpp:decode(El, NS, Opts) of - Pkt -> - State1 = try callback(handle_recv, El, Pkt, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> process_element(Pkt, State1) - end - catch _:{xmpp_codec, Why} -> - State1 = try callback(handle_recv, El, {error, Why}, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> process_invalid_xml(State1, El, Why) - end - end); -handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, State) -> - noreply(try callback(handle_cdata, Data, State) - catch _:{?MODULE, undef} -> State - end); -handle_info({'$gen_event', {xmlstreamend, _}}, State) -> - noreply(process_stream_end({stream, reset}, State)); -handle_info({'$gen_event', closed}, State) -> - noreply(process_stream_end({socket, closed}, State)); -handle_info(timeout, #{lang := Lang} = State) -> - Disconnected = is_disconnected(State), - noreply(try callback(handle_timeout, State) - catch _:{?MODULE, undef} when not Disconnected -> - Txt = <<"Idle connection">>, - send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang)); - _:{?MODULE, undef} -> - stop(State) - end); -handle_info({'DOWN', MRef, _Type, _Object, _Info}, - #{socket_monitor := MRef} = State) -> - noreply(process_stream_end({socket, closed}, State)); -handle_info({tcp, _, Data}, #{socket := Socket} = State) -> - noreply( - case xmpp_socket:recv(Socket, Data) of - {ok, NewSocket} -> - State#{socket => NewSocket}; - {error, Reason} when is_atom(Reason) -> - process_stream_end({socket, Reason}, State); - {error, Reason} -> - %% TODO: make fast_tls return atoms - process_stream_end({tls, Reason}, State) - end); -handle_info({tcp_closed, _}, State) -> - handle_info({'$gen_event', closed}, State); -handle_info({tcp_error, _, Reason}, State) -> - noreply(process_stream_end({socket, Reason}, State)); -handle_info({'EXIT', _, Reason}, State) -> - {stop, Reason, State}; -handle_info(Info, State) -> - noreply(try callback(handle_info, Info, State) - catch _:{?MODULE, undef} -> State - end). - --spec terminate(term(), state()) -> any(). -terminate(Reason, State) -> - case get(already_terminated) of - true -> - State; - _ -> - put(already_terminated, true), - try callback(terminate, Reason, State) - catch _:{?MODULE, undef} -> ok - end, - send_trailer(State) - end. - -code_change(OldVsn, State, Extra) -> - callback(code_change, OldVsn, State, Extra). - -%%%=================================================================== -%%% Internal functions -%%%=================================================================== --spec noreply(state()) -> noreply(). -noreply(#{stream_timeout := infinity} = State) -> - {noreply, State, infinity}; -noreply(#{stream_timeout := {MSecs, OldTime}} = State) -> - NewTime = p1_time_compat:monotonic_time(milli_seconds), - Timeout = max(0, MSecs - NewTime + OldTime), - {noreply, State, Timeout}. - --spec new_id() -> binary(). -new_id() -> - p1_rand:get_string(). - --spec is_disconnected(state()) -> boolean(). -is_disconnected(#{stream_state := StreamState}) -> - StreamState == disconnected. - --spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state(). -process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> - case xmpp:is_stanza(El) of - true -> - Txt = xmpp:io_format_error(Reason), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - send_error(State, El, xmpp:err_bad_request(Txt, Lang)); - false -> - State - end. - --spec process_stream_end(stop_reason(), state()) -> state(). -process_stream_end(_, #{stream_state := disconnected} = State) -> - State; -process_stream_end(Reason, State) -> - State1 = send_trailer(State), - try callback(handle_stream_end, Reason, State1) - catch _:{?MODULE, undef} -> stop(State1) - end. - --spec process_stream(stream_start(), state()) -> state(). -process_stream(#stream_start{xmlns = XML_NS, - stream_xmlns = STREAM_NS}, - #{xmlns := NS} = State) - when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> - send_pkt(State, xmpp:serr_invalid_namespace()); -process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> - send_pkt(State, xmpp:serr_unsupported_version()); -process_stream(#stream_start{lang = Lang, id = ID, - version = Version} = StreamStart, - State) -> - State1 = State#{stream_remote_id => ID, lang => Lang}, - State2 = try callback(handle_stream_start, StreamStart, State1) - catch _:{?MODULE, undef} -> State1 - end, - case is_disconnected(State2) of - true -> State2; - false -> - case Version of - {1, _} -> - State2#{stream_state => wait_for_features}; - _ -> - process_stream_downgrade(StreamStart, State2) - end - end. - --spec process_element(xmpp_element(), state()) -> state(). -process_element(Pkt, #{stream_state := StateName} = State) -> - case Pkt of - #stream_features{} when StateName == wait_for_features -> - process_features(Pkt, State); - #starttls_proceed{} when StateName == wait_for_starttls_response -> - process_starttls(State); - #sasl_success{} when StateName == wait_for_sasl_response -> - process_sasl_success(State); - #sasl_failure{} when StateName == wait_for_sasl_response -> - process_sasl_failure(Pkt, State); - #stream_error{} -> - process_stream_end({stream, {in, Pkt}}, State); - _ when is_record(Pkt, stream_features); - is_record(Pkt, starttls_proceed); - is_record(Pkt, starttls); - is_record(Pkt, sasl_auth); - is_record(Pkt, sasl_success); - is_record(Pkt, sasl_failure); - is_record(Pkt, sasl_response); - is_record(Pkt, sasl_abort); - is_record(Pkt, compress); - is_record(Pkt, handshake) -> - %% Do not pass this crap upstream - State; - _ when StateName == wait_for_bind_response -> - process_bind_response(Pkt, State); - _ -> - process_packet(Pkt, State) - end. - --spec process_features(stream_features(), state()) -> state(). -process_features(StreamFeatures, - #{stream_authenticated := true} = State) -> - try callback(handle_authenticated_features, StreamFeatures, State) - catch _:{?MODULE, undef} -> process_bind(StreamFeatures, State) - end; -process_features(StreamFeatures, - #{stream_encrypted := Encrypted, lang := Lang} = State) -> - State1 = try callback(handle_unauthenticated_features, StreamFeatures, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> - TLSRequired = is_starttls_required(State1), - TLSAvailable = is_starttls_available(State1), - try xmpp:try_subtag(StreamFeatures, #starttls{}) of - false when TLSRequired and not Encrypted -> - Txt = <<"Use of STARTTLS required">>, - send_pkt(State1, xmpp:serr_policy_violation(Txt, Lang)); - #starttls{required = true} when not TLSAvailable and not Encrypted -> - Txt = <<"Use of STARTTLS forbidden">>, - send_pkt(State1, xmpp:serr_unsupported_feature(Txt, Lang)); - #starttls{} when TLSAvailable and not Encrypted -> - State2 = State1#{stream_state => wait_for_starttls_response}, - send_pkt(State2, #starttls{}); - _ -> - State2 = process_cert_verification(State1), - case is_disconnected(State2) of - true -> State2; - false -> process_sasl_mechanisms(StreamFeatures, State2) - end - catch _:{xmpp_codec, Why} -> - Txt = xmpp:io_format_error(Why), - send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang)) - end - end. - --spec process_stream_established(state()) -> state(). -process_stream_established(#{stream_state := StateName} = State) - when StateName == disconnected; StateName == established -> - State; -process_stream_established(State) -> - State1 = State#{stream_authenticated := true, - stream_state => established, - stream_timeout => infinity}, - try callback(handle_stream_established, State1) - catch _:{?MODULE, undef} -> State1 - end. - --spec process_sasl_mechanisms(stream_features(), state()) -> state(). -process_sasl_mechanisms(StreamFeatures, State) -> - AvailMechs = sasl_mechanisms(State), - State1 = State#{sasl_mechs_available => AvailMechs}, - try xmpp:try_subtag(StreamFeatures, #sasl_mechanisms{}) of - #sasl_mechanisms{list = ProvidedMechs} -> - process_sasl_auth(State1#{sasl_mechs_provided => ProvidedMechs}); - false -> - process_sasl_auth(State1#{sasl_mechs_provided => []}) - catch _:{xmpp_codec, Why} -> - Txt = xmpp:io_format_error(Why), - Lang = maps:get(lang, State), - send_pkt(State, xmpp:serr_invalid_xml(Txt, Lang)) - end. - -process_sasl_auth(#{stream_encrypted := false, xmlns := ?NS_SERVER} = State) -> - State1 = State#{sasl_mechs_available => []}, - Txt = case is_starttls_available(State) of - true -> <<"Peer doesn't support STARTTLS">>; - false -> <<"STARTTLS is disabled in local configuration">> - end, - process_sasl_failure(Txt, State1); -process_sasl_auth(#{sasl_mechs_provided := [], - stream_encrypted := Encrypted} = State) -> - State1 = State#{sasl_mechs_available => []}, - Hint = case Encrypted of - true -> <<"; most likely it doesn't accept our certificate">>; - false -> <<"">> - end, - Txt = <<"Peer provided no SASL mechanisms", Hint/binary>>, - process_sasl_failure(Txt, State1); -process_sasl_auth(#{sasl_mechs_available := []} = State) -> - Err = maps:get(sasl_error, State, - <<"No mutually supported SASL mechanisms found">>), - process_sasl_failure(Err, State); -process_sasl_auth(#{sasl_mechs_available := [Mech|AvailMechs], - sasl_mechs_provided := ProvidedMechs} = State) -> - State1 = State#{sasl_mechs_available => AvailMechs}, - if Mech == <<"EXTERNAL">> orelse Mech == <<"PLAIN">> -> - case lists:member(Mech, ProvidedMechs) of - true -> - Text = make_sasl_authzid(Mech, State1), - State2 = State1#{sasl_mech => Mech, - stream_state => wait_for_sasl_response}, - send(State2, #sasl_auth{mechanism = Mech, text = Text}); - false -> - process_sasl_auth(State1) - end; - true -> - process_sasl_auth(State1) - end. - --spec process_starttls(state()) -> state(). -process_starttls(#{socket := Socket} = State) -> - case starttls(Socket, State) of - {ok, TLSSocket} -> - State1 = State#{socket => TLSSocket, - stream_id => new_id(), - stream_restarted => true, - stream_state => wait_for_stream, - stream_encrypted => true}, - send_header(State1); - {error, Why} -> - process_stream_end({tls, Why}, State) - end. - --spec process_stream_downgrade(stream_start(), state()) -> state(). -process_stream_downgrade(StreamStart, - #{lang := Lang, - stream_encrypted := Encrypted} = State) -> - TLSRequired = is_starttls_required(State), - if not Encrypted and TLSRequired -> - Txt = <<"Use of STARTTLS required">>, - send_pkt(State, xmpp:serr_policy_violation(Txt, Lang)); - true -> - State1 = State#{stream_state => downgraded}, - try callback(handle_stream_downgraded, StreamStart, State1) - catch _:{?MODULE, undef} -> - send_pkt(State1, xmpp:serr_unsupported_version()) - end - end. - --spec process_cert_verification(state()) -> state(). -process_cert_verification(#{stream_encrypted := true, - stream_verified := false} = State) -> - case try callback(tls_verify, State) - catch _:{?MODULE, undef} -> true - end of - true -> - case xmpp_stream_pkix:authenticate(State) of - {ok, _} -> - State#{stream_verified => true}; - {error, Why, _Peer} -> - process_stream_end({pkix, Why}, State) - end; - false -> - State#{stream_verified => true} - end; -process_cert_verification(State) -> - State. - --spec process_sasl_success(state()) -> state(). -process_sasl_success(#{socket := Socket, sasl_mech := Mech} = State) -> - Socket1 = xmpp_socket:reset_stream(Socket), - State1 = State#{socket => Socket1}, - State2 = State1#{stream_id => new_id(), - stream_restarted => true, - stream_state => wait_for_stream, - stream_authenticated => true}, - State3 = reset_sasl_state(State2), - State4 = send_header(State3), - case is_disconnected(State4) of - true -> State4; - false -> - try callback(handle_auth_success, Mech, State4) - catch _:{?MODULE, undef} -> State4 - end - end. - --spec process_sasl_failure(sasl_failure() | binary(), state()) -> state(). -process_sasl_failure(Failure, #{sasl_mechs_available := [_|_]} = State) -> - process_sasl_auth(State#{sasl_failure => Failure}); -process_sasl_failure(#sasl_failure{} = Failure, State) -> - Reason = format("Peer responded with error: ~s", - [xmpp:format_sasl_error(Failure)]), - process_sasl_failure(Reason, State); -process_sasl_failure(Reason, State) -> - Mech = case maps:get(sasl_mech, State, undefined) of - undefined -> - case sasl_mechanisms(State) of - [] -> <<"EXTERNAL">>; - [M|_] -> M - end; - M -> M - end, - State1 = reset_sasl_state(State), - try callback(handle_auth_failure, Mech, {auth, Reason}, State1) - catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State1) - end. - --spec process_bind(stream_features(), state()) -> state(). -process_bind(StreamFeatures, #{lang := Lang, xmlns := ?NS_CLIENT, - resource := R, - stream_state := StateName} = State) - when StateName /= established, StateName /= disconnected -> - case xmpp:has_subtag(StreamFeatures, #bind{}) of - true -> - ID = new_id(), - Pkt = #iq{id = ID, type = set, - sub_els = [#bind{resource = R}]}, - State1 = State#{stream_state => wait_for_bind_response, - bind_id => ID}, - send_pkt(State1, Pkt); - false -> - Txt = <<"Missing resource binding feature">>, - send_pkt(State, xmpp:serr_invalid_xml(Txt, Lang)) - end; -process_bind(_, State) -> - process_stream_established(State). - --spec process_bind_response(xmpp_element(), state()) -> state(). -process_bind_response(#iq{type = result, id = ID} = IQ, - #{lang := Lang, bind_id := ID} = State) -> - State1 = reset_bind_state(State), - try xmpp:try_subtag(IQ, #bind{}) of - #bind{jid = #jid{user = U, server = S, resource = R}} -> - State2 = State1#{user => U, server => S, resource => R}, - State3 = try callback(handle_bind_success, State2) - catch _:{?MODULE, undef} -> State2 - end, - process_stream_established(State3); - #bind{} -> - Txt = <<"Missing element in resource binding response">>, - send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang)); - false -> - Txt = <<"Missing element in resource binding response">>, - send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang)) - catch _:{xmpp_codec, Why} -> - Txt = xmpp:io_format_error(Why), - send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang)) - end; -process_bind_response(#iq{type = error, id = ID} = IQ, - #{bind_id := ID} = State) -> - Err = xmpp:get_error(IQ), - State1 = reset_bind_state(State), - try callback(handle_bind_failure, Err, State1) - catch _:{?MODULE, undef} -> process_stream_end({bind, Err}, State1) - end; -process_bind_response(Pkt, State) -> - process_packet(Pkt, State). - --spec process_packet(xmpp_element(), state()) -> state(). -process_packet(Pkt, State) -> - Pkt1 = fix_from(Pkt, State), - try callback(handle_packet, Pkt1, State) - catch _:{?MODULE, undef} -> State - end. - --spec is_starttls_required(state()) -> boolean(). -is_starttls_required(State) -> - try callback(tls_required, State) - catch _:{?MODULE, undef} -> false - end. - --spec is_starttls_available(state()) -> boolean(). -is_starttls_available(State) -> - try callback(tls_enabled, State) - catch _:{?MODULE, undef} -> true - end. - --spec sasl_mechanisms(state()) -> [binary()]. -sasl_mechanisms(#{stream_encrypted := Encrypted} = State) -> - try callback(sasl_mechanisms, State) of - Ms when Encrypted -> Ms; - Ms -> lists:delete(<<"EXTERNAL">>, Ms) - catch _:{?MODULE, undef} -> - if Encrypted -> [<<"EXTERNAL">>]; - true -> [] - end - end. - --spec send_header(state()) -> state(). -send_header(#{remote_server := RemoteServer, - stream_encrypted := Encrypted, - lang := Lang, - xmlns := NS, - user := User, - resource := Resource, - server := Server} = State) -> - NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK; - true -> <<"">> - end, - From = if Encrypted -> - jid:make(User, Server, Resource); - NS == ?NS_SERVER -> - jid:make(Server); - true -> - undefined - end, - StreamStart = #stream_start{xmlns = NS, - lang = Lang, - stream_xmlns = ?NS_STREAM, - db_xmlns = NS_DB, - from = From, - to = jid:make(RemoteServer), - version = {1,0}}, - case socket_send(State, StreamStart) of - ok -> State; - {error, Why} -> process_stream_end({socket, Why}, State) - end. - --spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). -send_pkt(State, Pkt) -> - Result = socket_send(State, Pkt), - State1 = try callback(handle_send, Pkt, Result, State) - catch _:{?MODULE, undef} -> State - end, - case Result of - _ when is_record(Pkt, stream_error) -> - process_stream_end({stream, {out, Pkt}}, State1); - ok -> - State1; - {error, Why} -> - process_stream_end({socket, Why}, State1) - end. - --spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state(). -send_error(State, Pkt, Err) -> - case xmpp:is_stanza(Pkt) of - true -> - case xmpp:get_type(Pkt) of - result -> State; - error -> State; - <<"result">> -> State; - <<"error">> -> State; - _ -> - ErrPkt = xmpp:make_error(Pkt, Err), - send_pkt(State, ErrPkt) - end; - false -> - State - end. - --spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}. -socket_send(#{socket := Socket, xmlns := NS, - stream_state := StateName}, Pkt) -> - case Pkt of - trailer -> - xmpp_socket:send_trailer(Socket); - #stream_start{} when StateName /= disconnected -> - xmpp_socket:send_header(Socket, xmpp:encode(Pkt)); - _ when StateName /= disconnected -> - xmpp_socket:send_element(Socket, xmpp:encode(Pkt, NS)); - _ -> - {error, closed} - end; -socket_send(_, _) -> - {error, closed}. - --spec send_trailer(state()) -> state(). -send_trailer(State) -> - socket_send(State, trailer), - close_socket(State). - --spec close_socket(state()) -> state(). -close_socket(State) -> - case State of - #{socket := Socket} -> - xmpp_socket:close(Socket); - _ -> - ok - end, - State#{stream_timeout => infinity, - stream_state => disconnected}. - --spec starttls(term(), state()) -> {ok, term()} | {error, tls_error_reason()}. -starttls(Socket, #{xmlns := NS, - remote_server := RemoteServer} = State) -> - TLSOpts = try callback(tls_options, State) - catch _:{?MODULE, undef} -> [] - end, - SNI = idna_to_ascii(RemoteServer), - ALPN = case NS of - ?NS_SERVER -> <<"xmpp-server">>; - ?NS_CLIENT -> <<"xmpp-client">> - end, - xmpp_socket:starttls(Socket, [connect, {sni, SNI}, {alpn, [ALPN]}|TLSOpts]). - --spec select_lang(binary(), binary()) -> binary(). -select_lang(Lang, <<"">>) -> Lang; -select_lang(_, Lang) -> Lang. - --spec format_inet_error(atom()) -> string(). -format_inet_error(closed) -> - "connection closed"; -format_inet_error(Reason) -> - case inet:format_error(Reason) of - "unknown POSIX error" -> atom_to_list(Reason); - Txt -> Txt - end. - --spec format_tls_error(atom() | binary()) -> list(). -format_tls_error(Reason) when is_atom(Reason) -> - format_inet_error(Reason); -format_tls_error(Reason) -> - binary_to_list(Reason). - --spec format(io:format(), list()) -> binary(). -format(Fmt, Args) -> - iolist_to_binary(io_lib:format(Fmt, Args)). - --spec make_sasl_authzid(binary(), state()) -> binary(). -make_sasl_authzid(Mech, #{user := User, server := Server, - password := Password}) -> - case Mech of - <<"EXTERNAL">> -> - jid:encode(jid:make(User, Server)); - <<"PLAIN">> -> - JID = jid:encode(jid:make(User, Server)), - <> - end. --spec fix_from(xmpp_element(), state()) -> xmpp_element(). -fix_from(Pkt, #{xmlns := ?NS_CLIENT} = State) -> - case xmpp:is_stanza(Pkt) of - true -> - case xmpp:get_from(Pkt) of - undefined -> - #{user := U, server := S, resource := R} = State, - From = jid:make(U, S, R), - xmpp:set_from(Pkt, From); - _ -> - Pkt - end; - false -> - Pkt - end; -fix_from(Pkt, _State) -> - Pkt. - -%%%=================================================================== -%%% State resets -%%%=================================================================== --spec reset_sasl_state(state()) -> state(). -reset_sasl_state(State) -> - State1 = maps:remove(sasl_mech, State), - State2 = maps:remove(sasl_failure, State1), - State3 = maps:remove(sasl_mechs_provided, State2), - maps:remove(sasl_mechs_available, State3). - --spec reset_connection_state(state()) -> state(). -reset_connection_state(State) -> - State1 = maps:remove(ip, State), - State2 = maps:remove(socket, State1), - maps:remove(socket_monitor, State2). - --spec reset_stream_state(state()) -> state(). -reset_stream_state(State) -> - State1 = State#{stream_id => new_id(), - stream_encrypted => false, - stream_verified => false, - stream_authenticated => false, - stream_restarted => false, - stream_state => connecting}, - maps:remove(stream_remote_id, State1). - --spec reset_bind_state(state()) -> state(). -reset_bind_state(State) -> - maps:remove(bind_id, State). - --spec reset_state(state()) -> state(). -reset_state(State) -> - State1 = reset_bind_state(State), - State2 = reset_sasl_state(State1), - State3 = reset_connection_state(State2), - reset_stream_state(State3). - -%%%=================================================================== -%%% Connection stuff -%%%=================================================================== --spec idna_to_ascii(binary()) -> binary() | false. -idna_to_ascii(<<$[, _/binary>> = Host) -> - %% This is an IPv6 address in 'IP-literal' format (as per RFC7622) - %% We remove brackets here - case binary:last(Host) of - $] -> - IPv6 = binary:part(Host, {1, size(Host)-2}), - case inet:parse_ipv6strict_address(binary_to_list(IPv6)) of - {ok, _} -> IPv6; - {error, _} -> false - end; - _ -> - false - end; -idna_to_ascii(Host) -> - case inet:parse_address(binary_to_list(Host)) of - {ok, _} -> Host; - {error, _} -> ejabberd_idna:domain_utf8_to_ascii(Host) - end. - --spec resolve(string(), state()) -> {ok, [ip_port()]} | network_error(). -resolve(Host, State) -> - try callback(resolve, Host, State) of - [] -> - do_resolve(Host, State); - HostPorts -> - a_lookup(HostPorts, State) - catch _:{?MODULE, undef} -> - do_resolve(Host, State) - end. - --spec do_resolve(string(), state()) -> {ok, [ip_port()]} | network_error(). -do_resolve(Host, State) -> - case srv_lookup(Host, State) of - {error, _Reason} -> - DefaultPort = get_default_port(State), - a_lookup([{Host, DefaultPort, false}], State); - {ok, HostPorts} -> - a_lookup(HostPorts, State) - end. - --spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error(). -srv_lookup(_Host, #{xmlns := ?NS_COMPONENT}) -> - %% Do not attempt to lookup SRV for component connections - {error, nxdomain}; -srv_lookup(Host, State) -> - %% Only perform SRV lookups for FQDN names - case string:chr(Host, $.) of - 0 -> - {error, nxdomain}; - _ -> - case inet:parse_address(Host) of - {ok, _} -> - {error, nxdomain}; - {error, _} -> - Timeout = get_dns_timeout(State), - Retries = get_dns_retries(State), - case srv_lookup(Host, State, Timeout, Retries) of - {ok, AddrList} -> - h_addr_list_to_host_ports(AddrList); - {error, _} = Err -> - Err - end - end - end. - -srv_lookup(Host, #{xmlns := NS} = State, Timeout, Retries) -> - SRVType = case NS of - ?NS_SERVER -> "-server._tcp."; - ?NS_CLIENT -> "-client._tcp." - end, - TLSAddrs = case is_starttls_available(State) of - true -> - case srv_lookup("_xmpps" ++ SRVType ++ Host, - Timeout, Retries) of - {ok, HostEnt} -> - [{A, true} || A <- HostEnt#hostent.h_addr_list]; - {error, _} -> - [] - end; - false -> - [] - end, - case srv_lookup("_xmpp" ++ SRVType ++ Host, Timeout, Retries) of - {ok, HostEntry} -> - Addrs = [{A, false} || A <- HostEntry#hostent.h_addr_list], - {ok, TLSAddrs ++ Addrs}; - {error, _} when TLSAddrs /= [] -> - {ok, TLSAddrs}; - {error, _} = Err -> - Err - end. - --spec srv_lookup(string(), timeout(), integer()) -> - {ok, inet:hostent()} | network_error(). -srv_lookup(_SRVName, _Timeout, Retries) when Retries < 1 -> - {error, timeout}; -srv_lookup(SRVName, Timeout, Retries) -> - case inet_res:getbyname(SRVName, srv, Timeout) of - {ok, HostEntry} -> - {ok, HostEntry}; - {error, timeout} -> - srv_lookup(SRVName, Timeout, Retries - 1); - {error, _} = Err -> - Err - end. - --spec a_lookup([host_port()], state()) -> - {ok, [ip_port()]} | network_error(). -a_lookup(HostPorts, State) -> - HostPortFamilies = [{Host, Port, TLS, Family} - || {Host, Port, TLS} <- HostPorts, - Family <- get_address_families(State)], - a_lookup(HostPortFamilies, State, [], {error, nxdomain}). - --spec a_lookup([{inet:hostname() | inet:ip_address(), inet:port_number(), - boolean(), inet:address_family()}], - state(), [ip_port()], network_error()) -> {ok, [ip_port()]} | network_error(). -a_lookup([{Addr, Port, TLS, Family}|HostPortFamilies], State, Acc, Err) - when is_tuple(Addr) -> - Acc1 = if tuple_size(Addr) == 4 andalso Family == inet -> - [{Addr, Port, TLS}|Acc]; - tuple_size(Addr) == 8 andalso Family == inet6 -> - [{Addr, Port, TLS}|Acc]; - true -> - Acc - end, - a_lookup(HostPortFamilies, State, Acc1, Err); -a_lookup([{Host, Port, TLS, Family}|HostPortFamilies], State, Acc, Err) -> - Timeout = get_dns_timeout(State), - Retries = get_dns_retries(State), - case a_lookup(Host, Port, TLS, Family, Timeout, Retries) of - {error, Reason} -> - a_lookup(HostPortFamilies, State, Acc, {error, Reason}); - {ok, AddrPorts} -> - a_lookup(HostPortFamilies, State, Acc ++ AddrPorts, Err) - end; -a_lookup([], _State, [], Err) -> - Err; -a_lookup([], _State, Acc, _) -> - {ok, Acc}. - --spec a_lookup(inet:hostname(), inet:port_number(), boolean(), inet:address_family(), - timeout(), integer()) -> {ok, [ip_port()]} | network_error(). -a_lookup(_Host, _Port, _TLS, _Family, _Timeout, Retries) when Retries < 1 -> - {error, timeout}; -a_lookup(Host, Port, TLS, Family, Timeout, Retries) -> - Start = p1_time_compat:monotonic_time(milli_seconds), - case inet:gethostbyname(Host, Family, Timeout) of - {error, nxdomain} = Err -> - %% inet:gethostbyname/3 doesn't return {error, timeout}, - %% so we should check if 'nxdomain' is in fact a result - %% of a timeout. - %% We also cannot use inet_res:gethostbyname/3 because - %% it ignores DNS configuration settings (/etc/hosts, etc) - End = p1_time_compat:monotonic_time(milli_seconds), - if (End - Start) >= Timeout -> - a_lookup(Host, Port, TLS, Family, Timeout, Retries - 1); - true -> - Err - end; - {error, _} = Err -> - Err; - {ok, HostEntry} -> - host_entry_to_addr_ports(HostEntry, Port, TLS) - end. - --spec h_addr_list_to_host_ports(h_addr_list()) -> {ok, [host_port()]} | - {error, nxdomain}. -h_addr_list_to_host_ports(AddrList) -> - PrioHostPorts = lists:flatmap( - fun({{Priority, Weight, Port, Host}, TLS}) -> - N = case Weight of - 0 -> 0; - _ -> (Weight + 1) * p1_rand:uniform() - end, - [{Priority * 65536 - N, Host, Port, TLS}]; - (_) -> - [] - end, AddrList), - HostPorts = [{Host, Port, TLS} - || {_Priority, Host, Port, TLS} <- lists:usort(PrioHostPorts)], - case HostPorts of - [] -> {error, nxdomain}; - _ -> {ok, HostPorts} - end. - --spec host_entry_to_addr_ports(inet:hostent(), inet:port_number(), boolean()) -> - {ok, [ip_port()]} | {error, nxdomain}. -host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port, TLS) -> - AddrPorts = lists:flatmap( - fun(Addr) -> - try get_addr_type(Addr) of - _ -> [{Addr, Port, TLS}] - catch _:_ -> - [] - end - end, AddrList), - case AddrPorts of - [] -> {error, nxdomain}; - _ -> {ok, AddrPorts} - end. - --spec connect([ip_port()], state()) -> {ok, term(), ip_port()} | - {error, {socket, socket_error_reason()}} | - {error, {tls, tls_error_reason()}}. -connect(AddrPorts, State) -> - Timeout = get_connect_timeout(State), - case connect(AddrPorts, Timeout, State, {error, nxdomain}) of - {ok, Socket, {Addr, Port, TLS = true}} -> - case starttls(Socket, State) of - {ok, TLSSocket} -> {ok, TLSSocket, {Addr, Port, TLS}}; - {error, Why} -> {error, {tls, Why}} - end; - {ok, Socket, {Addr, Port, TLS = false}} -> - {ok, Socket, {Addr, Port, TLS}}; - {error, Why} -> - {error, {socket, Why}} - end. - --spec connect([ip_port()], timeout(), state(), network_error()) -> - {ok, term(), ip_port()} | network_error(). -connect([{Addr, Port, TLS}|AddrPorts], Timeout, State, _) -> - Type = get_addr_type(Addr), - Opts = [binary, {packet, 0}, - {send_timeout, ?TCP_SEND_TIMEOUT}, - {send_timeout_close, true}, - {active, false}, Type], - Opts1 = try callback(connect_options, Addr, Opts, State) - catch _:{?MODULE, undef} -> Opts - end, - try xmpp_socket:connect(Addr, Port, Opts1, Timeout) of - {ok, Socket} -> - {ok, Socket, {Addr, Port, TLS}}; - Err -> - connect(AddrPorts, Timeout, State, Err) - catch _:badarg -> - connect(AddrPorts, Timeout, State, {error, einval}) - end; -connect([], _Timeout, _State, Err) -> - Err. - --spec get_addr_type(inet:ip_address()) -> inet:address_family(). -get_addr_type({_, _, _, _}) -> inet; -get_addr_type({_, _, _, _, _, _, _, _}) -> inet6. - --spec get_dns_timeout(state()) -> timeout(). -get_dns_timeout(State) -> - try callback(dns_timeout, State) - catch _:{?MODULE, undef} -> timer:seconds(10) - end. - --spec get_dns_retries(state()) -> non_neg_integer(). -get_dns_retries(State) -> - try callback(dns_retries, State) - catch _:{?MODULE, undef} -> 2 - end. - --spec get_default_port(state()) -> inet:port_number(). -get_default_port(#{xmlns := NS} = State) -> - try callback(default_port, State) - catch _:{?MODULE, undef} when NS == ?NS_SERVER -> 5269; - _:{?MODULE, undef} when NS == ?NS_CLIENT -> 5222 - end. - --spec get_address_families(state()) -> [inet:address_family()]. -get_address_families(State) -> - try callback(address_families, State) - catch _:{?MODULE, undef} -> [inet, inet6] - end. - --spec get_connect_timeout(state()) -> timeout(). -get_connect_timeout(State) -> - try callback(connect_timeout, State) - catch _:{?MODULE, undef} -> timer:seconds(10) - end. - -%%%=================================================================== -%%% Callbacks -%%%=================================================================== -callback(F, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 1) of - true -> Mod:F(State); - false -> erlang:error({?MODULE, undef}) - end. - -callback(F, Arg1, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 2) of - true -> Mod:F(Arg1, State); - false -> erlang:error({?MODULE, undef}) - end. - -callback(code_change, OldVsn, #{mod := Mod} = State, Extra) -> - %% code_change/3 callback is a special snowflake - case erlang:function_exported(Mod, code_change, 3) of - true -> Mod:code_change(OldVsn, State, Extra); - false -> {ok, State} - end; -callback(F, Arg1, Arg2, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 3) of - true -> Mod:F(Arg1, Arg2, State); - false -> erlang:error({?MODULE, undef}) - end. diff --git a/src/xmpp_stream_pkix.erl b/src/xmpp_stream_pkix.erl deleted file mode 100644 index 4077e7849..000000000 --- a/src/xmpp_stream_pkix.erl +++ /dev/null @@ -1,271 +0,0 @@ -%%%------------------------------------------------------------------- -%%% Created : 13 Dec 2016 by Evgeny Khramtsov -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%------------------------------------------------------------------- --module(xmpp_stream_pkix). - -%% API --export([authenticate/1, authenticate/2, get_cert_domains/1, format_error/1]). - --include("xmpp.hrl"). --include_lib("public_key/include/public_key.hrl"). --include("XmppAddr.hrl"). - --type cert() :: #'OTPCertificate'{}. - -%%%=================================================================== -%%% API -%%%=================================================================== --spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state()) - -> {ok, binary()} | {error, atom(), binary()}. -authenticate(State) -> - authenticate(State, <<"">>). - --spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state(), binary()) - -> {ok, binary()} | {error, atom(), binary()}. -authenticate(#{xmlns := ?NS_SERVER, - socket := Socket} = State, Authzid) -> - Peer = maps:get(remote_server, State, Authzid), - case verify_cert(Socket) of - {ok, Cert} -> - case ejabberd_idna:domain_utf8_to_ascii(Peer) of - false -> - {error, idna_failed, Peer}; - AsciiPeer -> - case lists:any( - fun(D) -> match_domain(AsciiPeer, D) end, - get_cert_domains(Cert)) of - true -> - {ok, Peer}; - false -> - {error, hostname_mismatch, Peer} - end - end; - {error, Reason} -> - {error, Reason, Peer} - end; -authenticate(#{xmlns := ?NS_CLIENT, - socket := Socket, lserver := LServer}, Authzid) -> - JID = try jid:decode(Authzid) - catch _:{bad_jid, <<>>} -> jid:make(LServer); - _:{bad_jid, _} -> {error, invalid_authzid, Authzid} - end, - case JID of - #jid{user = User} -> - case verify_cert(Socket) of - {ok, Cert} -> - JIDs = get_xmpp_addrs(Cert), - get_username(JID, JIDs, LServer); - {error, Reason} -> - {error, Reason, User} - end; - Err -> - Err - end. - -format_error(idna_failed) -> - {'bad-protocol', <<"Remote domain is not an IDN hostname">>}; -format_error(hostname_mismatch) -> - {'not-authorized', <<"Certificate host name mismatch">>}; -format_error(jid_mismatch) -> - {'not-authorized', <<"Certificate JID mismatch">>}; -format_error(get_cert_failed) -> - {'bad-protocol', <<"Failed to get peer certificate">>}; -format_error(invalid_authzid) -> - {'invalid-authzid', <<"Malformed JID">>}; -format_error(Other) -> - {'not-authorized', erlang:atom_to_binary(Other, utf8)}. - --spec get_cert_domains(cert()) -> [binary()]. -get_cert_domains(Cert) -> - TBSCert = Cert#'OTPCertificate'.tbsCertificate, - {rdnSequence, Subject} = TBSCert#'OTPTBSCertificate'.subject, - Extensions = TBSCert#'OTPTBSCertificate'.extensions, - get_domain_from_subject(lists:flatten(Subject)) ++ - get_domains_from_san(Extensions). - -%%%=================================================================== -%%% Internal functions -%%%=================================================================== --spec verify_cert(xmpp_socket:socket()) -> {ok, cert()} | {error, atom()}. -verify_cert(Socket) -> - case xmpp_socket:get_peer_certificate(Socket, otp) of - {ok, Cert} -> - case xmpp_socket:get_verify_result(Socket) of - 0 -> - {ok, Cert}; - VerifyRes -> - %% TODO: return atomic errors - %% This should be improved in fast_tls - Reason = fast_tls:get_cert_verify_string(VerifyRes, Cert), - {error, erlang:binary_to_atom(Reason, utf8)} - end; - {error, _Reason} -> - {error, get_cert_failed}; - error -> - {error, get_cert_failed} - end. - --spec get_domain_from_subject([#'AttributeTypeAndValue'{}]) -> [binary()]. -get_domain_from_subject(AttrVals) -> - case lists:keyfind(?'id-at-commonName', - #'AttributeTypeAndValue'.type, - AttrVals) of - #'AttributeTypeAndValue'{value = {_, S}} -> - try jid:decode(iolist_to_binary(S)) of - #jid{luser = <<"">>, lresource = <<"">>, lserver = Domain} -> - [Domain]; - _ -> - [] - catch _:{bad_jid, _} -> - [] - end; - _ -> - [] - end. - --spec get_domains_from_san([#'Extension'{}] | asn1_NOVALUE) -> [binary()]. -get_domains_from_san(Extensions) when is_list(Extensions) -> - case lists:keyfind(?'id-ce-subjectAltName', - #'Extension'.extnID, - Extensions) of - #'Extension'{extnValue = Vals} -> - lists:flatmap( - fun({dNSName, S}) -> - [iolist_to_binary(S)]; - ({otherName, AnotherName}) -> - case decode_xmpp_addr(AnotherName) of - {ok, #jid{luser = <<"">>, - lresource = <<"">>, - lserver = Domain}} -> - case ejabberd_idna:domain_utf8_to_ascii(Domain) of - false -> - []; - ASCIIDomain -> - [ASCIIDomain] - end; - _ -> - [] - end; - (_) -> - [] - end, Vals); - _ -> - [] - end; -get_domains_from_san(_) -> - []. - --spec decode_xmpp_addr(#'AnotherName'{}) -> {ok, jid()} | error. -decode_xmpp_addr(#'AnotherName'{'type-id' = ?'id-on-xmppAddr', - value = XmppAddr}) -> - try 'XmppAddr':decode('XmppAddr', XmppAddr) of - {ok, JIDStr} -> - try {ok, jid:decode(iolist_to_binary(JIDStr))} - catch _:{bad_jid, _} -> error - end; - _ -> - error - catch _:_ -> - error - end; -decode_xmpp_addr(_) -> - error. - --spec get_xmpp_addrs(cert()) -> [jid()]. -get_xmpp_addrs(Cert) -> - TBSCert = Cert#'OTPCertificate'.tbsCertificate, - case TBSCert#'OTPTBSCertificate'.extensions of - Extensions when is_list(Extensions) -> - case lists:keyfind(?'id-ce-subjectAltName', - #'Extension'.extnID, - Extensions) of - #'Extension'{extnValue = Vals} -> - lists:flatmap( - fun({otherName, AnotherName}) -> - case decode_xmpp_addr(AnotherName) of - {ok, JID} -> [JID]; - _ -> [] - end; - (_) -> - [] - end, Vals); - _ -> - [] - end; - _ -> - [] - end. - -match_domain(Domain, Domain) -> true; -match_domain(Domain, Pattern) -> - DLabels = str:tokens(Domain, <<".">>), - PLabels = str:tokens(Pattern, <<".">>), - match_labels(DLabels, PLabels). - -match_labels([], []) -> true; -match_labels([], [_ | _]) -> false; -match_labels([_ | _], []) -> false; -match_labels([DL | DLabels], [PL | PLabels]) -> - case lists:all(fun (C) -> - $a =< C andalso C =< $z orelse - $0 =< C andalso C =< $9 orelse - C == $- orelse C == $* - end, - binary_to_list(PL)) - of - true -> - Regexp = ejabberd_regexp:sh_to_awk(PL), - case ejabberd_regexp:run(DL, Regexp) of - match -> match_labels(DLabels, PLabels); - nomatch -> false - end; - false -> false - end. - --spec get_username(jid(), [jid()], binary()) -> - {ok, binary()} | {error, jid_mismatch, binary()}. -get_username(#jid{user = User, lserver = LS}, _, LServer) when LS /= LServer -> - %% The user provided JID from different domain - {error, jid_mismatch, User}; -get_username(#jid{user = <<>>}, [#jid{user = U, lserver = LS}], LServer) - when U /= <<>> andalso LS == LServer -> - %% The user didn't provide JID or username, and there is only - %% one 'non-global' JID matching current domain - {ok, U}; -get_username(#jid{user = User, luser = LUser}, JIDs, LServer) when User /= <<>> -> - %% The user provided username - lists:foldl( - fun(_, {ok, _} = OK) -> - OK; - (#jid{user = <<>>, lserver = LS}, _) when LS == LServer -> - %% Found "global" JID in the certficate - %% (i.e. in the form of 'domain.com') - %% within current domain, so we force matching - {ok, User}; - (#jid{luser = LU, lserver = LS}, _) when LU == LUser, LS == LServer -> - %% Found exact JID matching - {ok, User}; - (_, Err) -> - Err - end, {error, jid_mismatch, User}, JIDs); -get_username(#jid{user = User}, _, _) -> - %% Nothing from above is true - {error, jid_mismatch, User}. diff --git a/test/ejabberd_cyrsasl_test.exs b/test/ejabberd_cyrsasl_test.exs index bdef92cd4..e73c12a14 100644 --- a/test/ejabberd_cyrsasl_test.exs +++ b/test/ejabberd_cyrsasl_test.exs @@ -32,32 +32,32 @@ defmodule EjabberdCyrsaslTest do start_module(:jid) :ejabberd_hooks.start_link :ok = :ejabberd_config.start(["domain1"], []) - {:ok, _} = :cyrsasl.start_link - cyrstate = :cyrsasl.server_new("domain1", "domain1", "domain1", :ok, &get_password/1, + {:ok, _} = :xmpp_sasl.start_link + cyrstate = :xmpp_sasl.server_new("domain1", "domain1", "domain1", :ok, &get_password/1, &check_password/3, &check_password_digest/5) setup_anonymous_mocks() {:ok, cyrstate: cyrstate} end test "Plain text (correct user and pass)", context do - step1 = :cyrsasl.server_start(context[:cyrstate], "PLAIN", <<0,"user1",0,"pass">>) + step1 = :xmpp_sasl.server_start(context[:cyrstate], "PLAIN", <<0,"user1",0,"pass">>) assert {:ok, _} = step1 {:ok, kv} = step1 assert kv[:authzid] == "user1", "got correct user" end test "Plain text (correct user wrong pass)", context do - step1 = :cyrsasl.server_start(context[:cyrstate], "PLAIN", <<0,"user1",0,"badpass">>) + step1 = :xmpp_sasl.server_start(context[:cyrstate], "PLAIN", <<0,"user1",0,"badpass">>) assert step1 == {:error, :not_authorized, "user1"} end test "Plain text (wrong user wrong pass)", context do - step1 = :cyrsasl.server_start(context[:cyrstate], "PLAIN", <<0,"nouser1",0,"badpass">>) + step1 = :xmpp_sasl.server_start(context[:cyrstate], "PLAIN", <<0,"nouser1",0,"badpass">>) assert step1 == {:error, :not_authorized, "nouser1"} end test "Anonymous", context do - step1 = :cyrsasl.server_start(context[:cyrstate], "ANONYMOUS", "domain1") + step1 = :xmpp_sasl.server_start(context[:cyrstate], "ANONYMOUS", "domain1") assert {:ok, _} = step1 end @@ -78,7 +78,7 @@ defmodule EjabberdCyrsaslTest do end defp process_digest_md5(cyrstate, user, domain, pass) do - assert {:continue, init_str, state1} = :cyrsasl.server_start(cyrstate, "DIGEST-MD5", "") + assert {:continue, init_str, state1} = :xmpp_sasl.server_start(cyrstate, "DIGEST-MD5", "") assert [_, nonce] = Regex.run(~r/nonce="(.*?)"/, init_str) digest_uri = "xmpp/#{domain}" cnonce = "abcd" @@ -87,8 +87,8 @@ defmodule EjabberdCyrsaslTest do response = "username=\"#{user}\",realm=\"#{domain}\",nonce=\"#{nonce}\",cnonce=\"#{cnonce}\"," <> "nc=\"#{nc}\",qop=auth,digest-uri=\"#{digest_uri}\",response=\"#{response_hash}\"," <> "charset=utf-8,algorithm=md5-sess" - case :cyrsasl.server_step(state1, response) do - {:continue, _calc_str, state2} -> :cyrsasl.server_step(state2, "") + case :xmpp_sasl.server_step(state1, response) do + {:continue, _calc_str, state2} -> :xmpp_sasl.server_step(state2, "") other -> other end end diff --git a/test/suite.erl b/test/suite.erl index 38d198a9a..efe587253 100644 --- a/test/suite.erl +++ b/test/suite.erl @@ -599,7 +599,7 @@ sasl_new(<<"ANONYMOUS">>, _) -> sasl_new(<<"DIGEST-MD5">>, {User, Server, Password}) -> {<<"">>, fun (ServerIn) -> - case cyrsasl_digest:parse(ServerIn) of + case xmpp_sasl_digest:parse(ServerIn) of bad -> {error, <<"Invalid SASL challenge">>}; KeyVals -> Nonce = fxml:get_attr_s(<<"nonce">>, KeyVals), @@ -625,7 +625,7 @@ sasl_new(<<"DIGEST-MD5">>, {User, Server, Password}) -> MyResponse/binary, "\"">>, {Resp, fun (ServerIn2) -> - case cyrsasl_digest:parse(ServerIn2) of + case xmpp_sasl_digest:parse(ServerIn2) of bad -> {error, <<"Invalid SASL challenge">>}; _KeyVals2 -> {<<"">>,