24
1
mirror of https://github.com/processone/ejabberd.git synced 2024-07-02 23:06:21 +02:00

LDAP StartTLS support

Conflicts:

	src/ejabberd_auth_ldap.erl
	src/eldap/eldap.erl
	src/mod_shared_roster_ldap.erl
	src/mod_vcard_ldap.erl
	src/tls/tls.erl
This commit is contained in:
Evgeniy Khramtsov 2012-07-20 23:38:20 +10:00
parent a70c72e50e
commit aad740f34c
3 changed files with 193 additions and 100 deletions

View File

@ -57,6 +57,7 @@
%%% LDAP Client state machine. %%% LDAP Client state machine.
%%% Possible states are: %%% Possible states are:
%%% connecting - actually disconnected, but retrying periodically %%% connecting - actually disconnected, but retrying periodically
%%% wait_starttls_response - connected and send starttls request
%%% wait_bind_response - connected and sent bind request %%% wait_bind_response - connected and sent bind request
%%% active - bound to LDAP Server and ready to handle commands %%% active - bound to LDAP Server and ready to handle commands
%%% active_bind - sent bind() request and waiting for response %%% active_bind - sent bind() request and waiting for response
@ -81,6 +82,7 @@
%% gen_fsm callbacks %% gen_fsm callbacks
-export([init/1, connecting/2, connecting/3, -export([init/1, connecting/2, connecting/3,
wait_bind_response/3, active/3, active_bind/3, wait_bind_response/3, active/3, active_bind/3,
wait_starttls_response/3,
handle_event/3, handle_sync_event/4, handle_info/3, handle_event/3, handle_sync_event/4, handle_info/3,
terminate/3, code_change/4]). terminate/3, code_change/4]).
@ -94,6 +96,8 @@
-define(RETRY_TIMEOUT, 500). -define(RETRY_TIMEOUT, 500).
-define(STARTTLS_TIMEOUT, 10000).
-define(BIND_TIMEOUT, 10000). -define(BIND_TIMEOUT, 10000).
-define(CMD_TIMEOUT, 100000). -define(CMD_TIMEOUT, 100000).
@ -134,6 +138,7 @@
passwd = <<"">> :: binary(), passwd = <<"">> :: binary(),
id = 0 :: non_neg_integer(), id = 0 :: non_neg_integer(),
bind_timer = make_ref() :: reference(), bind_timer = make_ref() :: reference(),
starttls_timer = make_ref() :: reference(),
dict = dict:new() :: dict(), dict = dict:new() :: dict(),
req_q = queue:new() :: queue()}). req_q = queue:new() :: queue()}).
@ -431,6 +436,7 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) ->
(none) -> none (none) -> none
end) of end) of
tls -> tls; tls -> tls;
starttls -> starttls;
_ -> none _ -> none
end, end,
PortTemp = case Port of PortTemp = case Port of
@ -450,9 +456,21 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) ->
end) of end) of
undefined -> undefined ->
[]; [];
Path -> CAPath ->
[{cacertfile, Path}] [{cacertfile, CAPath}]
end, end,
CertOpts = case gen_mod:get_opt(
tls_certfile, Opts,
fun(S) when is_binary(S) ->
binary_to_list(S);
(undefined) ->
undefined
end) of
undefined ->
[];
CPath ->
[{certfile, CPath}]
end,
DepthOpts = case gen_mod:get_opt( DepthOpts = case gen_mod:get_opt(
tls_depth, Opts, tls_depth, Opts,
fun(I) when is_integer(I), I>=0 -> fun(I) when is_integer(I), I>=0 ->
@ -471,17 +489,18 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) ->
(false) -> false (false) -> false
end, false), end, false),
TLSOpts = if (Verify == hard orelse Verify == soft) TLSOpts = if (Verify == hard orelse Verify == soft)
andalso CacertOpts == [] -> andalso CacertOpts == [] ->
?WARNING_MSG("TLS verification is enabled but no CA " ?WARNING_MSG("TLS verification is enabled but no CA "
"certfiles configured, so verification " "certfiles configured, so verification "
"is disabled.", "is disabled.",
[]), []),
[]; [];
Verify == soft -> Verify == soft ->
[{verify, 1}] ++ CacertOpts ++ DepthOpts; [{verify, 1}] ++ CacertOpts ++ DepthOpts ++ CertOpts;
Verify == hard -> Verify == hard ->
[{verify, 2}] ++ CacertOpts ++ DepthOpts; [{verify, 2}] ++ CacertOpts ++ DepthOpts ++ CertOpts;
true -> [] true ->
CacertOpts ++ DepthOpts ++ CertOpts
end, end,
{ok, connecting, {ok, connecting,
#eldap{hosts = Hosts, port = PortTemp, rootdn = Rootdn, #eldap{hosts = Hosts, port = PortTemp, rootdn = Rootdn,
@ -490,13 +509,16 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) ->
0}. 0}.
connecting(timeout, S) -> connecting(timeout, S) ->
{ok, NextState, NewS} = connect_bind(S), connect_bind(S).
{next_state, NextState, NewS}.
connecting(Event, From, S) -> connecting(Event, From, S) ->
Q = queue:in({Event, From}, S#eldap.req_q), Q = queue:in({Event, From}, S#eldap.req_q),
{next_state, connecting, S#eldap{req_q = Q}}. {next_state, connecting, S#eldap{req_q = Q}}.
wait_starttls_response(Event, From, S) ->
Q = queue:in({Event, From}, S#eldap.req_q),
{next_state, wait_starttls_response, S#eldap{req_q=Q}}.
wait_bind_response(Event, From, S) -> wait_bind_response(Event, From, S) ->
Q = queue:in({Event, From}, S#eldap.req_q), Q = queue:in({Event, From}, S#eldap.req_q),
{next_state, wait_bind_response, S#eldap{req_q = Q}}. {next_state, wait_bind_response, S#eldap{req_q = Q}}.
@ -525,12 +547,40 @@ handle_sync_event(_Event, _From, StateName, S) ->
%%---------------------------------------------------------------------- %%----------------------------------------------------------------------
handle_info({Tag, _Socket, Data}, connecting, S) handle_info({Tag, _Socket, Data}, connecting, S)
when Tag == tcp; Tag == ssl -> when Tag == tcp; Tag == ssl ->
?DEBUG("tcp packet received when disconnected!~n~p", activate_socket(S),
[Data]), ?DEBUG("tcp packet received when disconnected!~n~p", [Data]),
{next_state, connecting, S}; {next_state, connecting, S};
handle_info({Tag, _Socket, Data}, wait_bind_response, S)
when Tag == tcp; Tag == ssl -> handle_info({Tag, Socket, Data}, StateName, S)
when Tag == tcp; Tag == ssl ->
activate_socket(S),
handle_info({asn1, Socket, Data}, StateName, S);
handle_info({asn1, Socket, Data}, wait_starttls_response, S) ->
cancel_timer(S#eldap.starttls_timer),
case catch recvd_wait_starttls_response(Data, S) of
proceed ->
case ssl:connect(Socket, S#eldap.tls_options,
?STARTTLS_TIMEOUT) of
{error, Reason} ->
report_starttls_failure(S#eldap.host, S#eldap.port, Reason);
{ok, SSLSock} ->
bind_request(S#eldap{fd = SSLSock, sockmod = ssl,
id = bump_id(S)})
end;
{fail_starttls, Reason} ->
report_starttls_failure(S#eldap.host, S#eldap.port, Reason),
{next_state, connecting, close_and_retry(S, ?GRACEFUL_RETRY_TIMEOUT)};
{'EXIT', Reason} ->
report_starttls_failure(S#eldap.host, S#eldap.port, Reason),
{next_state, connecting, close_and_retry(S)};
{error, Reason} ->
report_starttls_failure(S#eldap.host, S#eldap.port, Reason),
{next_state, connecting, close_and_retry(S)}
end;
handle_info({asn1, _Socket, Data}, wait_bind_response, S) ->
cancel_timer(S#eldap.bind_timer), cancel_timer(S#eldap.bind_timer),
case catch recvd_wait_bind_response(Data, S) of case catch recvd_wait_bind_response(Data, S) of
bound -> dequeue_commands(S); bound -> dequeue_commands(S);
@ -545,10 +595,9 @@ handle_info({Tag, _Socket, Data}, wait_bind_response, S)
report_bind_failure(S#eldap.host, S#eldap.port, Reason), report_bind_failure(S#eldap.host, S#eldap.port, Reason),
{next_state, connecting, close_and_retry(S)} {next_state, connecting, close_and_retry(S)}
end; end;
handle_info({Tag, _Socket, Data}, StateName, S)
when (StateName == active orelse handle_info({asn1, _Socket, Data}, StateName, S)
StateName == active_bind) when (StateName == active orelse StateName == active_bind) ->
andalso (Tag == tcp orelse Tag == ssl) ->
case catch recvd_packet(Data, S) of case catch recvd_packet(Data, S) of
{response, Response, RequestType} -> {response, Response, RequestType} ->
NewS = case Response of NewS = case Response of
@ -586,11 +635,14 @@ handle_info({timeout, Timer, {cmd_timeout, Id}},
{error, _Reason} -> {next_state, StateName, S} {error, _Reason} -> {next_state, StateName, S}
end; end;
handle_info({timeout, retry_connect}, connecting, S) -> handle_info({timeout, retry_connect}, connecting, S) ->
{ok, NextState, NewS} = connect_bind(S), connect_bind(S);
{next_state, NextState, NewS};
handle_info({timeout, _Timer, bind_timeout}, handle_info({timeout, _Timer, bind_timeout}, wait_bind_response, S) ->
wait_bind_response, S) ->
{next_state, connecting, close_and_retry(S)}; {next_state, connecting, close_and_retry(S)};
handle_info({timeout, _Timer, starttls_timeout}, wait_starttls_response, S) ->
{next_state, connecting, close_and_retry(S)};
%% %%
%% Make sure we don't fill the message queue with rubbish %% Make sure we don't fill the message queue with rubbish
%% %%
@ -633,17 +685,15 @@ send_command(Command, From, S) ->
{Name, Request} = gen_req(Command), {Name, Request} = gen_req(Command),
Message = #'LDAPMessage'{messageID = Id, Message = #'LDAPMessage'{messageID = Id,
protocolOp = {Name, Request}}, protocolOp = {Name, Request}},
?DEBUG("~p~n", [{Name, Request}]), ?DEBUG("~p~n",[Message]),
{ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', {ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', Message),
Message), case (S#eldap.sockmod):send(S#eldap.fd, iolist_to_binary(Bytes)) of
case (S#eldap.sockmod):send(S#eldap.fd, Bytes) of ok ->
ok -> Timer = erlang:start_timer(?CMD_TIMEOUT, self(), {cmd_timeout, Id}),
Timer = erlang:start_timer(?CMD_TIMEOUT, self(), New_dict = dict:store(Id, [{Timer, Command, From, Name}], S#eldap.dict),
{cmd_timeout, Id}), {ok, S#eldap{id = Id, dict = New_dict}};
New_dict = dict:store(Id, Error ->
[{Timer, Command, From, Name}], S#eldap.dict), Error
{ok, S#eldap{id = Id, dict = New_dict}};
Error -> Error
end. end.
gen_req({search, A}) -> gen_req({search, A}) ->
@ -683,7 +733,6 @@ gen_req({bind, RootDN, Passwd}) ->
authentication = {simple, Passwd}}}. authentication = {simple, Passwd}}}.
recvd_packet(Pkt, S) -> recvd_packet(Pkt, S) ->
check_tag(Pkt),
case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of
{ok, Msg} -> {ok, Msg} ->
Op = Msg#'LDAPMessage'.protocolOp, Op = Msg#'LDAPMessage'.protocolOp,
@ -791,7 +840,6 @@ get_op_rec(Id, Dict) ->
end. end.
recvd_wait_bind_response(Pkt, S) -> recvd_wait_bind_response(Pkt, S) ->
check_tag(Pkt),
case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of
{ok, Msg} -> {ok, Msg} ->
?DEBUG("~p", [Msg]), ?DEBUG("~p", [Msg]),
@ -806,35 +854,43 @@ recvd_wait_bind_response(Pkt, S) ->
Else -> {fail_bind, Else} Else -> {fail_bind, Else}
end. end.
recvd_wait_starttls_response(Pkt, S) ->
case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of
{ok, Msg} ->
?DEBUG("~p", [Msg]),
check_id(S#eldap.id, Msg#'LDAPMessage'.messageID),
case Msg#'LDAPMessage'.protocolOp of
{extendedResp, Result} ->
case Result#'ExtendedResponse'.resultCode of
success -> proceed;
Error -> {fail_starttls, Error}
end
end;
Else ->
{error, Else}
end.
check_id(Id, Id) -> ok; check_id(Id, Id) -> ok;
check_id(_, _) -> throw({error, wrong_bind_id}). check_id(_, _) -> throw({error, wrong_bind_id}).
%%----------------------------------------------------------------------- %%-----------------------------------------------------------------------
%% General Helpers %% General Helpers
%%----------------------------------------------------------------------- %%-----------------------------------------------------------------------
cancel_timer(Timer) -> cancel_timer(Timer) ->
erlang:cancel_timer(Timer), erlang:cancel_timer(Timer),
receive {timeout, Timer, _} -> ok after 0 -> ok end. receive {timeout, Timer, _} -> ok after 0 -> ok end.
check_tag(Data) ->
{_Tag, Data1, _Rb} = asn1rt_ber_bin:decode_tag(Data),
{{_Len, _Data2}, _Rb2} =
asn1rt_ber_bin:decode_length(Data1),
ok.
close_and_retry(S, Timeout) -> close_and_retry(S, Timeout) ->
catch (S#eldap.sockmod):close(S#eldap.fd), catch (S#eldap.sockmod):close(S#eldap.fd),
Queue = dict:fold(fun (_Id, Queue = dict:fold(
[{Timer, Command, From, _Name} | _], Q) -> fun(_Id, [{Timer, Command, From, _Name}|_], Q) ->
cancel_timer(Timer), cancel_timer(Timer),
queue:in_r({Command, From}, Q); queue:in_r({Command, From}, Q);
(_, _, Q) -> Q (_, _, Q) ->
end, Q
S#eldap.req_q, S#eldap.dict), end, S#eldap.req_q, S#eldap.dict),
erlang:send_after(Timeout, self(), erlang:send_after(Timeout, self(), {timeout, retry_connect}),
{timeout, retry_connect}), S#eldap{fd=null, req_q=Queue, dict=dict:new()}.
S#eldap{fd = undefined, req_q = Queue, dict = dict:new()}.
close_and_retry(S) -> close_and_retry(S) ->
close_and_retry(S, ?RETRY_TIMEOUT). close_and_retry(S, ?RETRY_TIMEOUT).
@ -843,6 +899,15 @@ report_bind_failure(Host, Port, Reason) ->
?WARNING_MSG("LDAP bind failed on ~s:~p~nReason: ~p", ?WARNING_MSG("LDAP bind failed on ~s:~p~nReason: ~p",
[Host, Port, Reason]). [Host, Port, Reason]).
report_starttls_failure(Host, Port, Reason) ->
?WARNING_MSG("LDAP StartTLS failed:~n"
"** Server: ~s:~p~n"
"** Reason: ~p",
[Host, Port, Reason]).
%%-----------------------------------------------------------------------
%% Sort out timed out commands
%%-----------------------------------------------------------------------
cmd_timeout(Timer, Id, S) -> cmd_timeout(Timer, Id, S) ->
Dict = S#eldap.dict, Dict = S#eldap.dict,
case dict:find(Id, Dict) of case dict:find(Id, Dict) of
@ -891,58 +956,74 @@ connect_bind(S) ->
Host = next_host(S#eldap.host, S#eldap.hosts), Host = next_host(S#eldap.host, S#eldap.hosts),
?INFO_MSG("LDAP connection on ~s:~p", ?INFO_MSG("LDAP connection on ~s:~p",
[Host, S#eldap.port]), [Host, S#eldap.port]),
Opts = if S#eldap.tls == tls -> Opts = case S#eldap.tls of
[{packet, asn1}, {active, true}, {keepalive, true}, tls ->
binary [{packet, asn1}, {active, once}, {keepalive, true},
| S#eldap.tls_options]; binary | S#eldap.tls_options];
true -> _ ->
[{packet, asn1}, {active, true}, {keepalive, true}, [{packet, asn1}, {active, once}, {keepalive, true},
{send_timeout, ?SEND_TIMEOUT}, binary] {send_timeout, ?SEND_TIMEOUT}, binary]
end, end,
HostS = binary_to_list(Host), HostS = binary_to_list(Host),
SocketData = case S#eldap.tls of SockMod = case S#eldap.tls of
tls -> tls -> ssl;
SockMod = ssl, ssl:connect(HostS, S#eldap.port, Opts); _ -> gen_tcp
%% starttls -> %% TODO: Implement STARTTLS; end,
_ -> case SockMod:connect(HostS, S#eldap.port, Opts) of
SockMod = gen_tcp, {ok, Socket} ->
gen_tcp:connect(HostS, S#eldap.port, Opts) Id = bump_id(S),
end, NewS = S#eldap{host = Host, sockmod = SockMod,
case SocketData of id = Id, fd = Socket},
{ok, Socket} -> case S#eldap.tls of
case bind_request(Socket, S#eldap{sockmod = SockMod}) of starttls ->
{ok, NewS} -> starttls_request(NewS);
Timer = erlang:start_timer(?BIND_TIMEOUT, self(), _ ->
{timeout, bind_timeout}), bind_request(NewS)
{ok, wait_bind_response, end;
NewS#eldap{fd = Socket, sockmod = SockMod, host = Host, {error, Reason} ->
bind_timer = Timer}}; ?ERROR_MSG("LDAP connection failed:~n"
{error, Reason} -> "** Server: ~s:~p~n"
report_bind_failure(Host, S#eldap.port, Reason), "** Reason: ~p~n"
NewS = close_and_retry(S), "** Socket options: ~p",
{ok, connecting, NewS#eldap{host = Host}} [Host, S#eldap.port, Reason, Opts]),
end; NewS = close_and_retry(S),
{error, Reason} -> {next_state, connecting, NewS#eldap{host = Host}}
?ERROR_MSG("LDAP connection failed:~n** Server: "
"~s:~p~n** Reason: ~p~n** Socket options: ~p",
[Host, S#eldap.port, Reason, Opts]),
NewS = close_and_retry(S),
{ok, connecting, NewS#eldap{host = Host}}
end. end.
bind_request(Socket, S) -> bind_request(#eldap{fd = Socket, id = Id} = S) ->
Id = bump_id(S),
Req = #'BindRequest'{version = S#eldap.version, Req = #'BindRequest'{version = S#eldap.version,
name = S#eldap.rootdn, name = S#eldap.rootdn,
authentication = {simple, S#eldap.passwd}}, authentication = {simple, S#eldap.passwd}},
Message = #'LDAPMessage'{messageID = Id, Message = #'LDAPMessage'{messageID = Id,
protocolOp = {bindRequest, Req}}, protocolOp = {bindRequest, Req}},
?DEBUG("Bind Request Message:~p~n", [Message]), ?DEBUG("Bind Request Message:~p~n",[Message]),
{ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', {ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', Message),
Message),
case (S#eldap.sockmod):send(Socket, Bytes) of case (S#eldap.sockmod):send(Socket, Bytes) of
ok -> {ok, S#eldap{id = Id}}; ok ->
Error -> Error Timer = erlang:start_timer(?BIND_TIMEOUT, self(),
bind_timeout),
{next_state, wait_bind_response, S#eldap{bind_timer = Timer}};
{error, Reason} ->
report_bind_failure(S#eldap.host, S#eldap.port, Reason),
NewS = close_and_retry(S),
{next_state, connecting, NewS}
end.
starttls_request(#eldap{fd = Socket, id = Id} = S) ->
Req = #'ExtendedRequest'{requestName = ?STARTTLS},
Message = #'LDAPMessage'{messageID = Id,
protocolOp = {extendedReq, Req}},
?DEBUG("StartTLS Request Message: ~p~n", [Message]),
{ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', Message),
case (S#eldap.sockmod):send(Socket, Bytes) of
ok ->
Timer = erlang:start_timer(?STARTTLS_TIMEOUT, self(),
starttls_timeout),
{next_state, wait_starttls_response, S#eldap{starttls_timer = Timer}};
{error, Reason} ->
report_starttls_failure(S#eldap.host, S#eldap.port, Reason),
NewS = close_and_retry(S),
{next_state, connecting, NewS}
end. end.
next_host(undefined, [H | _]) -> next_host(undefined, [H | _]) ->
@ -963,4 +1044,12 @@ next_host(Host, [_ | T], Hosts) ->
bump_id(#eldap{id = Id}) bump_id(#eldap{id = Id})
when Id > (?MAX_TRANSACTION_ID) -> when Id > (?MAX_TRANSACTION_ID) ->
?MIN_TRANSACTION_ID; ?MIN_TRANSACTION_ID;
bump_id(#eldap{id = Id}) -> Id + 1. bump_id(#eldap{id = Id}) -> Id + 1.
activate_socket(#eldap{sockmod = SockMod, fd = Sock}) ->
if SockMod == gen_tcp ->
inet:setopts(Sock, [{active, once}]);
true ->
SockMod:setopts(Sock, [{active, once}])
end.

View File

@ -46,6 +46,7 @@
-type tlsopts() :: [{encrypt, tls | starttls | none} | -type tlsopts() :: [{encrypt, tls | starttls | none} |
{tls_cacertfile, binary() | undefined} | {tls_cacertfile, binary() | undefined} |
{tls_certfile, binary() | undefined} |
{tls_depth, non_neg_integer() | undefined} | {tls_depth, non_neg_integer() | undefined} |
{tls_verify, hard | soft | false}]. {tls_verify, hard | soft | false}].

View File

@ -207,6 +207,8 @@ get_config(Host, Opts) ->
(soft) -> soft; (soft) -> soft;
(false) -> false (false) -> false
end, false), end, false),
TLSCFile = get_opt({ldap_tls_certfile, Host}, Opts,
fun iolist_to_binary/1),
TLSCAFile = get_opt({ldap_tls_cacertfile, Host}, Opts, TLSCAFile = get_opt({ldap_tls_cacertfile, Host}, Opts,
fun iolist_to_binary/1), fun iolist_to_binary/1),
TLSDepth = get_opt({ldap_tls_depth, Host}, Opts, TLSDepth = get_opt({ldap_tls_depth, Host}, Opts,
@ -237,6 +239,7 @@ get_config(Host, Opts) ->
backups = Backups, backups = Backups,
tls_options = [{encrypt, Encrypt}, tls_options = [{encrypt, Encrypt},
{tls_verify, TLSVerify}, {tls_verify, TLSVerify},
{tls_certfile, TLSCFile},
{tls_cacertfile, TLSCAFile}, {tls_cacertfile, TLSCAFile},
{tls_depth, TLSDepth}], {tls_depth, TLSDepth}],
port = Port, port = Port,