Support certificate verification for outgoing s2s

Handle "s2s_use_starttls: required_trusted" the same way for outgoing
s2s connections as for incoming connections.  That is, check the remote
server's certificate (including the host name) and abort the connection
if verification fails.
This commit is contained in:
Holger Weiss 2014-04-28 01:42:02 +02:00
parent 3a3f8240c1
commit 49bdbf2895
3 changed files with 218 additions and 205 deletions

View File

@ -37,7 +37,8 @@
incoming_s2s_number/0, outgoing_s2s_number/0,
clean_temporarily_blocked_table/0,
list_temporarily_blocked_hosts/0,
external_host_overloaded/1, is_temporarly_blocked/1]).
external_host_overloaded/1, is_temporarly_blocked/1,
check_peer_certificate/3]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2,
@ -53,6 +54,14 @@
-include("ejabberd_commands.hrl").
-include_lib("public_key/include/public_key.hrl").
-define(PKIXEXPLICIT, 'OTP-PUB-KEY').
-define(PKIXIMPLICIT, 'OTP-PUB-KEY').
-include("XmppAddr.hrl").
-define(DEFAULT_MAX_S2S_CONNECTIONS_NUMBER, 1).
-define(DEFAULT_MAX_S2S_CONNECTIONS_NUMBER_PER_NODE, 1).
@ -207,6 +216,31 @@ try_register(FromTo) ->
dirty_get_connections() ->
mnesia:dirty_all_keys(s2s).
check_peer_certificate(SockMod, Sock, Peer) ->
case SockMod:get_peer_certificate(Sock) of
{ok, Cert} ->
case SockMod:get_verify_result(Sock) of
0 ->
case idna:domain_utf8_to_ascii(Peer) of
false ->
{error, <<"Cannot decode remote server name">>};
AsciiPeer ->
case
lists:any(fun(D) -> match_domain(AsciiPeer, D) end,
get_cert_domains(Cert)) of
true ->
{ok, <<"Verification successful">>};
false ->
{error, <<"Certificate host name mismatch">>}
end
end;
VerifyRes ->
{error, p1_tls:get_cert_verify_string(VerifyRes, Cert)}
end;
error ->
{error, <<"Cannot get peer certificate">>}
end.
%%====================================================================
%% gen_server callbacks
%%====================================================================
@ -619,3 +653,121 @@ get_s2s_state(S2sPid) ->
{badrpc, _} -> [{status, error}]
end,
[{s2s_pid, S2sPid} | Infos].
get_cert_domains(Cert) ->
{rdnSequence, Subject} =
(Cert#'Certificate'.tbsCertificate)#'TBSCertificate'.subject,
Extensions =
(Cert#'Certificate'.tbsCertificate)#'TBSCertificate'.extensions,
lists:flatmap(fun (#'AttributeTypeAndValue'{type =
?'id-at-commonName',
value = Val}) ->
case 'OTP-PUB-KEY':decode('X520CommonName', Val) of
{ok, {_, D1}} ->
D = if is_binary(D1) -> D1;
is_list(D1) -> list_to_binary(D1);
true -> error
end,
if D /= error ->
case jlib:string_to_jid(D) of
#jid{luser = <<"">>, lserver = LD,
lresource = <<"">>} ->
[LD];
_ -> []
end;
true -> []
end;
_ -> []
end;
(_) -> []
end,
lists:flatten(Subject))
++
lists:flatmap(fun (#'Extension'{extnID =
?'id-ce-subjectAltName',
extnValue = Val}) ->
BVal = if is_list(Val) -> list_to_binary(Val);
true -> Val
end,
case 'OTP-PUB-KEY':decode('SubjectAltName', BVal)
of
{ok, SANs} ->
lists:flatmap(fun ({otherName,
#'AnotherName'{'type-id' =
?'id-on-xmppAddr',
value =
XmppAddr}}) ->
case
'XmppAddr':decode('XmppAddr',
XmppAddr)
of
{ok, D}
when
is_binary(D) ->
case
jlib:string_to_jid((D))
of
#jid{luser =
<<"">>,
lserver =
LD,
lresource =
<<"">>} ->
case
idna:domain_utf8_to_ascii(LD)
of
false ->
[];
PCLD ->
[PCLD]
end;
_ -> []
end;
_ -> []
end;
({dNSName, D})
when is_list(D) ->
case
jlib:string_to_jid(list_to_binary(D))
of
#jid{luser = <<"">>,
lserver = LD,
lresource =
<<"">>} ->
[LD];
_ -> []
end;
(_) -> []
end,
SANs);
_ -> []
end;
(_) -> []
end,
Extensions).
match_domain(Domain, Domain) -> true;
match_domain(Domain, Pattern) ->
DLabels = str:tokens(Domain, <<".">>),
PLabels = str:tokens(Pattern, <<".">>),
match_labels(DLabels, PLabels).
match_labels([], []) -> true;
match_labels([], [_ | _]) -> false;
match_labels([_ | _], []) -> false;
match_labels([DL | DLabels], [PL | PLabels]) ->
case lists:all(fun (C) ->
$a =< C andalso C =< $z orelse
$0 =< C andalso C =< $9 orelse
C == $- orelse C == $*
end,
binary_to_list(PL))
of
true ->
Regexp = ejabberd_regexp:sh_to_awk(PL),
case ejabberd_regexp:run(DL, Regexp) of
match -> match_labels(DLabels, PLabels);
nomatch -> false
end;
false -> false
end.

