* src/tls/tls_drv.c: Support for "connect" method

* src/tls/tls.erl: Likewise

* src/ejabberd_s2s_in.erl: Support for STARTTLS+Dialback
* src/ejabberd_s2s_out.erl: Likewise
* src/ejabberd_receiver.erl: Added a few hacks ({active,once} mode
should be used instead of recv/3 call to avoid them)
* src/ejabberd_config.erl: Added s2s_use_starttls and s2s_certfile
options
* src/ejabberd.cfg.example: Likewise

SVN Revision: 426
This commit is contained in:
Alexey Shchepin 2005-10-25 01:08:37 +00:00
parent 6309f41b9a
commit 1433dafe6b
8 changed files with 427 additions and 164 deletions

View File

@ -1,3 +1,16 @@
2005-10-25 Alexey Shchepin <alexey@sevcom.net>
* src/tls/tls_drv.c: Support for "connect" method
* src/tls/tls.erl: Likewise
* src/ejabberd_s2s_in.erl: Support for STARTTLS+Dialback
* src/ejabberd_s2s_out.erl: Likewise
* src/ejabberd_receiver.erl: Added a few hacks ({active,once} mode
should be used instead of recv/3 call to avoid them)
* src/ejabberd_config.erl: Added s2s_use_starttls and s2s_certfile
options
* src/ejabberd.cfg.example: Likewise
2005-10-22 Alexey Shchepin <alexey@sevcom.net>
* src/ejabberd_app.erl: Try to load tls_drv at startup to avoid

View File

@ -115,6 +115,11 @@
[{password, "secret"}]}]}
]}.
% Use STARTTLS+Dialback for S2S connections
{s2s_use_starttls, true}.
{s2s_certfile, "./ssl.pem"}.
% If SRV lookup fails, then port 5269 is used to communicate with remote server
{outgoing_s2s_port, 5269}.

View File

@ -108,6 +108,10 @@ process_term(Term, State) ->
add_option(listen, Val, State);
{outgoing_s2s_port, Port} ->
add_option(outgoing_s2s_port, Port, State);
{s2s_use_starttls, Port} ->
add_option(s2s_use_starttls, Port, State);
{s2s_certfile, Port} ->
add_option(s2s_certfile, Port, State);
{Opt, Val} ->
lists:foldl(fun(Host, S) -> process_host_term(Term, Host, S) end,
State, State#state.hosts)

View File

@ -36,24 +36,27 @@ receiver(Socket, SockMod, Shaper, C2SPid) ->
receiver(Socket, SockMod, ShaperState, C2SPid, XMLStreamState, Timeout) ->
Res = (catch SockMod:recv(Socket, 0, Timeout)),
case Res of
{ok, Data} ->
receive
{starttls, TLSSocket} ->
xml_stream:close(XMLStreamState),
XMLStreamState1 = xml_stream:new(C2SPid),
TLSRes = tls:recv_data(TLSSocket, Data),
receiver1(TLSSocket, tls,
ShaperState, C2SPid, XMLStreamState1, Timeout,
TLSRes)
after 0 ->
receiver1(Socket, SockMod,
ShaperState, C2SPid, XMLStreamState, Timeout,
Res)
end;
_ ->
receive
{starttls, TLSSocket} ->
xml_stream:close(XMLStreamState),
XMLStreamState1 = xml_stream:new(C2SPid),
TLSRes = case Res of
{ok, Data} ->
tls:recv_data(TLSSocket, Data);
_ ->
tls:recv_data(TLSSocket, "")
end,
receiver1(TLSSocket, tls,
ShaperState, C2SPid, XMLStreamState1, Timeout,
TLSRes);
{change_timeout, NewTimeout} -> % Dirty hack
receiver1(Socket, SockMod,
ShaperState, C2SPid, XMLStreamState, Timeout, Res)
ShaperState, C2SPid, XMLStreamState, NewTimeout,
Res)
after 0 ->
receiver1(Socket, SockMod,
ShaperState, C2SPid, XMLStreamState, Timeout,
Res)
end.

View File

