24
1
mirror of https://github.com/processone/ejabberd.git synced 2024-06-30 23:02:00 +02:00

LDAP StartTLS support

Conflicts:

	src/ejabberd_auth_ldap.erl
	src/eldap/eldap.erl
	src/mod_shared_roster_ldap.erl
	src/mod_vcard_ldap.erl
	src/tls/tls.erl
This commit is contained in:
Evgeniy Khramtsov 2012-07-20 23:38:20 +10:00
parent a70c72e50e
commit aad740f34c
3 changed files with 193 additions and 100 deletions

View File

@ -57,6 +57,7 @@
%%% LDAP Client state machine.
%%% Possible states are:
%%% connecting - actually disconnected, but retrying periodically
%%% wait_starttls_response - connected and send starttls request
%%% wait_bind_response - connected and sent bind request
%%% active - bound to LDAP Server and ready to handle commands
%%% active_bind - sent bind() request and waiting for response
@ -81,6 +82,7 @@
%% gen_fsm callbacks
-export([init/1, connecting/2, connecting/3,
wait_bind_response/3, active/3, active_bind/3,
wait_starttls_response/3,
handle_event/3, handle_sync_event/4, handle_info/3,
terminate/3, code_change/4]).
@ -94,6 +96,8 @@
-define(RETRY_TIMEOUT, 500).
-define(STARTTLS_TIMEOUT, 10000).
-define(BIND_TIMEOUT, 10000).
-define(CMD_TIMEOUT, 100000).
@ -134,6 +138,7 @@
passwd = <<"">> :: binary(),
id = 0 :: non_neg_integer(),
bind_timer = make_ref() :: reference(),
starttls_timer = make_ref() :: reference(),
dict = dict:new() :: dict(),
req_q = queue:new() :: queue()}).
@ -431,6 +436,7 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) ->
(none) -> none
end) of
tls -> tls;
starttls -> starttls;
_ -> none
end,
PortTemp = case Port of
@ -450,9 +456,21 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) ->
end) of
undefined ->
[];
Path ->
[{cacertfile, Path}]
CAPath ->
[{cacertfile, CAPath}]
end,
CertOpts = case gen_mod:get_opt(
tls_certfile, Opts,
fun(S) when is_binary(S) ->
binary_to_list(S);
(undefined) ->
undefined
end) of
undefined ->
[];
CPath ->
[{certfile, CPath}]
end,
DepthOpts = case gen_mod:get_opt(
tls_depth, Opts,
fun(I) when is_integer(I), I>=0 ->
@ -471,17 +489,18 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) ->
(false) -> false
end, false),
TLSOpts = if (Verify == hard orelse Verify == soft)
andalso CacertOpts == [] ->
?WARNING_MSG("TLS verification is enabled but no CA "
"certfiles configured, so verification "
"is disabled.",
[]),
[];
andalso CacertOpts == [] ->
?WARNING_MSG("TLS verification is enabled but no CA "
"certfiles configured, so verification "
"is disabled.",
[]),
[];
Verify == soft ->
[{verify, 1}] ++ CacertOpts ++ DepthOpts;
[{verify, 1}] ++ CacertOpts ++ DepthOpts ++ CertOpts;
Verify == hard ->
[{verify, 2}] ++ CacertOpts ++ DepthOpts;
true -> []
[{verify, 2}] ++ CacertOpts ++ DepthOpts ++ CertOpts;
true ->
CacertOpts ++ DepthOpts ++ CertOpts
end,
{ok, connecting,
#eldap{hosts = Hosts, port = PortTemp, rootdn = Rootdn,
@ -490,13 +509,16 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) ->
0}.
connecting(timeout, S) ->
{ok, NextState, NewS} = connect_bind(S),
{next_state, NextState, NewS}.
connect_bind(S).
connecting(Event, From, S) ->
Q = queue:in({Event, From}, S#eldap.req_q),
{next_state, connecting, S#eldap{req_q = Q}}.
wait_starttls_response(Event, From, S) ->
Q = queue:in({Event, From}, S#eldap.req_q),
{next_state, wait_starttls_response, S#eldap{req_q=Q}}.
wait_bind_response(Event, From, S) ->
Q = queue:in({Event, From}, S#eldap.req_q),
{next_state, wait_bind_response, S#eldap{req_q = Q}}.
@ -525,12 +547,40 @@ handle_sync_event(_Event, _From, StateName, S) ->
%%----------------------------------------------------------------------
handle_info({Tag, _Socket, Data}, connecting, S)
when Tag == tcp; Tag == ssl ->
?DEBUG("tcp packet received when disconnected!~n~p",
[Data]),
when Tag == tcp; Tag == ssl ->
activate_socket(S),
?DEBUG("tcp packet received when disconnected!~n~p", [Data]),
{next_state, connecting, S};
handle_info({Tag, _Socket, Data}, wait_bind_response, S)
when Tag == tcp; Tag == ssl ->
handle_info({Tag, Socket, Data}, StateName, S)
when Tag == tcp; Tag == ssl ->
activate_socket(S),
handle_info({asn1, Socket, Data}, StateName, S);
handle_info({asn1, Socket, Data}, wait_starttls_response, S) ->
cancel_timer(S#eldap.starttls_timer),
case catch recvd_wait_starttls_response(Data, S) of
proceed ->
case ssl:connect(Socket, S#eldap.tls_options,
?STARTTLS_TIMEOUT) of
{error, Reason} ->
report_starttls_failure(S#eldap.host, S#eldap.port, Reason);
{ok, SSLSock} ->
bind_request(S#eldap{fd = SSLSock, sockmod = ssl,
id = bump_id(S)})
end;
{fail_starttls, Reason} ->
report_starttls_failure(S#eldap.host, S#eldap.port, Reason),
{next_state, connecting, close_and_retry(S, ?GRACEFUL_RETRY_TIMEOUT)};
{'EXIT', Reason} ->
report_starttls_failure(S#eldap.host, S#eldap.port, Reason),
{next_state, connecting, close_and_retry(S)};
{error, Reason} ->
report_starttls_failure(S#eldap.host, S#eldap.port, Reason),
{next_state, connecting, close_and_retry(S)}
end;
handle_info({asn1, _Socket, Data}, wait_bind_response, S) ->
cancel_timer(S#eldap.bind_timer),
case catch recvd_wait_bind_response(Data, S) of
bound -> dequeue_commands(S);
@ -545,10 +595,9 @@ handle_info({Tag, _Socket, Data}, wait_bind_response, S)
report_bind_failure(S#eldap.host, S#eldap.port, Reason),
{next_state, connecting, close_and_retry(S)}
end;
handle_info({Tag, _Socket, Data}, StateName, S)
when (StateName == active orelse
StateName == active_bind)
andalso (Tag == tcp orelse Tag == ssl) ->
handle_info({asn1, _Socket, Data}, StateName, S)
when (StateName == active orelse StateName == active_bind) ->
case catch recvd_packet(Data, S) of
{response, Response, RequestType} ->
NewS = case Response of
@ -586,11 +635,14 @@ handle_info({timeout, Timer, {cmd_timeout, Id}},
{error, _Reason} -> {next_state, StateName, S}
end;
handle_info({timeout, retry_connect}, connecting, S) ->
{ok, NextState, NewS} = connect_bind(S),
{next_state, NextState, NewS};
handle_info({timeout, _Timer, bind_timeout},
wait_bind_response, S) ->
connect_bind(S);
handle_info({timeout, _Timer, bind_timeout}, wait_bind_response, S) ->
{next_state, connecting, close_and_retry(S)};
handle_info({timeout, _Timer, starttls_timeout}, wait_starttls_response, S) ->
{next_state, connecting, close_and_retry(S)};
%%
%% Make sure we don't fill the message queue with rubbish
%%
@ -633,17 +685,15 @@ send_command(Command, From, S) ->
{Name, Request} = gen_req(Command),
Message = #'LDAPMessage'{messageID = Id,
protocolOp = {Name, Request}},
?DEBUG("~p~n", [{Name, Request}]),
{ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage',
Message),
case (S#eldap.sockmod):send(S#eldap.fd, Bytes) of
ok ->
Timer = erlang:start_timer(?CMD_TIMEOUT, self(),
{cmd_timeout, Id}),
New_dict = dict:store(Id,
[{Timer, Command, From, Name}], S#eldap.dict),
{ok, S#eldap{id = Id, dict = New_dict}};
Error -> Error
?DEBUG("~p~n",[Message]),
{ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', Message),
case (S#eldap.sockmod):send(S#eldap.fd, iolist_to_binary(Bytes)) of
ok ->
Timer = erlang:start_timer(?CMD_TIMEOUT, self(), {cmd_timeout, Id}),
New_dict = dict:store(Id, [{Timer, Command, From, Name}], S#eldap.dict),
{ok, S#eldap{id = Id, dict = New_dict}};
Error ->
Error
end.
gen_req({search, A}) ->
@ -683,7 +733,6 @@ gen_req({bind, RootDN, Passwd}) ->
authentication = {simple, Passwd}}}.
recvd_packet(Pkt, S) ->
check_tag(Pkt),
case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of
{ok, Msg} ->
Op = Msg#'LDAPMessage'.protocolOp,
@ -791,7 +840,6 @@ get_op_rec(Id, Dict) ->
end.
recvd_wait_bind_response(Pkt, S) ->
check_tag(Pkt),
case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of
{ok, Msg} ->
?DEBUG("~p", [Msg]),
@ -806,35 +854,43 @@ recvd_wait_bind_response(Pkt, S) ->
Else -> {fail_bind, Else}
end.
recvd_wait_starttls_response(Pkt, S) ->
case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of
{ok, Msg} ->
?DEBUG("~p", [Msg]),
check_id(S#eldap.id, Msg#'LDAPMessage'.messageID),
case Msg#'LDAPMessage'.protocolOp of
{extendedResp, Result} ->
case Result#'ExtendedResponse'.resultCode of
success -> proceed;
Error -> {fail_starttls, Error}
end
end;
Else ->
{error, Else}
end.
check_id(Id, Id) -> ok;
check_id(_, _) -> throw({error, wrong_bind_id}).
%%-----------------------------------------------------------------------
%% General Helpers
%%-----------------------------------------------------------------------
cancel_timer(Timer) ->
erlang:cancel_timer(Timer),
receive {timeout, Timer, _} -> ok after 0 -> ok end.
check_tag(Data) ->
{_Tag, Data1, _Rb} = asn1rt_ber_bin:decode_tag(Data),
{{_Len, _Data2}, _Rb2} =
asn1rt_ber_bin:decode_length(Data1),
ok.
close_and_retry(S, Timeout) ->
catch (S#eldap.sockmod):close(S#eldap.fd),
Queue = dict:fold(fun (_Id,
[{Timer, Command, From, _Name} | _], Q) ->
cancel_timer(Timer),
queue:in_r({Command, From}, Q);
(_, _, Q) -> Q
end,
S#eldap.req_q, S#eldap.dict),
erlang:send_after(Timeout, self(),
{timeout, retry_connect}),
S#eldap{fd = undefined, req_q = Queue, dict = dict:new()}.
Queue = dict:fold(
fun(_Id, [{Timer, Command, From, _Name}|_], Q) ->
cancel_timer(Timer),
queue:in_r({Command, From}, Q);
(_, _, Q) ->
Q
end, S#eldap.req_q, S#eldap.dict),
erlang:send_after(Timeout, self(), {timeout, retry_connect}),
S#eldap{fd=null, req_q=Queue, dict=dict:new()}.
close_and_retry(S) ->
close_and_retry(S, ?RETRY_TIMEOUT).
@ -843,6 +899,15 @@ report_bind_failure(Host, Port, Reason) ->
?WARNING_MSG("LDAP bind failed on ~s:~p~nReason: ~p",
[Host, Port, Reason]).
report_starttls_failure(Host, Port, Reason) ->
?WARNING_MSG("LDAP StartTLS failed:~n"
"** Server: ~s:~p~n"
"** Reason: ~p",
[Host, Port, Reason]).
%%-----------------------------------------------------------------------
%% Sort out timed out commands
%%-----------------------------------------------------------------------
cmd_timeout(Timer, Id, S) ->
Dict = S#eldap.dict,
case dict:find(Id, Dict) of
@ -891,58 +956,74 @@ connect_bind(S) ->
Host = next_host(S#eldap.host, S#eldap.hosts),
?INFO_MSG("LDAP connection on ~s:~p",
[Host, S#eldap.port]),
Opts = if S#eldap.tls == tls ->
[{packet, asn1}, {active, true}, {keepalive, true},
binary
| S#eldap.tls_options];
true ->
[{packet, asn1}, {active, true}, {keepalive, true},
{send_timeout, ?SEND_TIMEOUT}, binary]
end,
Opts = case S#eldap.tls of
tls ->
[{packet, asn1}, {active, once}, {keepalive, true},
binary | S#eldap.tls_options];
_ ->
[{packet, asn1}, {active, once}, {keepalive, true},
{send_timeout, ?SEND_TIMEOUT}, binary]
end,
HostS = binary_to_list(Host),
SocketData = case S#eldap.tls of
tls ->
SockMod = ssl, ssl:connect(HostS, S#eldap.port, Opts);
%% starttls -> %% TODO: Implement STARTTLS;
_ ->
SockMod = gen_tcp,
gen_tcp:connect(HostS, S#eldap.port, Opts)
end,
case SocketData of
{ok, Socket} ->
case bind_request(Socket, S#eldap{sockmod = SockMod}) of
{ok, NewS} ->
Timer = erlang:start_timer(?BIND_TIMEOUT, self(),
{timeout, bind_timeout}),
{ok, wait_bind_response,
NewS#eldap{fd = Socket, sockmod = SockMod, host = Host,
bind_timer = Timer}};
{error, Reason} ->
report_bind_failure(Host, S#eldap.port, Reason),
NewS = close_and_retry(S),
{ok, connecting, NewS#eldap{host = Host}}
end;
{error, Reason} ->
?ERROR_MSG("LDAP connection failed:~n** Server: "
"~s:~p~n** Reason: ~p~n** Socket options: ~p",
[Host, S#eldap.port, Reason, Opts]),
NewS = close_and_retry(S),
{ok, connecting, NewS#eldap{host = Host}}
SockMod = case S#eldap.tls of
tls -> ssl;
_ -> gen_tcp
end,
case SockMod:connect(HostS, S#eldap.port, Opts) of
{ok, Socket} ->
Id = bump_id(S),
NewS = S#eldap{host = Host, sockmod = SockMod,
id = Id, fd = Socket},
case S#eldap.tls of
starttls ->
starttls_request(NewS);
_ ->
bind_request(NewS)
end;
{error, Reason} ->
?ERROR_MSG("LDAP connection failed:~n"
"** Server: ~s:~p~n"
"** Reason: ~p~n"
"** Socket options: ~p",
[Host, S#eldap.port, Reason, Opts]),
NewS = close_and_retry(S),
{next_state, connecting, NewS#eldap{host = Host}}
end.
bind_request(Socket, S) ->
Id = bump_id(S),
bind_request(#eldap{fd = Socket, id = Id} = S) ->
Req = #'BindRequest'{version = S#eldap.version,
name = S#eldap.rootdn,
authentication = {simple, S#eldap.passwd}},
authentication = {simple, S#eldap.passwd}},
Message = #'LDAPMessage'{messageID = Id,
protocolOp = {bindRequest, Req}},
?DEBUG("Bind Request Message:~p~n", [Message]),
{ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage',
Message),
?DEBUG("Bind Request Message:~p~n",[Message]),
{ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', Message),
case (S#eldap.sockmod):send(Socket, Bytes) of
ok -> {ok, S#eldap{id = Id}};
Error -> Error
ok ->
Timer = erlang:start_timer(?BIND_TIMEOUT, self(),
bind_timeout),
{next_state, wait_bind_response, S#eldap{bind_timer = Timer}};
{error, Reason} ->
report_bind_failure(S#eldap.host, S#eldap.port, Reason),
NewS = close_and_retry(S),
{next_state, connecting, NewS}
end.
starttls_request(#eldap{fd = Socket, id = Id} = S) ->
Req = #'ExtendedRequest'{requestName = ?STARTTLS},
Message = #'LDAPMessage'{messageID = Id,
protocolOp = {extendedReq, Req}},
?DEBUG("StartTLS Request Message: ~p~n", [Message]),
{ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', Message),
case (S#eldap.sockmod):send(Socket, Bytes) of
ok ->
Timer = erlang:start_timer(?STARTTLS_TIMEOUT, self(),
starttls_timeout),
{next_state, wait_starttls_response, S#eldap{starttls_timer = Timer}};
{error, Reason} ->
report_starttls_failure(S#eldap.host, S#eldap.port, Reason),
NewS = close_and_retry(S),
{next_state, connecting, NewS}
end.
next_host(undefined, [H | _]) ->
@ -963,4 +1044,12 @@ next_host(Host, [_ | T], Hosts) ->
bump_id(#eldap{id = Id})
when Id > (?MAX_TRANSACTION_ID) ->
?MIN_TRANSACTION_ID;
bump_id(#eldap{id = Id}) -> Id + 1.
activate_socket(#eldap{sockmod = SockMod, fd = Sock}) ->
if SockMod == gen_tcp ->
inet:setopts(Sock, [{active, once}]);
true ->
SockMod:setopts(Sock, [{active, once}])
end.

View File

@ -46,6 +46,7 @@
-type tlsopts() :: [{encrypt, tls | starttls | none} |
{tls_cacertfile, binary() | undefined} |
{tls_certfile, binary() | undefined} |
{tls_depth, non_neg_integer() | undefined} |
{tls_verify, hard | soft | false}].

View File

@ -207,6 +207,8 @@ get_config(Host, Opts) ->
(soft) -> soft;
(false) -> false
end, false),
TLSCFile = get_opt({ldap_tls_certfile, Host}, Opts,
fun iolist_to_binary/1),
TLSCAFile = get_opt({ldap_tls_cacertfile, Host}, Opts,
fun iolist_to_binary/1),
TLSDepth = get_opt({ldap_tls_depth, Host}, Opts,
@ -237,6 +239,7 @@ get_config(Host, Opts) ->
backups = Backups,
tls_options = [{encrypt, Encrypt},
{tls_verify, TLSVerify},
{tls_certfile, TLSCFile},
{tls_cacertfile, TLSCAFile},
{tls_depth, TLSDepth}],
port = Port,