* src/ejabberd_s2s_out.erl: Support for STARTTLS+SASL EXTERNAL

* src/ejabberd_s2s_in.erl: Likewise
* src/tls/tls.erl: Likewise
* src/tls/tls_drv.c: Likewise
* src/tls/XmppAddr.asn1: Likewise
* src/tls/Makefile.in: Likewise

SVN Revision: 430
This commit is contained in:
Alexey Shchepin 2005-11-03 05:04:54 +00:00
parent 2efda30fdc
commit f6343f01f7
7 changed files with 393 additions and 31 deletions

View File

@ -1,3 +1,13 @@
2005-11-03 Alexey Shchepin <alexey@sevcom.net>
* src/ejabberd_s2s_out.erl: Support for STARTTLS+SASL EXTERNAL
(not well-tested yet)
* src/ejabberd_s2s_in.erl: Likewise
* src/tls/tls.erl: Likewise
* src/tls/tls_drv.c: Likewise
* src/tls/XmppAddr.asn1: Likewise
* src/tls/Makefile.in: Likewise
2005-10-30 Alexey Shchepin <alexey@sevcom.net>
* src/mod_disco.erl: Minor fix

View File

@ -29,6 +29,9 @@
-include("ejabberd.hrl").
-include("jlib.hrl").
%-include_lib("ssl/pkix/SSL-PKIX.hrl").
-include_lib("ssl/pkix/PKIX1Explicit88.hrl").
-include("tls/XmppAddr.hrl").
-define(DICT, dict).
@ -40,6 +43,8 @@
tls = false,
tls_enabled = false,
tls_options = [],
authenticated = false,
auth_domain,
connections = ?DICT:new(),
timer}).
@ -131,18 +136,50 @@ init([{SockMod, Socket}, Opts]) ->
%%----------------------------------------------------------------------
wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) ->
% TODO
case {xml:get_attr_s("xmlns", Attrs),
xml:get_attr_s("xmlns:db", Attrs),
xml:get_attr_s("version", Attrs) == "1.0"} of
{"jabber:server", "jabber:server:dialback", true} when
StateData#state.tls ->
StateData#state.tls and (not StateData#state.authenticated) ->
send_text(StateData, ?STREAM_HEADER(" version='1.0'")),
SASL =
if
StateData#state.tls_enabled ->
case tls:get_peer_certificate(StateData#state.socket) of
{ok, _Cert} ->
case tls:get_verify_result(
StateData#state.socket) of
0 ->
[{xmlelement, "mechanisms",
[{"xmlns", ?NS_SASL}],
[{xmlelement, "mechanism", [],
[{xmlcdata, "EXTERNAL"}]}]}];
_ ->
[]
end;
error ->
[]
end;
true ->
[]
end,
StartTLS = if
StateData#state.tls_enabled ->
[];
true ->
[{xmlelement, "starttls",
[{"xmlns", ?NS_TLS}], []}]
end,
send_element(StateData,
{xmlelement, "stream:features", [],
[{xmlelement, "starttls",
[{"xmlns", ?NS_TLS}], []}]}),
SASL ++ StartTLS}),
{next_state, wait_for_feature_request, StateData};
{"jabber:server", _, true} when
StateData#state.authenticated ->
send_text(StateData, ?STREAM_HEADER(" version='1.0'")),
send_element(StateData,
{xmlelement, "stream:features", [], []}),
{next_state, stream_established, StateData};
{"jabber:server", "jabber:server:dialback", _} ->
send_text(StateData, ?STREAM_HEADER("")),
{next_state, stream_established, StateData};
@ -185,6 +222,60 @@ wait_for_feature_request({xmlstreamelement, El}, StateData) ->
streamid = new_id(),
tls_enabled = true
}};
{?NS_SASL, "auth"} when TLSEnabled ->
Mech = xml:get_attr_s("mechanism", Attrs),
case Mech of
"EXTERNAL" ->
Auth = jlib:decode_base64(xml:get_cdata(Els)),
AuthDomain = jlib:nameprep(Auth),
AuthRes =
case tls:get_peer_certificate(StateData#state.socket) of
{ok, Cert} ->
case tls:get_verify_result(
StateData#state.socket) of
0 ->
case AuthDomain of
error ->
false;
_ ->
lists:member(
AuthDomain,
get_cert_domains(Cert))
end;
_ ->
false
end;
error ->
false
end,
if
AuthRes ->
ejabberd_receiver:reset_stream(
StateData#state.receiver),
send_element(StateData,
{xmlelement, "success",
[{"xmlns", ?NS_SASL}], []}),
?INFO_MSG("(~w) Accepted s2s authentication for ~s",
[StateData#state.socket, AuthDomain]),
{next_state, wait_for_stream,
StateData#state{streamid = new_id(),
authenticated = true,
auth_domain = AuthDomain
}};
true ->
send_element(StateData,
{xmlelement, "failure",
[{"xmlns", ?NS_SASL}], []}),
send_text(StateData, ?STREAM_TRAILER),
{stop, normal, StateData}
end;
_ ->
send_element(StateData,
{xmlelement, "failure",
[{"xmlns", ?NS_SASL}],
[{xmlelement, "invalid-mechanism", [], []}]}),
{stop, normal, StateData}
end;
_ ->
stream_established({xmlstreamelement, El}, StateData)
end;
@ -252,18 +343,38 @@ stream_established({xmlstreamelement, El}, StateData) ->
(To /= error) and (From /= error) ->
LFrom = From#jid.lserver,
LTo = To#jid.lserver,
case ?DICT:find({LFrom, LTo},
StateData#state.connections) of
{ok, established} ->
if ((Name == "iq") or
(Name == "message") or
(Name == "presence")) ->
ejabberd_router:route(From, To, El);
true ->
if
StateData#state.authenticated ->
case (LFrom == StateData#state.auth_domain)
andalso
lists:member(
LTo,
ejabberd_router:dirty_get_all_domains()) of
true ->
if ((Name == "iq") or
(Name == "message") or
(Name == "presence")) ->
ejabberd_router:route(From, To, El);
true ->
error
end;
false ->
error
end;
_ ->
error
true ->
case ?DICT:find({LFrom, LTo},
StateData#state.connections) of
{ok, established} ->
if ((Name == "iq") or
(Name == "message") or
(Name == "presence")) ->
ejabberd_router:route(From, To, El);
true ->
error
end;
_ ->
error
end
end;
true ->
error
@ -365,7 +476,7 @@ handle_info({send_text, Text}, StateName, StateData) ->
send_text(StateData, Text),
{next_state, StateName, StateData};
handle_info({timeout, Timer, _}, StateName,
handle_info({timeout, Timer, _}, _StateName,
#state{timer = Timer} = StateData) ->
{stop, normal, StateData};
@ -428,4 +539,51 @@ is_key_packet(_) ->
false.
get_cert_domains(Cert) ->
{rdnSequence, Subject} =
(Cert#'Certificate'.tbsCertificate)#'TBSCertificate'.subject,
lists:flatmap(
fun(#'AttributeTypeAndValue'{type = ?'id-at-commonName',
value = Val}) ->
case 'PKIX1Explicit88':decode(
'X520CommonName', Val) of
{ok, {_, D1}} ->
D = if
is_list(D1) -> D1;
is_binary(D1) -> binary_to_list(D1);
true -> error
end,
if
D /= error ->
case jlib:nameprep(D) of
error ->
[];
LD ->
[LD]
end;
true ->
[]
end;
_ ->
[]
end;
(#'AttributeTypeAndValue'{type = ?'id-on-xmppAddr',
value = Val}) ->
case 'XmppAddr':decode(
'XmppAddr', Val) of
{ok, D} when is_binary(D) ->
case jlib:nameprep(binary_to_list(D)) of
error ->
[];
LD ->
[LD]
end;
_ ->
[]
end;
(_) ->
[]
end, lists:flatten(Subject)).

View File

@ -21,6 +21,7 @@
wait_for_stream/2,
wait_for_validation/2,
wait_for_features/2,
wait_for_auth_result/2,
wait_for_starttls_proceed/2,
stream_established/2,
handle_event/3,
@ -40,6 +41,8 @@
tls_required = false,
tls_enabled = false,
tls_options = [],
authenticated = false,
try_auth = true,
myname, server, queue,
new = false, verify = false,
timer}).
@ -276,23 +279,57 @@ wait_for_validation(closed, StateData) ->
wait_for_features({xmlstreamelement, El}, StateData) ->
case El of
{xmlelement, "stream:features", _Attrs, Els} ->
{StartTLS, StartTLSRequired} =
{SASLEXT, StartTLS, StartTLSRequired} =
lists:foldl(
fun({xmlelement, "starttls", Attrs1, Els1} = El1, Acc) ->
fun({xmlelement, "mechanisms", Attrs1, Els1} = El1,
{SEXT, STLS, STLSReq} = Acc) ->
case xml:get_attr_s("xmlns", Attrs1) of
?NS_SASL ->
NewSEXT =
lists:any(
fun({xmlelement, "mechanism", _, Els2}) ->
case xml:get_cdata(Els2) of
"EXTERNAL" -> true;
_ -> false
end;
(_) -> false
end, Els1),
{NewSEXT, STLS, STLSReq};
_ ->
Acc
end;
({xmlelement, "starttls", Attrs1, Els1} = El1,
{SEXT, STLS, STLSReq} = Acc) ->
case xml:get_attr_s("xmlns", Attrs1) of
?NS_TLS ->
Req = case xml:get_subtag(El1, "required") of
{xmlelement, _, _, _} -> true;
false -> false
end,
{true, Req};
{SEXT, true, Req};
_ ->
Acc
end;
(_, Acc) ->
Acc
end, {false, false}, Els),
end, {false, false, false}, Els),
if
(not SASLEXT) and (not StartTLS) and
StateData#state.authenticated ->
send_queue(StateData, StateData#state.queue),
{next_state, stream_established,
StateData#state{queue = queue:new()}};
SASLEXT and StateData#state.try_auth and
(StateData#state.new /= false) ->
send_element(StateData,
{xmlelement, "auth",
[{"xmlns", ?NS_SASL},
{"mechanism", "EXTERNAL"}],
[{xmlcdata,
jlib:encode_base64(
StateData#state.myname)}]}),
{next_state, wait_for_auth_result,
StateData#state{try_auth = false}};
StartTLS and StateData#state.tls and
(not StateData#state.tls_enabled) ->
StateData#state.receiver ! {change_timeout, 100},
@ -333,6 +370,66 @@ wait_for_features(closed, StateData) ->
{stop, normal, StateData}.
wait_for_auth_result({xmlstreamelement, El}, StateData) ->
case El of
{xmlelement, "success", Attrs, _Els} ->
case xml:get_attr_s("xmlns", Attrs) of
?NS_SASL ->
?INFO_MSG("auth: ~p", [{StateData#state.myname,
StateData#state.server}]),
ejabberd_receiver:reset_stream(
StateData#state.receiver),
send_text(StateData,
io_lib:format(?STREAM_HEADER,
[StateData#state.server,
" version='1.0'"])),
{next_state, wait_for_stream,
StateData#state{streamid = new_id(),
authenticated = true
}};
_ ->
send_text(StateData,
xml:element_to_string(?SERR_BAD_FORMAT) ++
?STREAM_TRAILER),
{stop, normal, StateData}
end;
{xmlelement, "failure", Attrs, _Els} ->
case xml:get_attr_s("xmlns", Attrs) of
?NS_SASL ->
?INFO_MSG("restarted: ~p", [{StateData#state.myname,
StateData#state.server}]),
(StateData#state.sockmod):close(StateData#state.socket),
gen_fsm:send_event(self(), init),
{next_state, open_socket,
StateData#state{socket = undefined}};
_ ->
send_text(StateData,
xml:element_to_string(?SERR_BAD_FORMAT) ++
?STREAM_TRAILER),
{stop, normal, StateData}
end;
_ ->
send_text(StateData,
xml:element_to_string(?SERR_BAD_FORMAT) ++
?STREAM_TRAILER),
{stop, normal, StateData}
end;
wait_for_auth_result({xmlstreamend, Name}, StateData) ->
{stop, normal, StateData};
wait_for_auth_result({xmlstreamerror, _}, StateData) ->
send_text(StateData,
?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_auth_result(timeout, StateData) ->
{stop, normal, StateData};
wait_for_auth_result(closed, StateData) ->
{stop, normal, StateData}.
wait_for_starttls_proceed({xmlstreamelement, El}, StateData) ->
case El of
{xmlelement, "proceed", Attrs, _Els} ->
@ -351,12 +448,15 @@ wait_for_starttls_proceed({xmlstreamelement, El}, StateData) ->
streamid = new_id(),
tls_enabled = true
},
R = send_text(NewStateData,
io_lib:format(?STREAM_HEADER,
[StateData#state.server,
" version='1.0'"])),
send_text(NewStateData,
io_lib:format(?STREAM_HEADER,
[StateData#state.server,
" version='1.0'"])),
{next_state, wait_for_stream, NewStateData};
_ ->
send_text(StateData,
xml:element_to_string(?SERR_BAD_FORMAT) ++
?STREAM_TRAILER),
{stop, normal, StateData}
end;
_ ->

View File

@ -12,14 +12,18 @@ ERLSHLIBS = ../tls_drv.so
OUTDIR = ..
EFLAGS = -I .. -pz ..
ASN_FLAGS = -bber_bin +der +compact_bit_string +optimize +noobj
OBJS = \
$(OUTDIR)/tls.beam
$(OUTDIR)/tls.beam $(OUTDIR)/XmppAddr.beam
all: $(OBJS) $(ERLSHLIBS)
$(OUTDIR)/%.beam: %.erl
@ERLC@ -W $(EFLAGS) -o $(OUTDIR) $<
%.erl: %.asn1
erlc $(ASN_FLAGS) $<
#all: $(ERLSHLIBS)
# erl -s make all report "{outdir, \"..\"}" -noinput -s erlang halt

14
src/tls/XmppAddr.asn1 Normal file
View File

@ -0,0 +1,14 @@
XmppAddr { iso(1) identified-organization(3)
dod(6) internet(1) security(5) mechanisms(5) pkix(7)
id-on(8) id-on-xmppAddr(5) }
DEFINITIONS EXPLICIT TAGS ::=
BEGIN
id-on-xmppAddr OBJECT IDENTIFIER ::= { iso(1) identified-organization(3)
dod(6) internet(1) security(5) mechanisms(5) pkix(7)
id-on(8) 5 }
XmppAddr ::= UTF8String
END

View File

@ -17,6 +17,8 @@
send/2,
recv/2, recv/3, recv_data/2,
close/1,
get_peer_certificate/1,
get_verify_result/1,
test/0]).
%% Internal exports, call-back functions.
@ -33,6 +35,8 @@
-define(SET_DECRYPTED_OUTPUT, 4).
-define(GET_ENCRYPTED_OUTPUT, 5).
-define(GET_DECRYPTED_INPUT, 6).
-define(GET_PEER_CERTIFICATE, 7).
-define(GET_VERIFY_RESULT, 8).
-record(tlssock, {tcpsock, tlsport}).
@ -69,15 +73,16 @@ handle_call(_, _, State) ->
handle_cast(_, State) ->
{noreply, State}.
handle_info({'EXIT', Pid, Reason}, Port) ->
{noreply, Port};
handle_info({'EXIT', Port, Reason}, Port) ->
{stop, {port_died, Reason}, Port};
handle_info({'EXIT', _Pid, _Reason}, Port) ->
{noreply, Port};
handle_info(_, State) ->
{noreply, State}.
code_change(OldVsn, State, Extra) ->
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
terminate(_Reason, Port) ->
@ -115,7 +120,7 @@ tls_to_tcp(#tlssock{tcpsock = TCPSocket, tlsport = Port}) ->
recv(Socket, Length) ->
recv(Socket, Length, infinity).
recv(#tlssock{tcpsock = TCPSocket, tlsport = Port} = TLSSock,
recv(#tlssock{tcpsock = TCPSocket} = TLSSock,
Length, Timeout) ->
case gen_tcp:recv(TCPSocket, Length, Timeout) of
{ok, Packet} ->
@ -133,6 +138,7 @@ recv_data(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) ->
<<0, Out/binary>> ->
case gen_tcp:send(TCPSocket, Out) of
ok ->
%io:format("IN: ~p~n", [{TCPSocket, binary_to_list(In)}]),
{ok, In};
Error ->
Error
@ -150,6 +156,7 @@ recv_data(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) ->
send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) ->
case port_control(Port, ?SET_DECRYPTED_OUTPUT, Packet) of
<<0>> ->
%io:format("OUT: ~p~n", [{TCPSocket, lists:flatten(Packet)}]),
case port_control(Port, ?GET_ENCRYPTED_OUTPUT, []) of
<<0, Out/binary>> ->
gen_tcp:send(TCPSocket, Out);
@ -159,8 +166,12 @@ send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) ->
<<1, Error/binary>> ->
{error, binary_to_list(Error)};
<<2>> -> % Dirty hack
receive after 100 -> ok end,
send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet)
receive
{timeout, _Timer, _} ->
{error, timeout}
after 100 ->
send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet)
end
end.
@ -168,6 +179,23 @@ close(#tlssock{tcpsock = TCPSocket, tlsport = Port}) ->
gen_tcp:close(TCPSocket),
port_close(Port).
get_peer_certificate(#tlssock{tlsport = Port}) ->
case port_control(Port, ?GET_PEER_CERTIFICATE, []) of
<<0, BCert/binary>> ->
case catch ssl_pkix:decode_cert(BCert, [pkix]) of
{ok, Cert} ->
{ok, Cert};
_ ->
error
end;
<<1>> ->
error
end.
get_verify_result(#tlssock{tlsport = Port}) ->
<<Res>> = port_control(Port, ?GET_VERIFY_RESULT, []),
Res.
test() ->
case erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv) of

View File

@ -46,12 +46,19 @@ static void tls_drv_stop(ErlDrvData handle)
}
static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
{
return 1;
}
#define SET_CERTIFICATE_FILE_ACCEPT 1
#define SET_CERTIFICATE_FILE_CONNECT 2
#define SET_ENCRYPTED_INPUT 3
#define SET_DECRYPTED_OUTPUT 4
#define GET_ENCRYPTED_OUTPUT 5
#define GET_DECRYPTED_INPUT 6
#define GET_PEER_CERTIFICATE 7
#define GET_VERIFY_RESULT 8
#define die_unless(cond, errstr) \
@ -75,6 +82,7 @@ static int tls_drv_control(ErlDrvData handle,
int res;
int size;
ErlDrvBinary *b;
X509 *cert;
switch (command)
{
@ -92,6 +100,15 @@ static int tls_drv_control(ErlDrvData handle,
res = SSL_CTX_check_private_key(d->ctx);
die_unless(res > 0, "SSL_CTX_check_private_key failed");
SSL_CTX_set_default_verify_paths(d->ctx);
if (command == SET_CERTIFICATE_FILE_ACCEPT)
{
SSL_CTX_set_verify(d->ctx,
SSL_VERIFY_PEER|SSL_VERIFY_CLIENT_ONCE,
verify_callback);
}
d->ssl = SSL_new(d->ctx);
die_unless(d->ssl, "SSL_new failed");
@ -182,6 +199,37 @@ static int tls_drv_control(ErlDrvData handle,
return rlen;
}
break;
case GET_PEER_CERTIFICATE:
cert = SSL_get_peer_certificate(d->ssl);
if (cert == NULL)
{
b = driver_alloc_binary(1);
b->orig_bytes[0] = 1;
*rbuf = (char *)b;
return 1;
} else {
unsigned char *tmp_buf;
rlen = i2d_X509(cert, NULL);
if (rlen >= 0)
{
rlen++;
b = driver_alloc_binary(rlen);
b->orig_bytes[0] = 0;
tmp_buf = &b->orig_bytes[1];
i2d_X509(cert, &tmp_buf);
X509_free(cert);
*rbuf = (char *)b;
return rlen;
} else
X509_free(cert);
}
break;
case GET_VERIFY_RESULT:
b = driver_alloc_binary(1);
b->orig_bytes[0] = SSL_get_verify_result(d->ssl);
*rbuf = (char *)b;
return 1;
break;
}
b = driver_alloc_binary(1);