@ -14,13 +14,12 @@
%% External exports
-export([start/2,
start_link/2,
send_text/2,
send_element/2]).
start_link/2]).
%% gen_fsm callbacks
-export([init/1,
wait_for_stream/2,
wait_for_feature_request/2,
stream_established/2,
handle_event/3,
handle_sync_event/4,
@ -34,9 +33,13 @@
-define(DICT, dict).
-record(state, {socket,
sockmod,
receiver,
streamid,
shaper,
tls = false,
tls_enabled = false,
tls_options = [],
connections = ?DICT:new(),
timer}).
@ -49,13 +52,13 @@
-define(FSMOPTS, []).
-endif.
-define(STREAM_HEADER,
-define(STREAM_HEADER(Version),
("<?xml version='1.0'?>"
"<stream:stream "
"xmlns:stream='http://etherx.jabber.org/streams' "
"xmlns='jabber:server' "
"xmlns:db='jabber:server:dialback' "
"id='" ++ StateData#state.streamid ++ "'>")
"id='" ++ StateData#state.streamid ++ "'" ++ Version ++ ">")
).
-define(STREAM_TRAILER, "</stream:stream>").
@ -96,12 +99,28 @@ init([{SockMod, Socket}, Opts]) ->
{value, {_, S}} -> S;
_ -> none
end,
StartTLS = case ejabberd_config:get_local_option(s2s_use_starttls) of
undefined ->
false;
UseStartTLS ->
UseStartTLS
end,
TLSOpts = case ejabberd_config:get_local_option(s2s_certfile) of
undefined ->
[];
CertFile ->
[{certfile, CertFile}]
end,
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
{ok, wait_for_stream,
#state{socket = Socket,
sockmod = SockMod,
receiver = ReceiverPid,
streamid = new_id(),
shaper = Shaper,
tls = StartTLS,
tls_enabled = false,
tls_options = TLSOpts,
timer = Timer}}.
%%----------------------------------------------------------------------
@ -113,18 +132,28 @@ init([{SockMod, Socket}, Opts]) ->
wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) ->
% TODO
case {xml:get_attr_s("xmlns", Attrs), xml:get_attr_s("xmlns:db", Attrs)} of
{"jabber:server", "jabber:server:dialback"} ->
send_text(StateData#state.socket, ?STREAM_HEADER),
{next_state, stream_established, StateData#state{}};
case {xml:get_attr_s("xmlns", Attrs),
xml:get_attr_s("xmlns:db", Attrs),
xml:get_attr_s("version", Attrs) == "1.0"} of
{"jabber:server", "jabber:server:dialback", true} when
StateData#state.tls ->
send_text(StateData, ?STREAM_HEADER(" version='1.0'")),
send_element(StateData,
{xmlelement, "stream:features", [],
[{xmlelement, "starttls",
[{"xmlns", ?NS_TLS}], []}]}),
{next_state, wait_for_feature_request, StateData};
{"jabber:server", "jabber:server:dialback", _} ->
send_text(StateData, ?STREAM_HEADER("")),
{next_state, stream_established, StateData};
_ ->
send_text(StateData#state.socket, ?INVALID_NAMESPACE_ERR),
send_text(StateData, ?INVALID_NAMESPACE_ERR),
{stop, normal, StateData}
end;
wait_for_stream({xmlstreamerror, _}, StateData) ->
send_text(StateData#state.socket,
?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
send_text(StateData,
?STREAM_HEADER("") ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_stream(timeout, StateData) ->
@ -133,6 +162,45 @@ wait_for_stream(timeout, StateData) ->
wait_for_stream(closed, StateData) ->
{stop, normal, StateData}.
wait_for_feature_request({xmlstreamelement, El}, StateData) ->
{xmlelement, Name, Attrs, Els} = El,
TLS = StateData#state.tls,
TLSEnabled = StateData#state.tls_enabled,
SockMod = StateData#state.sockmod,
case {xml:get_attr_s("xmlns", Attrs), Name} of
{?NS_TLS, "starttls"} when TLS == true,
TLSEnabled == false,
SockMod == gen_tcp ->
?INFO_MSG("starttls", []),
Socket = StateData#state.socket,
TLSOpts = StateData#state.tls_options,
{ok, TLSSocket} = tls:tcp_to_tls(Socket, TLSOpts),
ejabberd_receiver:starttls(StateData#state.receiver, TLSSocket),
send_element(StateData,
{xmlelement, "proceed", [{"xmlns", ?NS_TLS}], []}),
{next_state, wait_for_stream,
StateData#state{sockmod = tls,
socket = TLSSocket,
streamid = new_id(),
tls_enabled = true
}};
_ ->
stream_established({xmlstreamelement, El}, StateData)
end;
wait_for_feature_request({xmlstreamend, _Name}, StateData) ->
send_text(StateData, ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_feature_request({xmlstreamerror, _}, StateData) ->
send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_feature_request(closed, StateData) ->
{stop, normal, StateData}.
stream_established({xmlstreamelement, El}, StateData) ->
cancel_timer(StateData#state.timer),
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
@ -154,7 +222,7 @@ stream_established({xmlstreamelement, El}, StateData) ->
StateData#state{connections = Conns,
timer = Timer}};
_ ->
send_text(StateData#state.socket, ?HOST_UNKNOWN_ERR),
send_text(StateData, ?HOST_UNKNOWN_ERR),
{stop, normal, StateData}
end;
{verify, To, From, Id, Key} ->
@ -165,7 +233,7 @@ stream_established({xmlstreamelement, El}, StateData) ->
Type = if Key == Key1 -> "valid";
true -> "invalid"
end,
send_element(StateData#state.socket,
send_element(StateData,
{xmlelement,
"db:verify",
[{"from", To},
@ -204,7 +272,7 @@ stream_established({xmlstreamelement, El}, StateData) ->
end;
stream_established({valid, From, To}, StateData) ->
send_element(StateData#state.socket,
send_element(StateData,
{xmlelement,
"db:result",
[{"from", To},
@ -219,7 +287,7 @@ stream_established({valid, From, To}, StateData) ->
{next_state, stream_established, NSD};
stream_established({invalid, From, To}, StateData) ->
send_element(StateData#state.socket,
send_element(StateData,
{xmlelement,
"db:result",
[{"from", To},
@ -237,8 +305,8 @@ stream_established({xmlstreamend, _Name}, StateData) ->
{stop, normal, StateData};
stream_established({xmlstreamerror, _}, StateData) ->
send_text(StateData#state.socket,
?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
send_text(StateData,
?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
stream_established(timeout, StateData) ->
@ -294,7 +362,7 @@ code_change(_OldVsn, StateName, StateData, _Extra) ->
%% {stop, Reason, NewStateData}
%%----------------------------------------------------------------------
handle_info({send_text, Text}, StateName, StateData) ->
send_text(StateData#state.socket, Text),
send_text(StateData, Text),
{next_state, StateName, StateData};
handle_info({timeout, Timer, _}, StateName,
@ -312,18 +380,18 @@ handle_info(_, StateName, StateData) ->
%%----------------------------------------------------------------------
terminate(Reason, _StateName, StateData) ->
?INFO_MSG("terminated: ~p", [Reason]),
gen_tcp:close(StateData#state.socket),
(StateData#state.sockmod):close(StateData#state.socket),
ok.
%%%----------------------------------------------------------------------
%%% Internal functions
%%%----------------------------------------------------------------------
send_text(Socket, Text) ->
gen_tcp:send(Socket,Text).
send_text(StateData, Text) ->
(StateData#state.sockmod):send(StateData#state.socket, Text).
send_element(Socket, El) ->
send_text(Socket, xml:element_to_string(El)).
send_element(StateData, El) ->
send_text(StateData, xml:element_to_string(El)).
change_shaper(StateData, Host, JID) ->

View File

@ -13,13 +13,15 @@
-behaviour(gen_fsm).
%% External exports
-export([start/3, start_link/3, send_text/2, send_element/2]).
-export([start/3, start_link/3]).
%% gen_fsm callbacks
-export([init/1,
open_socket/2,
wait_for_stream/2,
wait_for_validation/2,
wait_for_features/2,
wait_for_starttls_proceed/2,
stream_established/2,
handle_event/3,
handle_sync_event/4,
@ -30,8 +32,15 @@
-include("ejabberd.hrl").
-include("jlib.hrl").
-record(state, {socket, receiver, streamid,
myname, server, xmlpid, queue,
-record(state, {socket, receiver,
sockmod,
streamid,
use_v10,
tls = false,
tls_required = false,
tls_enabled = false,
tls_options = [],
myname, server, queue,
new = false, verify = false,
timer}).
@ -49,7 +58,7 @@
"xmlns:stream='http://etherx.jabber.org/streams' "
"xmlns='jabber:server' "
"xmlns:db='jabber:server:dialback' "
"to='~s'>"
"to='~s'~s>"
).
-define(STREAM_TRAILER, "</stream:stream>").
@ -86,6 +95,19 @@ start_link(From, Host, Type) ->
init([From, Server, Type]) ->
?INFO_MSG("started: ~p", [{From, Server, Type}]),
gen_fsm:send_event(self(), init),
TLS = case ejabberd_config:get_local_option(s2s_use_starttls) of
undefined ->
false;
UseStartTLS ->
UseStartTLS
end,
UseV10 = TLS,
TLSOpts = case ejabberd_config:get_local_option(s2s_certfile) of
undefined ->
[];
CertFile ->
[{certfile, CertFile}, connect]
end,
{New, Verify} = case Type of
{new, Key} ->
{Key, false};
@ -93,7 +115,10 @@ init([From, Server, Type]) ->
{false, {Pid, Key, SID}}
end,
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
{ok, open_socket, #state{queue = queue:new(),
{ok, open_socket, #state{use_v10 = UseV10,
tls = TLS,
tls_options = TLSOpts,
queue = queue:new(),
myname = From,
server = Server,
new = New,
@ -113,23 +138,34 @@ open_socket(init, StateData) ->
ASCIIAddr ->
?DEBUG("s2s_out: connecting to ~s:~p~n", [ASCIIAddr, Port]),
case gen_tcp:connect(ASCIIAddr, Port,
[binary, {packet, 0}]) of
[binary, {packet, 0},
{active, false}]) of
{ok, _Socket} = R -> R;
{error, Reason1} ->
?DEBUG("s2s_out: connect return ~p~n", [Reason1]),
catch gen_tcp:connect(Addr, Port,
[binary, {packet, 0}, inet6])
[binary, {packet, 0},
{active, false}, inet6])
end
end,
case Res of
{ok, Socket} ->
XMLStreamPid = xml_stream:start(self()),
send_text(Socket, io_lib:format(?STREAM_HEADER,
[StateData#state.server])),
{next_state, wait_for_stream,
StateData#state{socket = Socket,
xmlpid = XMLStreamPid,
streamid = new_id()}};
ReceiverPid = ejabberd_receiver:start(Socket, gen_tcp, none),
Version = if
StateData#state.use_v10 ->
" version='1.0'";
true ->
""
end,
NewStateData = StateData#state{socket = Socket,
sockmod = gen_tcp,
tls_enabled = false,
receiver = ReceiverPid,
streamid = new_id()},
send_text(NewStateData, io_lib:format(?STREAM_HEADER,
[StateData#state.server,
Version])),
{next_state, wait_for_stream, NewStateData};
{error, Reason} ->
?DEBUG("s2s_out: inet6 connect return ~p~n", [Reason]),
Error = ?ERR_REMOTE_SERVER_NOT_FOUND,
@ -140,58 +176,36 @@ open_socket(init, StateData) ->
Error = ?ERR_REMOTE_SERVER_NOT_FOUND,
bounce_messages(Error),
{stop, normal, StateData}
end.
end;
open_socket(_, StateData) ->
{next_state, open_socket, StateData}.
wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
% TODO
case {xml:get_attr_s("xmlns", Attrs), xml:get_attr_s("xmlns:db", Attrs)} of
{"jabber:server", "jabber:server:dialback"} ->
Server = StateData#state.server,
New = case StateData#state.new of
false ->
case ejabberd_s2s:try_register(
{StateData#state.myname, Server}) of
{key, Key} ->
Key;
false ->
false
end;
Key ->
Key
end,
case New of
false ->
ok;
Key1 ->
send_element(StateData#state.socket,
{xmlelement,
"db:result",
[{"from", StateData#state.myname},
{"to", Server}],
[{xmlcdata, Key1}]})
end,
case StateData#state.verify of
false ->
ok;
{Pid, Key2, SID} ->
send_element(StateData#state.socket,
{xmlelement,
"db:verify",
[{"from", StateData#state.myname},
{"to", StateData#state.server},
{"id", SID}],
[{xmlcdata, Key2}]})
end,
{next_state, wait_for_validation, StateData#state{new = New}};
case {xml:get_attr_s("xmlns", Attrs),
xml:get_attr_s("xmlns:db", Attrs),
xml:get_attr_s("version", Attrs) == "1.0"} of
{"jabber:server", "jabber:server:dialback", false} ->
send_db_request(StateData);
{"jabber:server", "jabber:server:dialback", true} when
StateData#state.use_v10 ->
{next_state, wait_for_features, StateData};
{"jabber:server", "", true} when StateData#state.use_v10 ->
?INFO_MSG("restarted: ~p", [{StateData#state.myname,
StateData#state.server}]),
% TODO: clear message queue
(StateData#state.sockmod):close(StateData#state.socket),
gen_fsm:send_event(self(), init),
{next_state, open_socket, StateData#state{socket = undefined,
use_v10 = false}};
_ ->
send_text(StateData#state.socket, ?INVALID_NAMESPACE_ERR),
send_text(StateData, ?INVALID_NAMESPACE_ERR),
{stop, normal, StateData}
end;
wait_for_stream({xmlstreamerror, _}, StateData) ->
send_text(StateData#state.socket,
?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
send_text(StateData,
?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_stream(timeout, StateData) ->
@ -208,7 +222,7 @@ wait_for_validation({xmlstreamelement, El}, StateData) ->
?INFO_MSG("recv result: ~p", [{From, To, Id, Type}]),
case Type of
"valid" ->
send_queue(StateData#state.socket, StateData#state.queue),
send_queue(StateData, StateData#state.queue),
{next_state, stream_established,
StateData#state{queue = queue:new()}};
_ ->
@ -248,8 +262,8 @@ wait_for_validation({xmlstreamend, Name}, StateData) ->
{stop, normal, StateData};
wait_for_validation({xmlstreamerror, _}, StateData) ->
send_text(StateData#state.socket,
?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
send_text(StateData,
?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_validation(timeout, StateData) ->
@ -259,6 +273,111 @@ wait_for_validation(closed, StateData) ->
{stop, normal, StateData}.
wait_for_features({xmlstreamelement, El}, StateData) ->
case El of
{xmlelement, "stream:features", _Attrs, Els} ->
{StartTLS, StartTLSRequired} =
lists:foldl(
fun({xmlelement, "starttls", Attrs1, Els1} = El1, Acc) ->
case xml:get_attr_s("xmlns", Attrs1) of
?NS_TLS ->
Req = case xml:get_subtag(El1, "required") of
{xmlelement, _, _, _} -> true;
false -> false
end,
{true, Req};
_ ->
Acc
end;
(_, Acc) ->
Acc
end, {false, false}, Els),
if
StartTLS and StateData#state.tls and
(not StateData#state.tls_enabled) ->
StateData#state.receiver ! {change_timeout, 100},
send_element(StateData,
{xmlelement, "starttls",
[{"xmlns", ?NS_TLS}], []}),
{next_state, wait_for_starttls_proceed, StateData};
StartTLSRequired and (not StateData#state.tls) ->
?INFO_MSG("restarted: ~p", [{StateData#state.myname,
StateData#state.server}]),
(StateData#state.sockmod):close(StateData#state.socket),
gen_fsm:send_event(self(), init),
{next_state, open_socket,
StateData#state{socket = undefined,
use_v10 = false}};
true ->
send_db_request(StateData)
end;
_ ->
send_text(StateData,
xml:element_to_string(?SERR_BAD_FORMAT) ++
?STREAM_TRAILER),
{stop, normal, StateData}
end;
wait_for_features({xmlstreamend, Name}, StateData) ->
{stop, normal, StateData};
wait_for_features({xmlstreamerror, _}, StateData) ->
send_text(StateData,
?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_features(timeout, StateData) ->
{stop, normal, StateData};
wait_for_features(closed, StateData) ->
{stop, normal, StateData}.
wait_for_starttls_proceed({xmlstreamelement, El}, StateData) ->
case El of
{xmlelement, "proceed", Attrs, _Els} ->
case xml:get_attr_s("xmlns", Attrs) of
?NS_TLS ->
?INFO_MSG("starttls: ~p", [{StateData#state.myname,
StateData#state.server}]),
Socket = StateData#state.socket,
TLSOpts = StateData#state.tls_options,
{ok, TLSSocket} = tls:tcp_to_tls(Socket, TLSOpts),
ejabberd_receiver:starttls(
StateData#state.receiver, TLSSocket),
StateData#state.receiver ! {change_timeout, infinity},
NewStateData = StateData#state{sockmod = tls,
socket = TLSSocket,
streamid = new_id(),
tls_enabled = true
},
R = send_text(NewStateData,
io_lib:format(?STREAM_HEADER,
[StateData#state.server,
" version='1.0'"])),
{next_state, wait_for_stream, NewStateData};
_ ->
{stop, normal, StateData}
end;
_ ->
{stop, normal, StateData}
end;
wait_for_starttls_proceed({xmlstreamend, Name}, StateData) ->
{stop, normal, StateData};
wait_for_starttls_proceed({xmlstreamerror, _}, StateData) ->
send_text(StateData,
?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_starttls_proceed(timeout, StateData) ->
{stop, normal, StateData};
wait_for_starttls_proceed(closed, StateData) ->
{stop, normal, StateData}.
stream_established({xmlstreamelement, El}, StateData) ->
?INFO_MSG("stream established", []),
case is_verify_res(El) of
@ -290,8 +409,8 @@ stream_established({xmlstreamend, Name}, StateData) ->
{stop, normal, StateData};
stream_established({xmlstreamerror, _}, StateData) ->
send_text(StateData#state.socket,
?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
send_text(StateData,
?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
stream_established(timeout, StateData) ->
@ -347,7 +466,7 @@ code_change(OldVsn, StateName, StateData, Extra) ->
%% {stop, Reason, NewStateData}
%%----------------------------------------------------------------------
handle_info({send_text, Text}, StateName, StateData) ->
send_text(StateData#state.socket, Text),
send_text(StateData, Text),
cancel_timer(StateData#state.timer),
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
{next_state, StateName, StateData#state{timer = Timer}};
@ -357,7 +476,7 @@ handle_info({send_element, El}, StateName, StateData) ->
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
case StateName of
stream_established ->
send_element(StateData#state.socket, El),
send_element(StateData, El),
{next_state, StateName, StateData#state{timer = Timer}};
_ ->
Q = queue:in(El, StateData#state.queue),
@ -365,17 +484,17 @@ handle_info({send_element, El}, StateName, StateData) ->
timer = Timer}}
end;
handle_info({tcp, Socket, Data}, StateName, StateData) ->
xml_stream:send_text(StateData#state.xmlpid, Data),
{next_state, StateName, StateData};
handle_info({tcp_closed, Socket}, StateName, StateData) ->
gen_fsm:send_event(self(), closed),
{next_state, StateName, StateData};
handle_info({tcp_error, Socket, Reason}, StateName, StateData) ->
gen_fsm:send_event(self(), closed),
{next_state, StateName, StateData};
%handle_info({tcp, Socket, Data}, StateName, StateData) ->
% xml_stream:send_text(StateData#state.xmlpid, Data),
% {next_state, StateName, StateData};
%
%handle_info({tcp_closed, Socket}, StateName, StateData) ->
% gen_fsm:send_event(self(), closed),
% {next_state, StateName, StateData};
%
%handle_info({tcp_error, Socket, Reason}, StateName, StateData) ->
% gen_fsm:send_event(self(), closed),
% {next_state, StateName, StateData};
handle_info({timeout, Timer, _}, StateName,
#state{timer = Timer} = StateData) ->
@ -404,8 +523,7 @@ terminate(Reason, StateName, StateData) ->
undefined ->
ok;
Socket ->
gen_tcp:close(Socket),
exit(StateData#state.xmlpid, closed)
(StateData#state.sockmod):close(Socket)
end,
ok.
@ -413,17 +531,17 @@ terminate(Reason, StateName, StateData) ->
%%% Internal functions
%%%----------------------------------------------------------------------
send_text(Socket, Text) ->
gen_tcp:send(Socket,Text).
send_text(StateData, Text) ->
(StateData#state.sockmod):send(StateData#state.socket, Text).
send_element(Socket, El) ->
send_text(Socket, xml:element_to_string(El)).
send_element(StateData, El) ->
send_text(StateData, xml:element_to_string(El)).
send_queue(Socket, Q) ->
send_queue(StateData, Q) ->
case queue:out(Q) of
{{value, El}, Q1} ->
send_element(Socket, El),
send_queue(Socket, Q1);
send_element(StateData, El),
send_queue(StateData, Q1);
{empty, Q1} ->
ok
end.
@ -470,20 +588,46 @@ bounce_messages(Error) ->
ok
end.
%is_key_packet({xmlelement, Name, Attrs, Els}) when Name == "db:result" ->
% {key,
% xml:get_attr_s("to", Attrs),
% xml:get_attr_s("from", Attrs),
% xml:get_attr_s("id", Attrs),
% xml:get_cdata(Els)};
%is_key_packet({xmlelement, Name, Attrs, Els}) when Name == "db:verify" ->
% {verify,
% xml:get_attr_s("to", Attrs),
% xml:get_attr_s("from", Attrs),
% xml:get_attr_s("id", Attrs),
% xml:get_cdata(Els)};
%is_key_packet(_) ->
% false.
send_db_request(StateData) ->
Server = StateData#state.server,
New = case StateData#state.new of
false ->
case ejabberd_s2s:try_register(
{StateData#state.myname, Server}) of
{key, Key} ->
Key;
false ->
false
end;
Key ->
Key
end,
case New of
false ->
ok;
Key1 ->
send_element(StateData,
{xmlelement,
"db:result",
[{"from", StateData#state.myname},
{"to", Server}],
[{xmlcdata, Key1}]})
end,
case StateData#state.verify of
false ->
ok;
{Pid, Key2, SID} ->
send_element(StateData,
{xmlelement,
"db:verify",
[{"from", StateData#state.myname},
{"to", StateData#state.server},
{"id", SID}],
[{xmlcdata, Key2}]})
end,
{next_state, wait_for_validation, StateData#state{new = New}}.
is_verify_res({xmlelement, Name, Attrs, Els}) when Name == "db:result" ->
{result,

View File

@ -27,11 +27,12 @@
code_change/3,
terminate/2]).
-define(SET_CERTIFICATE_FILE, 1).
-define(SET_ENCRYPTED_INPUT, 2).
-define(SET_DECRYPTED_OUTPUT, 3).
-define(GET_ENCRYPTED_OUTPUT, 4).
-define(GET_DECRYPTED_INPUT, 5).
-define(SET_CERTIFICATE_FILE_ACCEPT, 1).
-define(SET_CERTIFICATE_FILE_CONNECT, 2).
-define(SET_ENCRYPTED_INPUT, 3).
-define(SET_DECRYPTED_OUTPUT, 4).
-define(GET_ENCRYPTED_OUTPUT, 5).
-define(GET_DECRYPTED_INPUT, 6).
-record(tlssock, {tcpsock, tlsport}).
@ -44,7 +45,7 @@ start_link() ->
init([]) ->
ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv),
Port = open_port({spawn, tls_drv}, [binary]),
Res = port_control(Port, ?SET_CERTIFICATE_FILE, "./ssl.pem" ++ [0]),
Res = port_control(Port, ?SET_CERTIFICATE_FILE_ACCEPT, "./ssl.pem" ++ [0]),
case Res of
<<0>> ->
%ets:new(iconv_table, [set, public, named_table]),
@ -86,8 +87,13 @@ tcp_to_tls(TCPSocket, Options) ->
{value, {certfile, CertFile}} ->
ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv),
Port = open_port({spawn, tls_drv}, [binary]),
case port_control(Port, ?SET_CERTIFICATE_FILE,
CertFile ++ [0]) of
Command = case lists:member(connect, Options) of
true ->
?SET_CERTIFICATE_FILE_CONNECT;
false ->
?SET_CERTIFICATE_FILE_ACCEPT
end,
case port_control(Port, Command, CertFile ++ [0]) of
<<0>> ->
{ok, #tlssock{tcpsock = TCPSocket, tlsport = Port}};
<<1, Error/binary>> ->
@ -145,7 +151,10 @@ send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) ->
{error, binary_to_list(Error)}
end;
<<1, Error/binary>> ->
{error, binary_to_list(Error)}
{error, binary_to_list(Error)};
<<2>> -> % Dirty hack
receive after 100 -> ok end,
send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet)
end.
@ -158,7 +167,8 @@ test() ->
ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv),
Port = open_port({spawn, tls_drv}, [binary]),
io:format("open_port: ~p~n", [Port]),
PCRes = port_control(Port, ?SET_CERTIFICATE_FILE, "./ssl.pem" ++ [0]),
PCRes = port_control(Port, ?SET_CERTIFICATE_FILE_ACCEPT,
"./ssl.pem" ++ [0]),
io:format("port_control: ~p~n", [PCRes]),
{ok, ListenSocket} = gen_tcp:listen(1234, [binary,
{packet, 0},

View File

@ -4,6 +4,7 @@
#include <string.h>
#include <erl_driver.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#define BUF_SIZE 1024
@ -45,11 +46,12 @@ static void tls_drv_stop(ErlDrvData handle)
}
#define SET_CERTIFICATE_FILE 1
#define SET_ENCRYPTED_INPUT 2
#define SET_DECRYPTED_OUTPUT 3
#define GET_ENCRYPTED_OUTPUT 4
#define GET_DECRYPTED_INPUT 5
#define SET_CERTIFICATE_FILE_ACCEPT 1
#define SET_CERTIFICATE_FILE_CONNECT 2
#define SET_ENCRYPTED_INPUT 3
#define SET_DECRYPTED_OUTPUT 4
#define GET_ENCRYPTED_OUTPUT 5
#define GET_DECRYPTED_INPUT 6
#define die_unless(cond, errstr) \
@ -76,8 +78,9 @@ static int tls_drv_control(ErlDrvData handle,
switch (command)
{
case SET_CERTIFICATE_FILE:
d->ctx = SSL_CTX_new(SSLv23_server_method());
case SET_CERTIFICATE_FILE_ACCEPT:
case SET_CERTIFICATE_FILE_CONNECT:
d->ctx = SSL_CTX_new(SSLv23_method());
die_unless(d->ctx, "SSL_CTX_new failed");
res = SSL_CTX_use_certificate_file(d->ctx, buf, SSL_FILETYPE_PEM);
@ -97,7 +100,10 @@ static int tls_drv_control(ErlDrvData handle,
SSL_set_bio(d->ssl, d->bio_read, d->bio_write);
SSL_set_accept_state(d->ssl);
if (command == SET_CERTIFICATE_FILE_ACCEPT)
SSL_set_accept_state(d->ssl);
else
SSL_set_connect_state(d->ssl);
break;
case SET_ENCRYPTED_INPUT:
die_unless(d->ssl, "SSL not initialized");
@ -106,6 +112,19 @@ static int tls_drv_control(ErlDrvData handle,
case SET_DECRYPTED_OUTPUT:
die_unless(d->ssl, "SSL not initialized");
res = SSL_write(d->ssl, buf, len);
if (res <= 0)
{
res = SSL_get_error(d->ssl, res);
if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE)
{
b = driver_alloc_binary(1);
b->orig_bytes[0] = 2;
*rbuf = (char *)b;
return 1;
} else {
die_unless(0, "SSL_write failed");
}
}
break;
case GET_ENCRYPTED_OUTPUT:
die_unless(d->ssl, "SSL not initialized");
@ -128,13 +147,10 @@ static int tls_drv_control(ErlDrvData handle,
case GET_DECRYPTED_INPUT:
if (!SSL_is_init_finished(d->ssl))
{
//printf("Doing SSL_accept\r\n");
res = SSL_accept(d->ssl);
//if (res == 0)
// printf("SSL_accept returned zero\r\n");
if (res < 0)
res = SSL_do_handshake(d->ssl);
if (res <= 0)
die_unless(SSL_get_error(d->ssl, res) == SSL_ERROR_WANT_READ,
"SSL_accept failed");
"SSL_do_handshake failed");
} else {
size = BUF_SIZE + 1;
rlen = 1;