View File

@ -30,8 +30,7 @@
-behaviour(p1_fsm).
%% External exports
-export([start/2, start_link/2, match_domain/2,
socket_type/0]).
-export([start/2, start_link/2, socket_type/0]).
%% gen_fsm callbacks
-export([init/1, wait_for_stream/2,
@ -44,14 +43,6 @@
-include("jlib.hrl").
-include_lib("public_key/include/public_key.hrl").
-define(PKIXEXPLICIT, 'OTP-PUB-KEY').
-define(PKIXIMPLICIT, 'OTP-PUB-KEY').
-include("XmppAddr.hrl").
-define(DICT, dict).
-record(state,
@ -227,45 +218,11 @@ wait_for_stream({xmlstreamstart, _Name, Attrs},
Auth = if StateData#state.tls_enabled ->
case jlib:nameprep(xml:get_attr_s(<<"from">>, Attrs)) of
From when From /= <<"">>, From /= error ->
case
(StateData#state.sockmod):get_peer_certificate(StateData#state.socket)
of
{ok, Cert} ->
case
(StateData#state.sockmod):get_verify_result(StateData#state.socket)
of
0 ->
case
idna:domain_utf8_to_ascii(From)
of
false ->
{error, From,
<<"Cannot decode 'from' attribute">>};
PCAuthDomain ->
case
lists:any(fun (D) ->
match_domain(PCAuthDomain,
D)
end,
get_cert_domains(Cert))
of
true ->
{ok, From,
<<"Success">>};
false ->
{error, From,
<<"Certificate host name mismatch">>}
end
end;
CertVerifyRes ->
{error, From,
p1_tls:get_cert_verify_string(CertVerifyRes,
Cert)}
end;
error ->
{error, From,
<<"Cannot get peer certificate">>}
end;
{Result, Message} =
ejabberd_s2s:check_peer_certificate(StateData#state.sockmod,
StateData#state.socket,
From),
{Result, From, Message};
_ ->
{error, <<"(unknown)">>,
<<"Got no valid 'from' attribute">>}
@ -746,124 +703,6 @@ is_key_packet(#xmlel{name = Name, attrs = Attrs,
xml:get_attr_s(<<"id">>, Attrs), xml:get_cdata(Els)};
is_key_packet(_) -> false.
get_cert_domains(Cert) ->
{rdnSequence, Subject} =
(Cert#'Certificate'.tbsCertificate)#'TBSCertificate'.subject,
Extensions =
(Cert#'Certificate'.tbsCertificate)#'TBSCertificate'.extensions,
lists:flatmap(fun (#'AttributeTypeAndValue'{type =
?'id-at-commonName',
value = Val}) ->
case 'OTP-PUB-KEY':decode('X520CommonName', Val) of
{ok, {_, D1}} ->
D = if is_binary(D1) -> D1;
is_list(D1) -> list_to_binary(D1);
true -> error
end,
if D /= error ->
case jlib:string_to_jid(D) of
#jid{luser = <<"">>, lserver = LD,
lresource = <<"">>} ->
[LD];
_ -> []
end;
true -> []
end;
_ -> []
end;
(_) -> []
end,
lists:flatten(Subject))
++
lists:flatmap(fun (#'Extension'{extnID =
?'id-ce-subjectAltName',
extnValue = Val}) ->
BVal = if is_list(Val) -> list_to_binary(Val);
true -> Val
end,
case 'OTP-PUB-KEY':decode('SubjectAltName', BVal)
of
{ok, SANs} ->
lists:flatmap(fun ({otherName,
#'AnotherName'{'type-id' =
?'id-on-xmppAddr',
value =
XmppAddr}}) ->
case
'XmppAddr':decode('XmppAddr',
XmppAddr)
of
{ok, D}
when
is_binary(D) ->
case
jlib:string_to_jid((D))
of
#jid{luser =
<<"">>,
lserver =
LD,
lresource =
<<"">>} ->
case
idna:domain_utf8_to_ascii(LD)
of
false ->
[];
PCLD ->
[PCLD]
end;
_ -> []
end;
_ -> []
end;
({dNSName, D})
when is_list(D) ->
case
jlib:string_to_jid(list_to_binary(D))
of
#jid{luser = <<"">>,
lserver = LD,
lresource =
<<"">>} ->
[LD];
_ -> []
end;
(_) -> []
end,
SANs);
_ -> []
end;
(_) -> []
end,
Extensions).
match_domain(Domain, Domain) -> true;
match_domain(Domain, Pattern) ->
DLabels = str:tokens(Domain, <<".">>),
PLabels = str:tokens(Pattern, <<".">>),
match_labels(DLabels, PLabels).
match_labels([], []) -> true;
match_labels([], [_ | _]) -> false;
match_labels([_ | _], []) -> false;
match_labels([DL | DLabels], [PL | PLabels]) ->
case lists:all(fun (C) ->
$a =< C andalso C =< $z orelse
$0 =< C andalso C =< $9 orelse
C == $- orelse C == $*
end,
binary_to_list(PL))
of
true ->
Regexp = ejabberd_regexp:sh_to_awk(PL),
case ejabberd_regexp:run(DL, Regexp) of
match -> match_labels(DLabels, PLabels);
nomatch -> false
end;
false -> false
end.
fsm_limit_opts(Opts) ->
case lists:keysearch(max_fsm_queue, 1, Opts) of
{value, {_, N}} when is_integer(N) -> [{max_queue, N}];

View File

@ -69,6 +69,7 @@
use_v10 = true :: boolean(),
tls = false :: boolean(),
tls_required = false :: boolean(),
tls_certverify = false :: boolean(),
tls_enabled = false :: boolean(),
tls_options = [connect] :: list(),
authenticated = false :: boolean(),
@ -160,28 +161,27 @@ stop_connection(Pid) -> p1_fsm:send_event(Pid, closed).
init([From, Server, Type]) ->
process_flag(trap_exit, true),
?DEBUG("started: ~p", [{From, Server, Type}]),
{TLS, TLSRequired} = case
ejabberd_config:get_option(
s2s_use_starttls,
fun(true) -> true;
(false) -> false;
(optional) -> optional;
(required) -> required;
(required_trusted) -> required_trusted
end)
of
UseTls
when (UseTls == undefined) or
(UseTls == false) ->
{false, false};
UseTls
when (UseTls == true) or (UseTls == optional) ->
{true, false};
UseTls
when (UseTls == required) or
(UseTls == required_trusted) ->
{true, true}
end,
{TLS, TLSRequired, TLSCertverify} =
case ejabberd_config:get_option(
s2s_use_starttls,
fun(true) -> true;
(false) -> false;
(optional) -> optional;
(required) -> required;
(required_trusted) -> required_trusted
end)
of
UseTls
when (UseTls == undefined) or (UseTls == false) ->
{false, false, false};
UseTls
when (UseTls == true) or (UseTls == optional) ->
{true, false, false};
required ->
{true, true, false};
required_trusted ->
{true, true, true}
end,
UseV10 = TLS,
TLSOpts1 = case
ejabberd_config:get_option(
@ -223,9 +223,9 @@ init([From, Server, Type]) ->
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
{ok, open_socket,
#state{use_v10 = UseV10, tls = TLS,
tls_required = TLSRequired, tls_options = TLSOpts,
queue = queue:new(), myname = From, server = Server,
new = New, verify = Verify, timer = Timer}}.
tls_required = TLSRequired, tls_certverify = TLSCertverify,
tls_options = TLSOpts, queue = queue:new(), myname = From,
server = Server, new = New, verify = Verify, timer = Timer}}.
%%----------------------------------------------------------------------
%% Func: StateName/2
@ -345,35 +345,57 @@ open_socket2(Type, Addr, Port) ->
wait_for_stream({xmlstreamstart, _Name, Attrs},
StateData) ->
{CertCheckRes, CertCheckMsg, NewStateData} =
if StateData#state.tls_certverify, StateData#state.tls_enabled ->
{Res, Msg} =
ejabberd_s2s:check_peer_certificate(ejabberd_socket,
StateData#state.socket,
StateData#state.server),
?DEBUG("Certificate verification result for ~s: ~s",
[StateData#state.server, Msg]),
{Res, Msg, StateData#state{tls_certverify = false}};
true ->
{no_verify, <<"Not verified">>, StateData}
end,
case {xml:get_attr_s(<<"xmlns">>, Attrs),
xml:get_attr_s(<<"xmlns:db">>, Attrs),
xml:get_attr_s(<<"version">>, Attrs) == <<"1.0">>}
of
_ when CertCheckRes == error ->
send_text(NewStateData,
<<(xml:element_to_binary(?SERRT_POLICY_VIOLATION(<<"en">>,
CertCheckMsg)))/binary,
(?STREAM_TRAILER)/binary>>),
?INFO_MSG("Closing s2s connection: ~s -> ~s (~s)",
[NewStateData#state.myname,
NewStateData#state.server,
CertCheckMsg]),
{stop, normal, NewStateData};
{<<"jabber:server">>, <<"jabber:server:dialback">>,
false} ->
send_db_request(StateData);
send_db_request(NewStateData);
{<<"jabber:server">>, <<"jabber:server:dialback">>,
true}
when StateData#state.use_v10 ->
{next_state, wait_for_features, StateData, ?FSMTIMEOUT};
when NewStateData#state.use_v10 ->
{next_state, wait_for_features, NewStateData, ?FSMTIMEOUT};
%% Clause added to handle Tigase's workaround for an old ejabberd bug:
{<<"jabber:server">>, <<"jabber:server:dialback">>,
true}
when not StateData#state.use_v10 ->
send_db_request(StateData);
when not NewStateData#state.use_v10 ->
send_db_request(NewStateData);
{<<"jabber:server">>, <<"">>, true}
when StateData#state.use_v10 ->
when NewStateData#state.use_v10 ->
{next_state, wait_for_features,
StateData#state{db_enabled = false}, ?FSMTIMEOUT};
NewStateData#state{db_enabled = false}, ?FSMTIMEOUT};
{NSProvided, DB, _} ->
send_text(StateData, ?INVALID_NAMESPACE_ERR),
send_text(NewStateData, ?INVALID_NAMESPACE_ERR),
?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid "
"namespace).~nNamespace provided: ~p~nNamespac"
"e expected: \"jabber:server\"~nxmlns:db "
"provided: ~p~nAll attributes: ~p",
[StateData#state.myname, StateData#state.server,
[NewStateData#state.myname, NewStateData#state.server,
NSProvided, DB, Attrs]),
{stop, normal, StateData}
{stop, normal, NewStateData}
end;
wait_for_stream({xmlstreamerror, _}, StateData) ->
send_text(StateData,
@ -736,8 +758,8 @@ wait_for_starttls_proceed({xmlstreamelement, El},
tls_options = TLSOpts},
send_text(NewStateData,
io_lib:format(?STREAM_HEADER,
[StateData#state.myname,
StateData#state.server,
[NewStateData#state.myname,
NewStateData#state.server,
<<" version='1.0'">>])),
{next_state, wait_for_stream, NewStateData,
?FSMTIMEOUT};