From 1433dafe6bda9840cb687c5c3270584fb3ee55d1 Mon Sep 17 00:00:00 2001 From: Alexey Shchepin Date: Tue, 25 Oct 2005 01:08:37 +0000 Subject: [PATCH] * src/tls/tls_drv.c: Support for "connect" method * src/tls/tls.erl: Likewise * src/ejabberd_s2s_in.erl: Support for STARTTLS+Dialback * src/ejabberd_s2s_out.erl: Likewise * src/ejabberd_receiver.erl: Added a few hacks ({active,once} mode should be used instead of recv/3 call to avoid them) * src/ejabberd_config.erl: Added s2s_use_starttls and s2s_certfile options * src/ejabberd.cfg.example: Likewise SVN Revision: 426 --- ChangeLog | 13 ++ src/ejabberd.cfg.example | 5 + src/ejabberd_config.erl | 4 + src/ejabberd_receiver.erl | 37 +++-- src/ejabberd_s2s_in.erl | 116 ++++++++++--- src/ejabberd_s2s_out.erl | 342 +++++++++++++++++++++++++++----------- src/tls/tls.erl | 30 ++-- src/tls/tls_drv.c | 44 +++-- 8 files changed, 427 insertions(+), 164 deletions(-) diff --git a/ChangeLog b/ChangeLog index ec7d586f4..dd32cf530 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,16 @@ +2005-10-25 Alexey Shchepin + + * src/tls/tls_drv.c: Support for "connect" method + * src/tls/tls.erl: Likewise + + * src/ejabberd_s2s_in.erl: Support for STARTTLS+Dialback + * src/ejabberd_s2s_out.erl: Likewise + * src/ejabberd_receiver.erl: Added a few hacks ({active,once} mode + should be used instead of recv/3 call to avoid them) + * src/ejabberd_config.erl: Added s2s_use_starttls and s2s_certfile + options + * src/ejabberd.cfg.example: Likewise + 2005-10-22 Alexey Shchepin * src/ejabberd_app.erl: Try to load tls_drv at startup to avoid diff --git a/src/ejabberd.cfg.example b/src/ejabberd.cfg.example index df3c3461c..1cc468bf5 100644 --- a/src/ejabberd.cfg.example +++ b/src/ejabberd.cfg.example @@ -115,6 +115,11 @@ [{password, "secret"}]}]} ]}. + +% Use STARTTLS+Dialback for S2S connections +{s2s_use_starttls, true}. +{s2s_certfile, "./ssl.pem"}. + % If SRV lookup fails, then port 5269 is used to communicate with remote server {outgoing_s2s_port, 5269}. diff --git a/src/ejabberd_config.erl b/src/ejabberd_config.erl index 8bce8eae6..51c183a21 100644 --- a/src/ejabberd_config.erl +++ b/src/ejabberd_config.erl @@ -108,6 +108,10 @@ process_term(Term, State) -> add_option(listen, Val, State); {outgoing_s2s_port, Port} -> add_option(outgoing_s2s_port, Port, State); + {s2s_use_starttls, Port} -> + add_option(s2s_use_starttls, Port, State); + {s2s_certfile, Port} -> + add_option(s2s_certfile, Port, State); {Opt, Val} -> lists:foldl(fun(Host, S) -> process_host_term(Term, Host, S) end, State, State#state.hosts) diff --git a/src/ejabberd_receiver.erl b/src/ejabberd_receiver.erl index 1f1897fb1..204771c1a 100644 --- a/src/ejabberd_receiver.erl +++ b/src/ejabberd_receiver.erl @@ -36,24 +36,27 @@ receiver(Socket, SockMod, Shaper, C2SPid) -> receiver(Socket, SockMod, ShaperState, C2SPid, XMLStreamState, Timeout) -> Res = (catch SockMod:recv(Socket, 0, Timeout)), - case Res of - {ok, Data} -> - receive - {starttls, TLSSocket} -> - xml_stream:close(XMLStreamState), - XMLStreamState1 = xml_stream:new(C2SPid), - TLSRes = tls:recv_data(TLSSocket, Data), - receiver1(TLSSocket, tls, - ShaperState, C2SPid, XMLStreamState1, Timeout, - TLSRes) - after 0 -> - receiver1(Socket, SockMod, - ShaperState, C2SPid, XMLStreamState, Timeout, - Res) - end; - _ -> + receive + {starttls, TLSSocket} -> + xml_stream:close(XMLStreamState), + XMLStreamState1 = xml_stream:new(C2SPid), + TLSRes = case Res of + {ok, Data} -> + tls:recv_data(TLSSocket, Data); + _ -> + tls:recv_data(TLSSocket, "") + end, + receiver1(TLSSocket, tls, + ShaperState, C2SPid, XMLStreamState1, Timeout, + TLSRes); + {change_timeout, NewTimeout} -> % Dirty hack receiver1(Socket, SockMod, - ShaperState, C2SPid, XMLStreamState, Timeout, Res) + ShaperState, C2SPid, XMLStreamState, NewTimeout, + Res) + after 0 -> + receiver1(Socket, SockMod, + ShaperState, C2SPid, XMLStreamState, Timeout, + Res) end. diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index 1c09c0607..d2b616751 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -14,13 +14,12 @@ %% External exports -export([start/2, - start_link/2, - send_text/2, - send_element/2]). + start_link/2]). %% gen_fsm callbacks -export([init/1, wait_for_stream/2, + wait_for_feature_request/2, stream_established/2, handle_event/3, handle_sync_event/4, @@ -34,9 +33,13 @@ -define(DICT, dict). -record(state, {socket, + sockmod, receiver, streamid, shaper, + tls = false, + tls_enabled = false, + tls_options = [], connections = ?DICT:new(), timer}). @@ -49,13 +52,13 @@ -define(FSMOPTS, []). -endif. --define(STREAM_HEADER, +-define(STREAM_HEADER(Version), ("" "") + "id='" ++ StateData#state.streamid ++ "'" ++ Version ++ ">") ). -define(STREAM_TRAILER, ""). @@ -96,12 +99,28 @@ init([{SockMod, Socket}, Opts]) -> {value, {_, S}} -> S; _ -> none end, + StartTLS = case ejabberd_config:get_local_option(s2s_use_starttls) of + undefined -> + false; + UseStartTLS -> + UseStartTLS + end, + TLSOpts = case ejabberd_config:get_local_option(s2s_certfile) of + undefined -> + []; + CertFile -> + [{certfile, CertFile}] + end, Timer = erlang:start_timer(?S2STIMEOUT, self(), []), {ok, wait_for_stream, #state{socket = Socket, + sockmod = SockMod, receiver = ReceiverPid, streamid = new_id(), shaper = Shaper, + tls = StartTLS, + tls_enabled = false, + tls_options = TLSOpts, timer = Timer}}. %%---------------------------------------------------------------------- @@ -113,18 +132,28 @@ init([{SockMod, Socket}, Opts]) -> wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> % TODO - case {xml:get_attr_s("xmlns", Attrs), xml:get_attr_s("xmlns:db", Attrs)} of - {"jabber:server", "jabber:server:dialback"} -> - send_text(StateData#state.socket, ?STREAM_HEADER), - {next_state, stream_established, StateData#state{}}; + case {xml:get_attr_s("xmlns", Attrs), + xml:get_attr_s("xmlns:db", Attrs), + xml:get_attr_s("version", Attrs) == "1.0"} of + {"jabber:server", "jabber:server:dialback", true} when + StateData#state.tls -> + send_text(StateData, ?STREAM_HEADER(" version='1.0'")), + send_element(StateData, + {xmlelement, "stream:features", [], + [{xmlelement, "starttls", + [{"xmlns", ?NS_TLS}], []}]}), + {next_state, wait_for_feature_request, StateData}; + {"jabber:server", "jabber:server:dialback", _} -> + send_text(StateData, ?STREAM_HEADER("")), + {next_state, stream_established, StateData}; _ -> - send_text(StateData#state.socket, ?INVALID_NAMESPACE_ERR), + send_text(StateData, ?INVALID_NAMESPACE_ERR), {stop, normal, StateData} end; wait_for_stream({xmlstreamerror, _}, StateData) -> - send_text(StateData#state.socket, - ?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_text(StateData, + ?STREAM_HEADER("") ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER), {stop, normal, StateData}; wait_for_stream(timeout, StateData) -> @@ -133,6 +162,45 @@ wait_for_stream(timeout, StateData) -> wait_for_stream(closed, StateData) -> {stop, normal, StateData}. + +wait_for_feature_request({xmlstreamelement, El}, StateData) -> + {xmlelement, Name, Attrs, Els} = El, + TLS = StateData#state.tls, + TLSEnabled = StateData#state.tls_enabled, + SockMod = StateData#state.sockmod, + case {xml:get_attr_s("xmlns", Attrs), Name} of + {?NS_TLS, "starttls"} when TLS == true, + TLSEnabled == false, + SockMod == gen_tcp -> + ?INFO_MSG("starttls", []), + Socket = StateData#state.socket, + TLSOpts = StateData#state.tls_options, + {ok, TLSSocket} = tls:tcp_to_tls(Socket, TLSOpts), + ejabberd_receiver:starttls(StateData#state.receiver, TLSSocket), + send_element(StateData, + {xmlelement, "proceed", [{"xmlns", ?NS_TLS}], []}), + {next_state, wait_for_stream, + StateData#state{sockmod = tls, + socket = TLSSocket, + streamid = new_id(), + tls_enabled = true + }}; + _ -> + stream_established({xmlstreamelement, El}, StateData) + end; + +wait_for_feature_request({xmlstreamend, _Name}, StateData) -> + send_text(StateData, ?STREAM_TRAILER), + {stop, normal, StateData}; + +wait_for_feature_request({xmlstreamerror, _}, StateData) -> + send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + {stop, normal, StateData}; + +wait_for_feature_request(closed, StateData) -> + {stop, normal, StateData}. + + stream_established({xmlstreamelement, El}, StateData) -> cancel_timer(StateData#state.timer), Timer = erlang:start_timer(?S2STIMEOUT, self(), []), @@ -154,7 +222,7 @@ stream_established({xmlstreamelement, El}, StateData) -> StateData#state{connections = Conns, timer = Timer}}; _ -> - send_text(StateData#state.socket, ?HOST_UNKNOWN_ERR), + send_text(StateData, ?HOST_UNKNOWN_ERR), {stop, normal, StateData} end; {verify, To, From, Id, Key} -> @@ -165,7 +233,7 @@ stream_established({xmlstreamelement, El}, StateData) -> Type = if Key == Key1 -> "valid"; true -> "invalid" end, - send_element(StateData#state.socket, + send_element(StateData, {xmlelement, "db:verify", [{"from", To}, @@ -204,7 +272,7 @@ stream_established({xmlstreamelement, El}, StateData) -> end; stream_established({valid, From, To}, StateData) -> - send_element(StateData#state.socket, + send_element(StateData, {xmlelement, "db:result", [{"from", To}, @@ -219,7 +287,7 @@ stream_established({valid, From, To}, StateData) -> {next_state, stream_established, NSD}; stream_established({invalid, From, To}, StateData) -> - send_element(StateData#state.socket, + send_element(StateData, {xmlelement, "db:result", [{"from", To}, @@ -237,8 +305,8 @@ stream_established({xmlstreamend, _Name}, StateData) -> {stop, normal, StateData}; stream_established({xmlstreamerror, _}, StateData) -> - send_text(StateData#state.socket, - ?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_text(StateData, + ?INVALID_XML_ERR ++ ?STREAM_TRAILER), {stop, normal, StateData}; stream_established(timeout, StateData) -> @@ -294,7 +362,7 @@ code_change(_OldVsn, StateName, StateData, _Extra) -> %% {stop, Reason, NewStateData} %%---------------------------------------------------------------------- handle_info({send_text, Text}, StateName, StateData) -> - send_text(StateData#state.socket, Text), + send_text(StateData, Text), {next_state, StateName, StateData}; handle_info({timeout, Timer, _}, StateName, @@ -312,18 +380,18 @@ handle_info(_, StateName, StateData) -> %%---------------------------------------------------------------------- terminate(Reason, _StateName, StateData) -> ?INFO_MSG("terminated: ~p", [Reason]), - gen_tcp:close(StateData#state.socket), + (StateData#state.sockmod):close(StateData#state.socket), ok. %%%---------------------------------------------------------------------- %%% Internal functions %%%---------------------------------------------------------------------- -send_text(Socket, Text) -> - gen_tcp:send(Socket,Text). +send_text(StateData, Text) -> + (StateData#state.sockmod):send(StateData#state.socket, Text). -send_element(Socket, El) -> - send_text(Socket, xml:element_to_string(El)). +send_element(StateData, El) -> + send_text(StateData, xml:element_to_string(El)). change_shaper(StateData, Host, JID) -> diff --git a/src/ejabberd_s2s_out.erl b/src/ejabberd_s2s_out.erl index aec98ba6e..0b3d46f19 100644 --- a/src/ejabberd_s2s_out.erl +++ b/src/ejabberd_s2s_out.erl @@ -13,13 +13,15 @@ -behaviour(gen_fsm). %% External exports --export([start/3, start_link/3, send_text/2, send_element/2]). +-export([start/3, start_link/3]). %% gen_fsm callbacks -export([init/1, open_socket/2, wait_for_stream/2, wait_for_validation/2, + wait_for_features/2, + wait_for_starttls_proceed/2, stream_established/2, handle_event/3, handle_sync_event/4, @@ -30,8 +32,15 @@ -include("ejabberd.hrl"). -include("jlib.hrl"). --record(state, {socket, receiver, streamid, - myname, server, xmlpid, queue, +-record(state, {socket, receiver, + sockmod, + streamid, + use_v10, + tls = false, + tls_required = false, + tls_enabled = false, + tls_options = [], + myname, server, queue, new = false, verify = false, timer}). @@ -49,7 +58,7 @@ "xmlns:stream='http://etherx.jabber.org/streams' " "xmlns='jabber:server' " "xmlns:db='jabber:server:dialback' " - "to='~s'>" + "to='~s'~s>" ). -define(STREAM_TRAILER, ""). @@ -86,6 +95,19 @@ start_link(From, Host, Type) -> init([From, Server, Type]) -> ?INFO_MSG("started: ~p", [{From, Server, Type}]), gen_fsm:send_event(self(), init), + TLS = case ejabberd_config:get_local_option(s2s_use_starttls) of + undefined -> + false; + UseStartTLS -> + UseStartTLS + end, + UseV10 = TLS, + TLSOpts = case ejabberd_config:get_local_option(s2s_certfile) of + undefined -> + []; + CertFile -> + [{certfile, CertFile}, connect] + end, {New, Verify} = case Type of {new, Key} -> {Key, false}; @@ -93,7 +115,10 @@ init([From, Server, Type]) -> {false, {Pid, Key, SID}} end, Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - {ok, open_socket, #state{queue = queue:new(), + {ok, open_socket, #state{use_v10 = UseV10, + tls = TLS, + tls_options = TLSOpts, + queue = queue:new(), myname = From, server = Server, new = New, @@ -113,23 +138,34 @@ open_socket(init, StateData) -> ASCIIAddr -> ?DEBUG("s2s_out: connecting to ~s:~p~n", [ASCIIAddr, Port]), case gen_tcp:connect(ASCIIAddr, Port, - [binary, {packet, 0}]) of + [binary, {packet, 0}, + {active, false}]) of {ok, _Socket} = R -> R; {error, Reason1} -> ?DEBUG("s2s_out: connect return ~p~n", [Reason1]), catch gen_tcp:connect(Addr, Port, - [binary, {packet, 0}, inet6]) + [binary, {packet, 0}, + {active, false}, inet6]) end end, case Res of {ok, Socket} -> - XMLStreamPid = xml_stream:start(self()), - send_text(Socket, io_lib:format(?STREAM_HEADER, - [StateData#state.server])), - {next_state, wait_for_stream, - StateData#state{socket = Socket, - xmlpid = XMLStreamPid, - streamid = new_id()}}; + ReceiverPid = ejabberd_receiver:start(Socket, gen_tcp, none), + Version = if + StateData#state.use_v10 -> + " version='1.0'"; + true -> + "" + end, + NewStateData = StateData#state{socket = Socket, + sockmod = gen_tcp, + tls_enabled = false, + receiver = ReceiverPid, + streamid = new_id()}, + send_text(NewStateData, io_lib:format(?STREAM_HEADER, + [StateData#state.server, + Version])), + {next_state, wait_for_stream, NewStateData}; {error, Reason} -> ?DEBUG("s2s_out: inet6 connect return ~p~n", [Reason]), Error = ?ERR_REMOTE_SERVER_NOT_FOUND, @@ -140,58 +176,36 @@ open_socket(init, StateData) -> Error = ?ERR_REMOTE_SERVER_NOT_FOUND, bounce_messages(Error), {stop, normal, StateData} - end. + end; +open_socket(_, StateData) -> + {next_state, open_socket, StateData}. wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) -> - % TODO - case {xml:get_attr_s("xmlns", Attrs), xml:get_attr_s("xmlns:db", Attrs)} of - {"jabber:server", "jabber:server:dialback"} -> - Server = StateData#state.server, - New = case StateData#state.new of - false -> - case ejabberd_s2s:try_register( - {StateData#state.myname, Server}) of - {key, Key} -> - Key; - false -> - false - end; - Key -> - Key - end, - case New of - false -> - ok; - Key1 -> - send_element(StateData#state.socket, - {xmlelement, - "db:result", - [{"from", StateData#state.myname}, - {"to", Server}], - [{xmlcdata, Key1}]}) - end, - case StateData#state.verify of - false -> - ok; - {Pid, Key2, SID} -> - send_element(StateData#state.socket, - {xmlelement, - "db:verify", - [{"from", StateData#state.myname}, - {"to", StateData#state.server}, - {"id", SID}], - [{xmlcdata, Key2}]}) - end, - {next_state, wait_for_validation, StateData#state{new = New}}; + case {xml:get_attr_s("xmlns", Attrs), + xml:get_attr_s("xmlns:db", Attrs), + xml:get_attr_s("version", Attrs) == "1.0"} of + {"jabber:server", "jabber:server:dialback", false} -> + send_db_request(StateData); + {"jabber:server", "jabber:server:dialback", true} when + StateData#state.use_v10 -> + {next_state, wait_for_features, StateData}; + {"jabber:server", "", true} when StateData#state.use_v10 -> + ?INFO_MSG("restarted: ~p", [{StateData#state.myname, + StateData#state.server}]), + % TODO: clear message queue + (StateData#state.sockmod):close(StateData#state.socket), + gen_fsm:send_event(self(), init), + {next_state, open_socket, StateData#state{socket = undefined, + use_v10 = false}}; _ -> - send_text(StateData#state.socket, ?INVALID_NAMESPACE_ERR), + send_text(StateData, ?INVALID_NAMESPACE_ERR), {stop, normal, StateData} end; wait_for_stream({xmlstreamerror, _}, StateData) -> - send_text(StateData#state.socket, - ?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_text(StateData, + ?INVALID_XML_ERR ++ ?STREAM_TRAILER), {stop, normal, StateData}; wait_for_stream(timeout, StateData) -> @@ -208,7 +222,7 @@ wait_for_validation({xmlstreamelement, El}, StateData) -> ?INFO_MSG("recv result: ~p", [{From, To, Id, Type}]), case Type of "valid" -> - send_queue(StateData#state.socket, StateData#state.queue), + send_queue(StateData, StateData#state.queue), {next_state, stream_established, StateData#state{queue = queue:new()}}; _ -> @@ -248,8 +262,8 @@ wait_for_validation({xmlstreamend, Name}, StateData) -> {stop, normal, StateData}; wait_for_validation({xmlstreamerror, _}, StateData) -> - send_text(StateData#state.socket, - ?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_text(StateData, + ?INVALID_XML_ERR ++ ?STREAM_TRAILER), {stop, normal, StateData}; wait_for_validation(timeout, StateData) -> @@ -259,6 +273,111 @@ wait_for_validation(closed, StateData) -> {stop, normal, StateData}. +wait_for_features({xmlstreamelement, El}, StateData) -> + case El of + {xmlelement, "stream:features", _Attrs, Els} -> + {StartTLS, StartTLSRequired} = + lists:foldl( + fun({xmlelement, "starttls", Attrs1, Els1} = El1, Acc) -> + case xml:get_attr_s("xmlns", Attrs1) of + ?NS_TLS -> + Req = case xml:get_subtag(El1, "required") of + {xmlelement, _, _, _} -> true; + false -> false + end, + {true, Req}; + _ -> + Acc + end; + (_, Acc) -> + Acc + end, {false, false}, Els), + if + StartTLS and StateData#state.tls and + (not StateData#state.tls_enabled) -> + StateData#state.receiver ! {change_timeout, 100}, + send_element(StateData, + {xmlelement, "starttls", + [{"xmlns", ?NS_TLS}], []}), + {next_state, wait_for_starttls_proceed, StateData}; + StartTLSRequired and (not StateData#state.tls) -> + ?INFO_MSG("restarted: ~p", [{StateData#state.myname, + StateData#state.server}]), + (StateData#state.sockmod):close(StateData#state.socket), + gen_fsm:send_event(self(), init), + {next_state, open_socket, + StateData#state{socket = undefined, + use_v10 = false}}; + true -> + send_db_request(StateData) + end; + _ -> + send_text(StateData, + xml:element_to_string(?SERR_BAD_FORMAT) ++ + ?STREAM_TRAILER), + {stop, normal, StateData} + end; + +wait_for_features({xmlstreamend, Name}, StateData) -> + {stop, normal, StateData}; + +wait_for_features({xmlstreamerror, _}, StateData) -> + send_text(StateData, + ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + {stop, normal, StateData}; + +wait_for_features(timeout, StateData) -> + {stop, normal, StateData}; + +wait_for_features(closed, StateData) -> + {stop, normal, StateData}. + + +wait_for_starttls_proceed({xmlstreamelement, El}, StateData) -> + case El of + {xmlelement, "proceed", Attrs, _Els} -> + case xml:get_attr_s("xmlns", Attrs) of + ?NS_TLS -> + ?INFO_MSG("starttls: ~p", [{StateData#state.myname, + StateData#state.server}]), + Socket = StateData#state.socket, + TLSOpts = StateData#state.tls_options, + {ok, TLSSocket} = tls:tcp_to_tls(Socket, TLSOpts), + ejabberd_receiver:starttls( + StateData#state.receiver, TLSSocket), + StateData#state.receiver ! {change_timeout, infinity}, + NewStateData = StateData#state{sockmod = tls, + socket = TLSSocket, + streamid = new_id(), + tls_enabled = true + }, + R = send_text(NewStateData, + io_lib:format(?STREAM_HEADER, + [StateData#state.server, + " version='1.0'"])), + {next_state, wait_for_stream, NewStateData}; + _ -> + {stop, normal, StateData} + end; + _ -> + {stop, normal, StateData} + end; + +wait_for_starttls_proceed({xmlstreamend, Name}, StateData) -> + {stop, normal, StateData}; + +wait_for_starttls_proceed({xmlstreamerror, _}, StateData) -> + send_text(StateData, + ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + {stop, normal, StateData}; + +wait_for_starttls_proceed(timeout, StateData) -> + {stop, normal, StateData}; + +wait_for_starttls_proceed(closed, StateData) -> + {stop, normal, StateData}. + + stream_established({xmlstreamelement, El}, StateData) -> ?INFO_MSG("stream established", []), case is_verify_res(El) of @@ -290,8 +409,8 @@ stream_established({xmlstreamend, Name}, StateData) -> {stop, normal, StateData}; stream_established({xmlstreamerror, _}, StateData) -> - send_text(StateData#state.socket, - ?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_text(StateData, + ?INVALID_XML_ERR ++ ?STREAM_TRAILER), {stop, normal, StateData}; stream_established(timeout, StateData) -> @@ -347,7 +466,7 @@ code_change(OldVsn, StateName, StateData, Extra) -> %% {stop, Reason, NewStateData} %%---------------------------------------------------------------------- handle_info({send_text, Text}, StateName, StateData) -> - send_text(StateData#state.socket, Text), + send_text(StateData, Text), cancel_timer(StateData#state.timer), Timer = erlang:start_timer(?S2STIMEOUT, self(), []), {next_state, StateName, StateData#state{timer = Timer}}; @@ -357,7 +476,7 @@ handle_info({send_element, El}, StateName, StateData) -> Timer = erlang:start_timer(?S2STIMEOUT, self(), []), case StateName of stream_established -> - send_element(StateData#state.socket, El), + send_element(StateData, El), {next_state, StateName, StateData#state{timer = Timer}}; _ -> Q = queue:in(El, StateData#state.queue), @@ -365,17 +484,17 @@ handle_info({send_element, El}, StateName, StateData) -> timer = Timer}} end; -handle_info({tcp, Socket, Data}, StateName, StateData) -> - xml_stream:send_text(StateData#state.xmlpid, Data), - {next_state, StateName, StateData}; - -handle_info({tcp_closed, Socket}, StateName, StateData) -> - gen_fsm:send_event(self(), closed), - {next_state, StateName, StateData}; - -handle_info({tcp_error, Socket, Reason}, StateName, StateData) -> - gen_fsm:send_event(self(), closed), - {next_state, StateName, StateData}; +%handle_info({tcp, Socket, Data}, StateName, StateData) -> +% xml_stream:send_text(StateData#state.xmlpid, Data), +% {next_state, StateName, StateData}; +% +%handle_info({tcp_closed, Socket}, StateName, StateData) -> +% gen_fsm:send_event(self(), closed), +% {next_state, StateName, StateData}; +% +%handle_info({tcp_error, Socket, Reason}, StateName, StateData) -> +% gen_fsm:send_event(self(), closed), +% {next_state, StateName, StateData}; handle_info({timeout, Timer, _}, StateName, #state{timer = Timer} = StateData) -> @@ -404,8 +523,7 @@ terminate(Reason, StateName, StateData) -> undefined -> ok; Socket -> - gen_tcp:close(Socket), - exit(StateData#state.xmlpid, closed) + (StateData#state.sockmod):close(Socket) end, ok. @@ -413,17 +531,17 @@ terminate(Reason, StateName, StateData) -> %%% Internal functions %%%---------------------------------------------------------------------- -send_text(Socket, Text) -> - gen_tcp:send(Socket,Text). +send_text(StateData, Text) -> + (StateData#state.sockmod):send(StateData#state.socket, Text). -send_element(Socket, El) -> - send_text(Socket, xml:element_to_string(El)). +send_element(StateData, El) -> + send_text(StateData, xml:element_to_string(El)). -send_queue(Socket, Q) -> +send_queue(StateData, Q) -> case queue:out(Q) of {{value, El}, Q1} -> - send_element(Socket, El), - send_queue(Socket, Q1); + send_element(StateData, El), + send_queue(StateData, Q1); {empty, Q1} -> ok end. @@ -470,20 +588,46 @@ bounce_messages(Error) -> ok end. -%is_key_packet({xmlelement, Name, Attrs, Els}) when Name == "db:result" -> -% {key, -% xml:get_attr_s("to", Attrs), -% xml:get_attr_s("from", Attrs), -% xml:get_attr_s("id", Attrs), -% xml:get_cdata(Els)}; -%is_key_packet({xmlelement, Name, Attrs, Els}) when Name == "db:verify" -> -% {verify, -% xml:get_attr_s("to", Attrs), -% xml:get_attr_s("from", Attrs), -% xml:get_attr_s("id", Attrs), -% xml:get_cdata(Els)}; -%is_key_packet(_) -> -% false. + +send_db_request(StateData) -> + Server = StateData#state.server, + New = case StateData#state.new of + false -> + case ejabberd_s2s:try_register( + {StateData#state.myname, Server}) of + {key, Key} -> + Key; + false -> + false + end; + Key -> + Key + end, + case New of + false -> + ok; + Key1 -> + send_element(StateData, + {xmlelement, + "db:result", + [{"from", StateData#state.myname}, + {"to", Server}], + [{xmlcdata, Key1}]}) + end, + case StateData#state.verify of + false -> + ok; + {Pid, Key2, SID} -> + send_element(StateData, + {xmlelement, + "db:verify", + [{"from", StateData#state.myname}, + {"to", StateData#state.server}, + {"id", SID}], + [{xmlcdata, Key2}]}) + end, + {next_state, wait_for_validation, StateData#state{new = New}}. + is_verify_res({xmlelement, Name, Attrs, Els}) when Name == "db:result" -> {result, diff --git a/src/tls/tls.erl b/src/tls/tls.erl index 361c92fcf..e1925520f 100644 --- a/src/tls/tls.erl +++ b/src/tls/tls.erl @@ -27,11 +27,12 @@ code_change/3, terminate/2]). --define(SET_CERTIFICATE_FILE, 1). --define(SET_ENCRYPTED_INPUT, 2). --define(SET_DECRYPTED_OUTPUT, 3). --define(GET_ENCRYPTED_OUTPUT, 4). --define(GET_DECRYPTED_INPUT, 5). +-define(SET_CERTIFICATE_FILE_ACCEPT, 1). +-define(SET_CERTIFICATE_FILE_CONNECT, 2). +-define(SET_ENCRYPTED_INPUT, 3). +-define(SET_DECRYPTED_OUTPUT, 4). +-define(GET_ENCRYPTED_OUTPUT, 5). +-define(GET_DECRYPTED_INPUT, 6). -record(tlssock, {tcpsock, tlsport}). @@ -44,7 +45,7 @@ start_link() -> init([]) -> ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv), Port = open_port({spawn, tls_drv}, [binary]), - Res = port_control(Port, ?SET_CERTIFICATE_FILE, "./ssl.pem" ++ [0]), + Res = port_control(Port, ?SET_CERTIFICATE_FILE_ACCEPT, "./ssl.pem" ++ [0]), case Res of <<0>> -> %ets:new(iconv_table, [set, public, named_table]), @@ -86,8 +87,13 @@ tcp_to_tls(TCPSocket, Options) -> {value, {certfile, CertFile}} -> ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv), Port = open_port({spawn, tls_drv}, [binary]), - case port_control(Port, ?SET_CERTIFICATE_FILE, - CertFile ++ [0]) of + Command = case lists:member(connect, Options) of + true -> + ?SET_CERTIFICATE_FILE_CONNECT; + false -> + ?SET_CERTIFICATE_FILE_ACCEPT + end, + case port_control(Port, Command, CertFile ++ [0]) of <<0>> -> {ok, #tlssock{tcpsock = TCPSocket, tlsport = Port}}; <<1, Error/binary>> -> @@ -145,7 +151,10 @@ send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) -> {error, binary_to_list(Error)} end; <<1, Error/binary>> -> - {error, binary_to_list(Error)} + {error, binary_to_list(Error)}; + <<2>> -> % Dirty hack + receive after 100 -> ok end, + send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) end. @@ -158,7 +167,8 @@ test() -> ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv), Port = open_port({spawn, tls_drv}, [binary]), io:format("open_port: ~p~n", [Port]), - PCRes = port_control(Port, ?SET_CERTIFICATE_FILE, "./ssl.pem" ++ [0]), + PCRes = port_control(Port, ?SET_CERTIFICATE_FILE_ACCEPT, + "./ssl.pem" ++ [0]), io:format("port_control: ~p~n", [PCRes]), {ok, ListenSocket} = gen_tcp:listen(1234, [binary, {packet, 0}, diff --git a/src/tls/tls_drv.c b/src/tls/tls_drv.c index f320ee31f..608830ffb 100644 --- a/src/tls/tls_drv.c +++ b/src/tls/tls_drv.c @@ -4,6 +4,7 @@ #include #include #include +#include #define BUF_SIZE 1024 @@ -45,11 +46,12 @@ static void tls_drv_stop(ErlDrvData handle) } -#define SET_CERTIFICATE_FILE 1 -#define SET_ENCRYPTED_INPUT 2 -#define SET_DECRYPTED_OUTPUT 3 -#define GET_ENCRYPTED_OUTPUT 4 -#define GET_DECRYPTED_INPUT 5 +#define SET_CERTIFICATE_FILE_ACCEPT 1 +#define SET_CERTIFICATE_FILE_CONNECT 2 +#define SET_ENCRYPTED_INPUT 3 +#define SET_DECRYPTED_OUTPUT 4 +#define GET_ENCRYPTED_OUTPUT 5 +#define GET_DECRYPTED_INPUT 6 #define die_unless(cond, errstr) \ @@ -76,8 +78,9 @@ static int tls_drv_control(ErlDrvData handle, switch (command) { - case SET_CERTIFICATE_FILE: - d->ctx = SSL_CTX_new(SSLv23_server_method()); + case SET_CERTIFICATE_FILE_ACCEPT: + case SET_CERTIFICATE_FILE_CONNECT: + d->ctx = SSL_CTX_new(SSLv23_method()); die_unless(d->ctx, "SSL_CTX_new failed"); res = SSL_CTX_use_certificate_file(d->ctx, buf, SSL_FILETYPE_PEM); @@ -97,7 +100,10 @@ static int tls_drv_control(ErlDrvData handle, SSL_set_bio(d->ssl, d->bio_read, d->bio_write); - SSL_set_accept_state(d->ssl); + if (command == SET_CERTIFICATE_FILE_ACCEPT) + SSL_set_accept_state(d->ssl); + else + SSL_set_connect_state(d->ssl); break; case SET_ENCRYPTED_INPUT: die_unless(d->ssl, "SSL not initialized"); @@ -106,6 +112,19 @@ static int tls_drv_control(ErlDrvData handle, case SET_DECRYPTED_OUTPUT: die_unless(d->ssl, "SSL not initialized"); res = SSL_write(d->ssl, buf, len); + if (res <= 0) + { + res = SSL_get_error(d->ssl, res); + if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) + { + b = driver_alloc_binary(1); + b->orig_bytes[0] = 2; + *rbuf = (char *)b; + return 1; + } else { + die_unless(0, "SSL_write failed"); + } + } break; case GET_ENCRYPTED_OUTPUT: die_unless(d->ssl, "SSL not initialized"); @@ -128,13 +147,10 @@ static int tls_drv_control(ErlDrvData handle, case GET_DECRYPTED_INPUT: if (!SSL_is_init_finished(d->ssl)) { - //printf("Doing SSL_accept\r\n"); - res = SSL_accept(d->ssl); - //if (res == 0) - // printf("SSL_accept returned zero\r\n"); - if (res < 0) + res = SSL_do_handshake(d->ssl); + if (res <= 0) die_unless(SSL_get_error(d->ssl, res) == SSL_ERROR_WANT_READ, - "SSL_accept failed"); + "SSL_do_handshake failed"); } else { size = BUF_SIZE + 1; rlen = 1;