diff --git a/ChangeLog b/ChangeLog index 4a8c7325f..27c2374a5 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,17 @@ +2004-07-28 Alexey Shchepin + + * src/tls/tls_drv.c: Added freeing of SSL stuff + + * src/xml_stream.erl: Added start/2 function + * src/ejabberd_receiver.erl: Now using xml_stream:start/2 + +2004-07-27 Alexey Shchepin + + * src/ejabberd_c2s.erl: Support for TLS library (not completed) + + * src/tls/tls_drv.c: Updated to return binaries instead of lists + * src/tls/tls.erl: Likewise + 2004-07-26 Alexey Shchepin * src/tls/tls.erl: Updated diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index 25de839e9..a920597f4 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -45,6 +45,9 @@ sasl_state, access, shaper, + tls = false, + tls_enabled = false, + tls_options = [], authentificated = false, jid, user = "", server = ?MYNAME, resource = "", @@ -107,7 +110,6 @@ get_presence(FsmRef) -> %% {stop, StopReason} %%---------------------------------------------------------------------- init([{SockMod, Socket}, Opts]) -> - ReceiverPid = ejabberd_receiver:start(Socket, SockMod, none), Access = case lists:keysearch(access, 1, Opts) of {value, {_, A}} -> A; _ -> all @@ -116,12 +118,31 @@ init([{SockMod, Socket}, Opts]) -> {value, {_, S}} -> S; _ -> none end, - {ok, wait_for_stream, #state{socket = Socket, - sockmod = SockMod, - receiver = ReceiverPid, - streamid = new_id(), - access = Access, - shaper = Shaper}}. + TLS = lists:member(tls, Opts), + TLSEnabled = lists:member(tls_from_start, Opts), + TLSOpts = lists:filter(fun({certfile, _}) -> true; + (_) -> false + end, Opts), + {SockMod1, Socket1, ReceiverPid} = + if + TLSEnabled -> + {ok, TLSSocket} = tls:tcp_to_tls(Socket, TLSOpts), + RecPid = ejabberd_receiver:start(TLSSocket, tls, none), + {tls, TLSSocket, RecPid}; + true -> + RecPid = ejabberd_receiver:start(Socket, SockMod, none), + {SockMod, Socket, RecPid} + end, + {ok, wait_for_stream, #state{socket = Socket1, + sockmod = SockMod1, + receiver = ReceiverPid, + tls = TLS, + tls_enabled = TLSEnabled, + tls_options = TLSOpts, + streamid = new_id(), + access = Access, + shaper = Shaper}}. + %%---------------------------------------------------------------------- %% Func: StateName/2 diff --git a/src/ejabberd_receiver.erl b/src/ejabberd_receiver.erl index 6565d042f..308d1b94c 100644 --- a/src/ejabberd_receiver.erl +++ b/src/ejabberd_receiver.erl @@ -23,7 +23,7 @@ start(Socket, SockMod, Shaper) -> receiver(Socket, SockMod, Shaper, C2SPid) -> - XMLStreamPid = xml_stream:start(C2SPid), + XMLStreamPid = xml_stream:start(self(), C2SPid), ShaperState = shaper:new(Shaper), Timeout = case SockMod of ssl -> @@ -46,7 +46,7 @@ receiver(Socket, SockMod, ShaperState, C2SPid, XMLStreamPid, Timeout) -> XMLStreamPid1 = receive reset_stream -> exit(XMLStreamPid, closed), - xml_stream:start(C2SPid) + xml_stream:start(self(), C2SPid) after 0 -> XMLStreamPid end, diff --git a/src/tls/tls.erl b/src/tls/tls.erl index 96275d3f0..edf94a169 100644 --- a/src/tls/tls.erl +++ b/src/tls/tls.erl @@ -46,12 +46,12 @@ init([]) -> Port = open_port({spawn, tls_drv}, [binary]), Res = port_control(Port, ?SET_CERTIFICATE_FILE, "./ssl.pem" ++ [0]), case Res of - [0] -> + <<0>> -> %ets:new(iconv_table, [set, public, named_table]), %ets:insert(iconv_table, {port, Port}), {ok, Port}; - [1 | Error] -> - {error, Error} + <<1, Error/binary>> -> + {error, binary_to_list(Error)} end. @@ -84,14 +84,15 @@ terminate(_Reason, Port) -> tcp_to_tls(TCPSocket, Options) -> case lists:keysearch(certfile, 1, Options) of {value, {certfile, CertFile}} -> + ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv), Port = open_port({spawn, tls_drv}, [binary]), io:format("open_port: ~p~n", [Port]), case port_control(Port, ?SET_CERTIFICATE_FILE, CertFile ++ [0]) of - [0] -> + <<0>> -> {ok, #tlssock{tcpsock = TCPSocket, tlsport = Port}}; - [1 | Error] -> - {error, Error} + <<1, Error/binary>> -> + {error, binary_to_list(Error)} end; false -> {error, no_certfile} @@ -107,40 +108,41 @@ recv(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Length, Timeout) -> case gen_tcp:recv(TCPSocket, Length, Timeout) of {ok, Packet} -> case port_control(Port, ?SET_ENCRYPTED_INPUT, Packet) of - [0] -> + <<0>> -> case port_control(Port, ?GET_DECRYPTED_INPUT, []) of - [0 | In] -> + <<0, In/binary>> -> case port_control(Port, ?GET_ENCRYPTED_OUTPUT, []) of - [0 | Out] -> + <<0, Out/binary>> -> case gen_tcp:send(TCPSocket, Out) of ok -> {ok, In}; Error -> Error end; - [1 | Error] -> - {error, Error} + <<1, Error/binary>> -> + {error, binary_to_list(Error)} end; - [1 | Error] -> - {error, Error} + <<1, Error/binary>> -> + {error, binary_to_list(Error)} end; - [1 | Error] -> - {error, Error} + <<1, Error/binary>> -> + {error, binary_to_list(Error)} end; {error, _Reason} = Error -> Error end. + send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) -> case port_control(Port, ?SET_DECRYPTED_OUTPUT, Packet) of - [0] -> + <<0>> -> case port_control(Port, ?GET_ENCRYPTED_OUTPUT, []) of - [0 | Out] -> + <<0, Out/binary>> -> gen_tcp:send(TCPSocket, Out); - [1 | Error] -> - {error, Error} + <<1, Error/binary>> -> + {error, binary_to_list(Error)} end; - [1 | Error] -> - {error, Error} + <<1, Error/binary>> -> + {error, binary_to_list(Error)} end. @@ -171,24 +173,24 @@ loop(Port, Socket) -> {tcp, Socket, Data} -> %io:format("read: ~p~n", [Data]), Res = port_control(Port, ?SET_ENCRYPTED_INPUT, Data), - %io:format("SET_ENCRYPTED_INPUT: ~p~n", [Res]), + io:format("SET_ENCRYPTED_INPUT: ~p~n", [Res]), DIRes = port_control(Port, ?GET_DECRYPTED_INPUT, Data), - %io:format("GET_DECRYPTED_INPUT: ~p~n", [DIRes]), + io:format("GET_DECRYPTED_INPUT: ~p~n", [DIRes]), case DIRes of - [0 | In] -> - io:format("input: ~s~n", [In]); - [1 | DIError] -> - io:format("GET_DECRYPTED_INPUT error: ~p~n", [DIError]) + <<0, In/binary>> -> + io:format("input: ~s~n", [binary_to_list(In)]); + <<1, DIError/binary>> -> + io:format("GET_DECRYPTED_INPUT error: ~p~n", [binary_to_list(DIError)]) end, EORes = port_control(Port, ?GET_ENCRYPTED_OUTPUT, Data), - %io:format("GET_ENCRYPTED_OUTPUT: ~p~n", [EORes]), + io:format("GET_ENCRYPTED_OUTPUT: ~p~n", [EORes]), case EORes of - [0 | Out] -> + <<0, Out/binary>> -> gen_tcp:send(Socket, Out); - [1 | EOError] -> - io:format("GET_ENCRYPTED_OUTPUT error: ~p~n", [EOError]) + <<1, EOError/binary>> -> + io:format("GET_ENCRYPTED_OUTPUT error: ~p~n", [binary_to_list(EOError)]) end, diff --git a/src/tls/tls_drv.c b/src/tls/tls_drv.c index 23d42b464..f320ee31f 100644 --- a/src/tls/tls_drv.c +++ b/src/tls/tls_drv.c @@ -26,13 +26,21 @@ static ErlDrvData tls_drv_start(ErlDrvPort port, char *buff) d->bio_write = NULL; d->ssl = NULL; + set_port_control_flags(port, PORT_CONTROL_FLAG_BINARY); + return (ErlDrvData)d; } static void tls_drv_stop(ErlDrvData handle) { - // TODO - //XML_ParserFree(((tls_data *)handle)->parser); + tls_data *d = (tls_data *)handle; + + if (d->ssl != NULL) + SSL_free(d->ssl); + + if (d->ctx != NULL) + SSL_CTX_free(d->ctx); + driver_free((char *)handle); } @@ -43,17 +51,16 @@ static void tls_drv_stop(ErlDrvData handle) #define GET_ENCRYPTED_OUTPUT 4 #define GET_DECRYPTED_INPUT 5 -#define DECRYPTED_INPUT 1 -#define ENCRYPTED_OUTPUT 2 -#define die_unless(cond, errstr) \ - if (!(cond)) \ - { \ - rlen = strlen(errstr) + 1; \ - *rbuf = driver_alloc(rlen); \ - *rbuf[0] = 1; \ - strncpy(*rbuf + 1, errstr, rlen - 1); \ - return rlen; \ +#define die_unless(cond, errstr) \ + if (!(cond)) \ + { \ + rlen = strlen(errstr) + 1; \ + b = driver_alloc_binary(rlen); \ + b->orig_bytes[0] = 1; \ + strncpy(b->orig_bytes + 1, errstr, rlen - 1); \ + *rbuf = (char *)b; \ + return rlen; \ } @@ -65,6 +72,7 @@ static int tls_drv_control(ErlDrvData handle, tls_data *d = (tls_data *)handle; int res; int size; + ErlDrvBinary *b; switch (command) { @@ -92,48 +100,54 @@ static int tls_drv_control(ErlDrvData handle, SSL_set_accept_state(d->ssl); break; case SET_ENCRYPTED_INPUT: + die_unless(d->ssl, "SSL not initialized"); BIO_write(d->bio_read, buf, len); break; case SET_DECRYPTED_OUTPUT: + die_unless(d->ssl, "SSL not initialized"); res = SSL_write(d->ssl, buf, len); break; case GET_ENCRYPTED_OUTPUT: + die_unless(d->ssl, "SSL not initialized"); size = BUF_SIZE + 1; rlen = 1; - *rbuf = driver_alloc(size); - *rbuf[0] = 0; - while ((res = BIO_read(d->bio_write, *rbuf + rlen, BUF_SIZE)) > 0) + b = driver_alloc_binary(size); + b->orig_bytes[0] = 0; + while ((res = BIO_read(d->bio_write, + b->orig_bytes + rlen, BUF_SIZE)) > 0) { - printf("%d bytes of encrypted data read from state machine\r\n", res); + //printf("%d bytes of encrypted data read from state machine\r\n", res); rlen += res; size += BUF_SIZE; - *rbuf = driver_realloc(*rbuf, size); + b = driver_realloc_binary(b, size); } + b = driver_realloc_binary(b, rlen); + *rbuf = (char *)b; return rlen; case GET_DECRYPTED_INPUT: if (!SSL_is_init_finished(d->ssl)) { - printf("Doing SSL_accept\r\n"); + //printf("Doing SSL_accept\r\n"); res = SSL_accept(d->ssl); - if (res == 0) - printf("SSL_accept returned zero\r\n"); + //if (res == 0) + // printf("SSL_accept returned zero\r\n"); if (res < 0) die_unless(SSL_get_error(d->ssl, res) == SSL_ERROR_WANT_READ, "SSL_accept failed"); } else { size = BUF_SIZE + 1; rlen = 1; - *rbuf = driver_alloc(size); - *rbuf[0] = 0; + b = driver_alloc_binary(size); + b->orig_bytes[0] = 0; - - while ((res = SSL_read(d->ssl, *rbuf + rlen, BUF_SIZE)) > 0) + while ((res = SSL_read(d->ssl, + b->orig_bytes + rlen, BUF_SIZE)) > 0) { - printf("%d bytes of decrypted data read from state machine\r\n",res); + //printf("%d bytes of decrypted data read from state machine\r\n",res); rlen += res; size += BUF_SIZE; - *rbuf = driver_realloc(*rbuf, size); + b = driver_realloc_binary(b, size); } if (res < 0) @@ -142,23 +156,21 @@ static int tls_drv_control(ErlDrvData handle, if (err == SSL_ERROR_WANT_READ) { - printf("SSL_read wants more data\r\n"); + //printf("SSL_read wants more data\r\n"); //return 0; } // TODO } + b = driver_realloc_binary(b, rlen); + *rbuf = (char *)b; return rlen; } break; } - if (command == SET_ENCRYPTED_INPUT || command == SET_DECRYPTED_OUTPUT) - { - - } - - *rbuf = driver_alloc(1); - *rbuf[0] = 0; + b = driver_alloc_binary(1); + b->orig_bytes[0] = 0; + *rbuf = (char *)b; return 1; } diff --git a/src/xml_stream.erl b/src/xml_stream.erl index 4d7f80d4c..69b6123c0 100644 --- a/src/xml_stream.erl +++ b/src/xml_stream.erl @@ -10,7 +10,7 @@ -author('alexey@sevcom.net'). -vsn('$Revision$ '). --export([start/1, init/1, send_text/2]). +-export([start/1, start/2, init/1, init/2, send_text/2]). -define(XML_START, 0). -define(XML_END, 1). @@ -20,10 +20,18 @@ start(CallbackPid) -> spawn(?MODULE, init, [CallbackPid]). +start(Receiver, CallbackPid) -> + spawn(?MODULE, init, [Receiver, CallbackPid]). + init(CallbackPid) -> Port = open_port({spawn, expat_erl}, [binary]), loop(CallbackPid, Port, []). +init(Receiver, CallbackPid) -> + erlang:monitor(process, Receiver), + Port = open_port({spawn, expat_erl}, [binary]), + loop(CallbackPid, Port, []). + loop(CallbackPid, Port, Stack) -> receive {Port, {data, Bin}} -> @@ -31,7 +39,9 @@ loop(CallbackPid, Port, Stack) -> loop(CallbackPid, Port, process_data(CallbackPid, Stack, Data)); {_From, {send, Str}} -> Port ! {self(), {command, Str}}, - loop(CallbackPid, Port, Stack) + loop(CallbackPid, Port, Stack); + {'DOWN', _Ref, _Type, _Object, _Info} -> + ok end. process_data(CallbackPid, Stack, Data) ->