From aad740f34cd99d036b7695bb9ea22ebab9110c3c Mon Sep 17 00:00:00 2001 From: Evgeniy Khramtsov Date: Fri, 20 Jul 2012 23:38:20 +1000 Subject: [PATCH] LDAP StartTLS support Conflicts: src/ejabberd_auth_ldap.erl src/eldap/eldap.erl src/mod_shared_roster_ldap.erl src/mod_vcard_ldap.erl src/tls/tls.erl --- src/eldap/eldap.erl | 289 +++++++++++++++++++++++++------------- src/eldap/eldap.hrl | 1 + src/eldap/eldap_utils.erl | 3 + 3 files changed, 193 insertions(+), 100 deletions(-) diff --git a/src/eldap/eldap.erl b/src/eldap/eldap.erl index eab8e721c..4d35c5c95 100644 --- a/src/eldap/eldap.erl +++ b/src/eldap/eldap.erl @@ -57,6 +57,7 @@ %%% LDAP Client state machine. %%% Possible states are: %%% connecting - actually disconnected, but retrying periodically +%%% wait_starttls_response - connected and send starttls request %%% wait_bind_response - connected and sent bind request %%% active - bound to LDAP Server and ready to handle commands %%% active_bind - sent bind() request and waiting for response @@ -81,6 +82,7 @@ %% gen_fsm callbacks -export([init/1, connecting/2, connecting/3, wait_bind_response/3, active/3, active_bind/3, + wait_starttls_response/3, handle_event/3, handle_sync_event/4, handle_info/3, terminate/3, code_change/4]). @@ -94,6 +96,8 @@ -define(RETRY_TIMEOUT, 500). +-define(STARTTLS_TIMEOUT, 10000). + -define(BIND_TIMEOUT, 10000). -define(CMD_TIMEOUT, 100000). @@ -134,6 +138,7 @@ passwd = <<"">> :: binary(), id = 0 :: non_neg_integer(), bind_timer = make_ref() :: reference(), + starttls_timer = make_ref() :: reference(), dict = dict:new() :: dict(), req_q = queue:new() :: queue()}). @@ -431,6 +436,7 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) -> (none) -> none end) of tls -> tls; + starttls -> starttls; _ -> none end, PortTemp = case Port of @@ -450,9 +456,21 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) -> end) of undefined -> []; - Path -> - [{cacertfile, Path}] + CAPath -> + [{cacertfile, CAPath}] end, + CertOpts = case gen_mod:get_opt( + tls_certfile, Opts, + fun(S) when is_binary(S) -> + binary_to_list(S); + (undefined) -> + undefined + end) of + undefined -> + []; + CPath -> + [{certfile, CPath}] + end, DepthOpts = case gen_mod:get_opt( tls_depth, Opts, fun(I) when is_integer(I), I>=0 -> @@ -471,17 +489,18 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) -> (false) -> false end, false), TLSOpts = if (Verify == hard orelse Verify == soft) - andalso CacertOpts == [] -> - ?WARNING_MSG("TLS verification is enabled but no CA " - "certfiles configured, so verification " - "is disabled.", - []), - []; + andalso CacertOpts == [] -> + ?WARNING_MSG("TLS verification is enabled but no CA " + "certfiles configured, so verification " + "is disabled.", + []), + []; Verify == soft -> - [{verify, 1}] ++ CacertOpts ++ DepthOpts; + [{verify, 1}] ++ CacertOpts ++ DepthOpts ++ CertOpts; Verify == hard -> - [{verify, 2}] ++ CacertOpts ++ DepthOpts; - true -> [] + [{verify, 2}] ++ CacertOpts ++ DepthOpts ++ CertOpts; + true -> + CacertOpts ++ DepthOpts ++ CertOpts end, {ok, connecting, #eldap{hosts = Hosts, port = PortTemp, rootdn = Rootdn, @@ -490,13 +509,16 @@ init([Hosts, Port, Rootdn, Passwd, Opts]) -> 0}. connecting(timeout, S) -> - {ok, NextState, NewS} = connect_bind(S), - {next_state, NextState, NewS}. + connect_bind(S). connecting(Event, From, S) -> Q = queue:in({Event, From}, S#eldap.req_q), {next_state, connecting, S#eldap{req_q = Q}}. +wait_starttls_response(Event, From, S) -> + Q = queue:in({Event, From}, S#eldap.req_q), + {next_state, wait_starttls_response, S#eldap{req_q=Q}}. + wait_bind_response(Event, From, S) -> Q = queue:in({Event, From}, S#eldap.req_q), {next_state, wait_bind_response, S#eldap{req_q = Q}}. @@ -525,12 +547,40 @@ handle_sync_event(_Event, _From, StateName, S) -> %%---------------------------------------------------------------------- handle_info({Tag, _Socket, Data}, connecting, S) - when Tag == tcp; Tag == ssl -> - ?DEBUG("tcp packet received when disconnected!~n~p", - [Data]), + when Tag == tcp; Tag == ssl -> + activate_socket(S), + ?DEBUG("tcp packet received when disconnected!~n~p", [Data]), {next_state, connecting, S}; -handle_info({Tag, _Socket, Data}, wait_bind_response, S) - when Tag == tcp; Tag == ssl -> + +handle_info({Tag, Socket, Data}, StateName, S) + when Tag == tcp; Tag == ssl -> + activate_socket(S), + handle_info({asn1, Socket, Data}, StateName, S); + +handle_info({asn1, Socket, Data}, wait_starttls_response, S) -> + cancel_timer(S#eldap.starttls_timer), + case catch recvd_wait_starttls_response(Data, S) of + proceed -> + case ssl:connect(Socket, S#eldap.tls_options, + ?STARTTLS_TIMEOUT) of + {error, Reason} -> + report_starttls_failure(S#eldap.host, S#eldap.port, Reason); + {ok, SSLSock} -> + bind_request(S#eldap{fd = SSLSock, sockmod = ssl, + id = bump_id(S)}) + end; + {fail_starttls, Reason} -> + report_starttls_failure(S#eldap.host, S#eldap.port, Reason), + {next_state, connecting, close_and_retry(S, ?GRACEFUL_RETRY_TIMEOUT)}; + {'EXIT', Reason} -> + report_starttls_failure(S#eldap.host, S#eldap.port, Reason), + {next_state, connecting, close_and_retry(S)}; + {error, Reason} -> + report_starttls_failure(S#eldap.host, S#eldap.port, Reason), + {next_state, connecting, close_and_retry(S)} + end; + +handle_info({asn1, _Socket, Data}, wait_bind_response, S) -> cancel_timer(S#eldap.bind_timer), case catch recvd_wait_bind_response(Data, S) of bound -> dequeue_commands(S); @@ -545,10 +595,9 @@ handle_info({Tag, _Socket, Data}, wait_bind_response, S) report_bind_failure(S#eldap.host, S#eldap.port, Reason), {next_state, connecting, close_and_retry(S)} end; -handle_info({Tag, _Socket, Data}, StateName, S) - when (StateName == active orelse - StateName == active_bind) - andalso (Tag == tcp orelse Tag == ssl) -> + +handle_info({asn1, _Socket, Data}, StateName, S) + when (StateName == active orelse StateName == active_bind) -> case catch recvd_packet(Data, S) of {response, Response, RequestType} -> NewS = case Response of @@ -586,11 +635,14 @@ handle_info({timeout, Timer, {cmd_timeout, Id}}, {error, _Reason} -> {next_state, StateName, S} end; handle_info({timeout, retry_connect}, connecting, S) -> - {ok, NextState, NewS} = connect_bind(S), - {next_state, NextState, NewS}; -handle_info({timeout, _Timer, bind_timeout}, - wait_bind_response, S) -> + connect_bind(S); + +handle_info({timeout, _Timer, bind_timeout}, wait_bind_response, S) -> {next_state, connecting, close_and_retry(S)}; + +handle_info({timeout, _Timer, starttls_timeout}, wait_starttls_response, S) -> + {next_state, connecting, close_and_retry(S)}; + %% %% Make sure we don't fill the message queue with rubbish %% @@ -633,17 +685,15 @@ send_command(Command, From, S) -> {Name, Request} = gen_req(Command), Message = #'LDAPMessage'{messageID = Id, protocolOp = {Name, Request}}, - ?DEBUG("~p~n", [{Name, Request}]), - {ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', - Message), - case (S#eldap.sockmod):send(S#eldap.fd, Bytes) of - ok -> - Timer = erlang:start_timer(?CMD_TIMEOUT, self(), - {cmd_timeout, Id}), - New_dict = dict:store(Id, - [{Timer, Command, From, Name}], S#eldap.dict), - {ok, S#eldap{id = Id, dict = New_dict}}; - Error -> Error + ?DEBUG("~p~n",[Message]), + {ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', Message), + case (S#eldap.sockmod):send(S#eldap.fd, iolist_to_binary(Bytes)) of + ok -> + Timer = erlang:start_timer(?CMD_TIMEOUT, self(), {cmd_timeout, Id}), + New_dict = dict:store(Id, [{Timer, Command, From, Name}], S#eldap.dict), + {ok, S#eldap{id = Id, dict = New_dict}}; + Error -> + Error end. gen_req({search, A}) -> @@ -683,7 +733,6 @@ gen_req({bind, RootDN, Passwd}) -> authentication = {simple, Passwd}}}. recvd_packet(Pkt, S) -> - check_tag(Pkt), case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of {ok, Msg} -> Op = Msg#'LDAPMessage'.protocolOp, @@ -791,7 +840,6 @@ get_op_rec(Id, Dict) -> end. recvd_wait_bind_response(Pkt, S) -> - check_tag(Pkt), case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of {ok, Msg} -> ?DEBUG("~p", [Msg]), @@ -806,35 +854,43 @@ recvd_wait_bind_response(Pkt, S) -> Else -> {fail_bind, Else} end. +recvd_wait_starttls_response(Pkt, S) -> + case asn1rt:decode('ELDAPv3', 'LDAPMessage', Pkt) of + {ok, Msg} -> + ?DEBUG("~p", [Msg]), + check_id(S#eldap.id, Msg#'LDAPMessage'.messageID), + case Msg#'LDAPMessage'.protocolOp of + {extendedResp, Result} -> + case Result#'ExtendedResponse'.resultCode of + success -> proceed; + Error -> {fail_starttls, Error} + end + end; + Else -> + {error, Else} + end. + check_id(Id, Id) -> ok; check_id(_, _) -> throw({error, wrong_bind_id}). %%----------------------------------------------------------------------- %% General Helpers %%----------------------------------------------------------------------- - cancel_timer(Timer) -> erlang:cancel_timer(Timer), receive {timeout, Timer, _} -> ok after 0 -> ok end. -check_tag(Data) -> - {_Tag, Data1, _Rb} = asn1rt_ber_bin:decode_tag(Data), - {{_Len, _Data2}, _Rb2} = - asn1rt_ber_bin:decode_length(Data1), - ok. - close_and_retry(S, Timeout) -> catch (S#eldap.sockmod):close(S#eldap.fd), - Queue = dict:fold(fun (_Id, - [{Timer, Command, From, _Name} | _], Q) -> - cancel_timer(Timer), - queue:in_r({Command, From}, Q); - (_, _, Q) -> Q - end, - S#eldap.req_q, S#eldap.dict), - erlang:send_after(Timeout, self(), - {timeout, retry_connect}), - S#eldap{fd = undefined, req_q = Queue, dict = dict:new()}. + Queue = dict:fold( + fun(_Id, [{Timer, Command, From, _Name}|_], Q) -> + cancel_timer(Timer), + queue:in_r({Command, From}, Q); + (_, _, Q) -> + Q + end, S#eldap.req_q, S#eldap.dict), + erlang:send_after(Timeout, self(), {timeout, retry_connect}), + S#eldap{fd=null, req_q=Queue, dict=dict:new()}. close_and_retry(S) -> close_and_retry(S, ?RETRY_TIMEOUT). @@ -843,6 +899,15 @@ report_bind_failure(Host, Port, Reason) -> ?WARNING_MSG("LDAP bind failed on ~s:~p~nReason: ~p", [Host, Port, Reason]). +report_starttls_failure(Host, Port, Reason) -> + ?WARNING_MSG("LDAP StartTLS failed:~n" + "** Server: ~s:~p~n" + "** Reason: ~p", + [Host, Port, Reason]). + +%%----------------------------------------------------------------------- +%% Sort out timed out commands +%%----------------------------------------------------------------------- cmd_timeout(Timer, Id, S) -> Dict = S#eldap.dict, case dict:find(Id, Dict) of @@ -891,58 +956,74 @@ connect_bind(S) -> Host = next_host(S#eldap.host, S#eldap.hosts), ?INFO_MSG("LDAP connection on ~s:~p", [Host, S#eldap.port]), - Opts = if S#eldap.tls == tls -> - [{packet, asn1}, {active, true}, {keepalive, true}, - binary - | S#eldap.tls_options]; - true -> - [{packet, asn1}, {active, true}, {keepalive, true}, - {send_timeout, ?SEND_TIMEOUT}, binary] - end, + Opts = case S#eldap.tls of + tls -> + [{packet, asn1}, {active, once}, {keepalive, true}, + binary | S#eldap.tls_options]; + _ -> + [{packet, asn1}, {active, once}, {keepalive, true}, + {send_timeout, ?SEND_TIMEOUT}, binary] + end, HostS = binary_to_list(Host), - SocketData = case S#eldap.tls of - tls -> - SockMod = ssl, ssl:connect(HostS, S#eldap.port, Opts); - %% starttls -> %% TODO: Implement STARTTLS; - _ -> - SockMod = gen_tcp, - gen_tcp:connect(HostS, S#eldap.port, Opts) - end, - case SocketData of - {ok, Socket} -> - case bind_request(Socket, S#eldap{sockmod = SockMod}) of - {ok, NewS} -> - Timer = erlang:start_timer(?BIND_TIMEOUT, self(), - {timeout, bind_timeout}), - {ok, wait_bind_response, - NewS#eldap{fd = Socket, sockmod = SockMod, host = Host, - bind_timer = Timer}}; - {error, Reason} -> - report_bind_failure(Host, S#eldap.port, Reason), - NewS = close_and_retry(S), - {ok, connecting, NewS#eldap{host = Host}} - end; - {error, Reason} -> - ?ERROR_MSG("LDAP connection failed:~n** Server: " - "~s:~p~n** Reason: ~p~n** Socket options: ~p", - [Host, S#eldap.port, Reason, Opts]), - NewS = close_and_retry(S), - {ok, connecting, NewS#eldap{host = Host}} + SockMod = case S#eldap.tls of + tls -> ssl; + _ -> gen_tcp + end, + case SockMod:connect(HostS, S#eldap.port, Opts) of + {ok, Socket} -> + Id = bump_id(S), + NewS = S#eldap{host = Host, sockmod = SockMod, + id = Id, fd = Socket}, + case S#eldap.tls of + starttls -> + starttls_request(NewS); + _ -> + bind_request(NewS) + end; + {error, Reason} -> + ?ERROR_MSG("LDAP connection failed:~n" + "** Server: ~s:~p~n" + "** Reason: ~p~n" + "** Socket options: ~p", + [Host, S#eldap.port, Reason, Opts]), + NewS = close_and_retry(S), + {next_state, connecting, NewS#eldap{host = Host}} end. -bind_request(Socket, S) -> - Id = bump_id(S), +bind_request(#eldap{fd = Socket, id = Id} = S) -> Req = #'BindRequest'{version = S#eldap.version, name = S#eldap.rootdn, - authentication = {simple, S#eldap.passwd}}, + authentication = {simple, S#eldap.passwd}}, Message = #'LDAPMessage'{messageID = Id, protocolOp = {bindRequest, Req}}, - ?DEBUG("Bind Request Message:~p~n", [Message]), - {ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', - Message), + ?DEBUG("Bind Request Message:~p~n",[Message]), + {ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', Message), case (S#eldap.sockmod):send(Socket, Bytes) of - ok -> {ok, S#eldap{id = Id}}; - Error -> Error + ok -> + Timer = erlang:start_timer(?BIND_TIMEOUT, self(), + bind_timeout), + {next_state, wait_bind_response, S#eldap{bind_timer = Timer}}; + {error, Reason} -> + report_bind_failure(S#eldap.host, S#eldap.port, Reason), + NewS = close_and_retry(S), + {next_state, connecting, NewS} + end. + +starttls_request(#eldap{fd = Socket, id = Id} = S) -> + Req = #'ExtendedRequest'{requestName = ?STARTTLS}, + Message = #'LDAPMessage'{messageID = Id, + protocolOp = {extendedReq, Req}}, + ?DEBUG("StartTLS Request Message: ~p~n", [Message]), + {ok, Bytes} = asn1rt:encode('ELDAPv3', 'LDAPMessage', Message), + case (S#eldap.sockmod):send(Socket, Bytes) of + ok -> + Timer = erlang:start_timer(?STARTTLS_TIMEOUT, self(), + starttls_timeout), + {next_state, wait_starttls_response, S#eldap{starttls_timer = Timer}}; + {error, Reason} -> + report_starttls_failure(S#eldap.host, S#eldap.port, Reason), + NewS = close_and_retry(S), + {next_state, connecting, NewS} end. next_host(undefined, [H | _]) -> @@ -963,4 +1044,12 @@ next_host(Host, [_ | T], Hosts) -> bump_id(#eldap{id = Id}) when Id > (?MAX_TRANSACTION_ID) -> ?MIN_TRANSACTION_ID; + bump_id(#eldap{id = Id}) -> Id + 1. + +activate_socket(#eldap{sockmod = SockMod, fd = Sock}) -> + if SockMod == gen_tcp -> + inet:setopts(Sock, [{active, once}]); + true -> + SockMod:setopts(Sock, [{active, once}]) + end. diff --git a/src/eldap/eldap.hrl b/src/eldap/eldap.hrl index d99332ab4..c88da00de 100644 --- a/src/eldap/eldap.hrl +++ b/src/eldap/eldap.hrl @@ -46,6 +46,7 @@ -type tlsopts() :: [{encrypt, tls | starttls | none} | {tls_cacertfile, binary() | undefined} | + {tls_certfile, binary() | undefined} | {tls_depth, non_neg_integer() | undefined} | {tls_verify, hard | soft | false}]. diff --git a/src/eldap/eldap_utils.erl b/src/eldap/eldap_utils.erl index cc72f950d..c6040e60c 100644 --- a/src/eldap/eldap_utils.erl +++ b/src/eldap/eldap_utils.erl @@ -207,6 +207,8 @@ get_config(Host, Opts) -> (soft) -> soft; (false) -> false end, false), + TLSCFile = get_opt({ldap_tls_certfile, Host}, Opts, + fun iolist_to_binary/1), TLSCAFile = get_opt({ldap_tls_cacertfile, Host}, Opts, fun iolist_to_binary/1), TLSDepth = get_opt({ldap_tls_depth, Host}, Opts, @@ -237,6 +239,7 @@ get_config(Host, Opts) -> backups = Backups, tls_options = [{encrypt, Encrypt}, {tls_verify, TLSVerify}, + {tls_certfile, TLSCFile}, {tls_cacertfile, TLSCAFile}, {tls_depth, TLSDepth}], port = Port,