diff --git a/include/ejabberd_http.hrl b/include/ejabberd_http.hrl index 931706342..50c9d4ad7 100644 --- a/include/ejabberd_http.hrl +++ b/include/ejabberd_http.hrl @@ -34,3 +34,15 @@ opts = [] :: list(), headers = [] :: [{atom() | binary(), binary()}]}). + +-record(ws, + {socket :: inet:socket() | p1_tls:tls_socket(), + sockmod = gen_tcp :: gen_tcp | p1_tls, + ip :: {inet:ip_address(), inet:port_number()}, + host = <<"">> :: binary(), + port = 5280 :: inet:port_number(), + path = [] :: [binary()], + headers = [] :: [{atom() | binary(), binary()}], + local_path = [] :: [binary()], + q = [] :: [{binary() | nokey, binary()}], + buf :: binary()}). diff --git a/src/ejabberd_http.erl b/src/ejabberd_http.erl index 3c91c3c58..ce5c8aa09 100644 --- a/src/ejabberd_http.erl +++ b/src/ejabberd_http.erl @@ -345,48 +345,93 @@ get_transfer_protocol(SockMod, HostPort) -> %% XXX bard: search through request handlers looking for one that %% matches the requested URL path, and pass control to it. If none is %% found, answer with HTTP 404. -process([], _) -> - ejabberd_web:error(not_found); -process(Handlers, Request) -> - %% Only the first element in the path prefix is checked - [{HandlerPathPrefix, HandlerModule} | HandlersLeft] = - Handlers, - case lists:prefix(HandlerPathPrefix, - Request#request.path) - or (HandlerPathPrefix == Request#request.path) - of - true -> - ?DEBUG("~p matches ~p", - [Request#request.path, HandlerPathPrefix]), - LocalPath = lists:nthtail(length(HandlerPathPrefix), - Request#request.path), - ?DEBUG("~p", [Request#request.headers]), - R = HandlerModule:process(LocalPath, Request), - ejabberd_hooks:run(http_request_debug, - [{LocalPath, Request}]), - R; - false -> process(HandlersLeft, Request) + +process([], _, _, _, _) -> ejabberd_web:error(not_found); +process(Handlers, Request, Socket, SockMod, Trail) -> + {HandlerPathPrefix, HandlerModule, HandlerOpts, HandlersLeft} = + case Handlers of + [{Pfx, Mod} | Tail] -> + {Pfx, Mod, [], Tail}; + [{Pfx, Mod, Opts} | Tail] -> + {Pfx, Mod, Opts, Tail} + end, + + case (lists:prefix(HandlerPathPrefix, Request#request.path) or + (HandlerPathPrefix==Request#request.path)) of + true -> + ?DEBUG("~p matches ~p", [Request#request.path, HandlerPathPrefix]), + %% LocalPath is the path "local to the handler", i.e. if + %% the handler was registered to handle "/test/" and the + %% requested path is "/test/foo/bar", the local path is + %% ["foo", "bar"] + LocalPath = lists:nthtail(length(HandlerPathPrefix), Request#request.path), + R = try + HandlerModule:socket_handoff( + LocalPath, Request, Socket, SockMod, Trail, HandlerOpts) + catch error:undef -> + HandlerModule:process(LocalPath, Request) + end, + ejabberd_hooks:run(http_request_debug, [{LocalPath, Request}]), + R; + false -> + process(HandlersLeft, Request, Socket, SockMod, Trail) end. -process_request(#state{request_method = Method, options = Options, - request_path = {abs_path, Path}, request_auth = Auth, - request_lang = Lang, request_handlers = RequestHandlers, - request_host = Host, request_port = Port, - request_tp = TP, request_headers = RequestHeaders, - sockmod = SockMod, - socket = Socket} = State) - when Method=:='GET' orelse Method=:='HEAD' orelse Method=:='DELETE' orelse Method=:='OPTIONS' -> - case (catch url_decode_q_split(Path)) of - {'EXIT', _} -> +extract_path_query(#state{request_method = Method, + request_path = {abs_path, Path}}) + when Method =:= 'GET' orelse + Method =:= 'HEAD' orelse + Method =:= 'DELETE' orelse Method =:= 'OPTIONS' -> + case catch url_decode_q_split(Path) of + {'EXIT', _} -> false; + {NPath, Query} -> + LPath = normalize_path([NPE + || NPE <- str:tokens(path_decode(NPath), <<"/">>)]), + LQuery = case catch parse_urlencoded(Query) of + {'EXIT', _Reason} -> []; + LQ -> LQ + end, + {LPath, LQuery, <<"">>} + end; +extract_path_query(#state{request_method = Method, + request_path = {abs_path, Path}, + request_content_length = Len, + sockmod = _SockMod, + socket = _Socket} = State) + when (Method =:= 'POST' orelse Method =:= 'PUT') andalso + is_integer(Len) -> + Data = recv_data(State, Len), + ?DEBUG("client data: ~p~n", [Data]), + case catch url_decode_q_split(Path) of + {'EXIT', _} -> false; + {NPath, _Query} -> + LPath = normalize_path([NPE + || NPE <- str:tokens(path_decode(NPath), <<"/">>)]), + LQuery = case catch parse_urlencoded(Data) of + {'EXIT', _Reason} -> []; + LQ -> LQ + end, + {LPath, LQuery, Data} + end; +extract_path_query(_State) -> + false. + +process_request(#state{request_method = Method, + request_auth = Auth, + request_lang = Lang, + sockmod = SockMod, + socket = Socket, + options = Options, + request_host = Host, + request_port = Port, + request_tp = TP, + request_headers = RequestHeaders, + request_handlers = RequestHandlers, + trail = Trail} = State) -> + case extract_path_query(State) of + false -> make_bad_request(State); - {NPath, Query} -> - LPath = normalize_path([NPE || NPE <- str:tokens(path_decode(NPath), <<"/">>)]), - LQuery = case (catch parse_urlencoded(Query)) of - {'EXIT', _Reason} -> - []; - LQ -> - LQ - end, + {LPath, LQuery, Data} -> {ok, IPHere} = case SockMod of gen_tcp -> @@ -396,92 +441,36 @@ process_request(#state{request_method = Method, options = Options, end, XFF = proplists:get_value('X-Forwarded-For', RequestHeaders, []), IP = analyze_ip_xff(IPHere, XFF, Host), - Request = #request{method = Method, - path = LPath, + Request = #request{method = Method, + path = LPath, + q = LQuery, + auth = Auth, + data = Data, + lang = Lang, + host = Host, + port = Port, + tp = TP, opts = Options, - q = LQuery, - auth = Auth, - lang = Lang, - host = Host, - port = Port, - tp = TP, - headers = RequestHeaders, - ip = IP}, - %% XXX bard: This previously passed control to - %% ejabberd_web:process_get, now passes it to a local - %% procedure (process) that handles dispatching based on - %% URL path prefix. - case process(RequestHandlers, Request) of - El when element(1, El) == xmlel -> - make_xhtml_output(State, 200, [], El); - {Status, Headers, El} when - element(1, El) == xmlel -> - make_xhtml_output(State, Status, Headers, El); - Output when is_list(Output) or is_binary(Output) -> - make_text_output(State, 200, [], Output); - {Status, Headers, Output} when is_list(Output) or is_binary(Output) -> - make_text_output(State, Status, Headers, Output) + headers = RequestHeaders, + ip = IP}, + case process(RequestHandlers, Request, Socket, SockMod, Trail) of + El when is_record(El, xmlel) -> + make_xhtml_output(State, 200, [], El); + {Status, Headers, El} + when is_record(El, xmlel) -> + make_xhtml_output(State, Status, Headers, El); + Output when is_binary(Output) or is_list(Output) -> + make_text_output(State, 200, [], Output); + {Status, Headers, Output} + when is_binary(Output) or is_list(Output) -> + make_text_output(State, Status, Headers, Output); + {Status, Reason, Headers, Output} + when is_binary(Output) or is_list(Output) -> + make_text_output(State, Status, Reason, Headers, Output); + _ -> + none end - end; -process_request(#state{request_method = Method, options = Options, - request_path = {abs_path, Path}, request_auth = Auth, - request_content_length = Len, request_lang = Lang, - sockmod = SockMod, socket = Socket, request_host = Host, - request_port = Port, request_tp = TP, - request_headers = RequestHeaders, - request_handlers = RequestHandlers} = - State) - when (Method =:= 'POST' orelse Method =:= 'PUT') andalso - is_integer(Len) -> - {ok, IPHere} = case SockMod of - gen_tcp -> inet:peername(Socket); - _ -> SockMod:peername(Socket) - end, - XFF = proplists:get_value('X-Forwarded-For', - RequestHeaders, []), - IP = analyze_ip_xff(IPHere, XFF, Host), - case SockMod of - gen_tcp -> inet:setopts(Socket, [{packet, 0}]); - _ -> ok - end, - Data = recv_data(State, Len), - ?DEBUG("client data: ~p~n", [Data]), - case (catch url_decode_q_split(Path)) of - {'EXIT', _} -> - make_bad_request(State); - {NPath, _Query} -> - LPath = normalize_path([NPE || NPE <- str:tokens(path_decode(NPath), <<"/">>)]), - LQuery = case (catch parse_urlencoded(Data)) of - {'EXIT', _Reason} -> - []; - LQ -> - LQ - end, - Request = #request{method = Method, - path = LPath, - q = LQuery, - opts = Options, - auth = Auth, - data = Data, - lang = Lang, - host = Host, - port = Port, - tp = TP, - headers = RequestHeaders, - ip = IP}, - case process(RequestHandlers, Request) of - El when element(1, El) == xmlel -> - make_xhtml_output(State, 200, [], El); - {Status, Headers, El} when - element(1, El) == xmlel -> - make_xhtml_output(State, Status, Headers, El); - Output when is_list(Output) or is_binary(Output) -> - make_text_output(State, 200, [], Output); - {Status, Headers, Output} when is_list(Output) or is_binary(Output) -> - make_text_output(State, Status, Headers, Output) - end - end; -process_request(State) -> make_bad_request(State). + end. make_bad_request(State) -> %% Support for X-Forwarded-From diff --git a/src/ejabberd_http_ws.erl b/src/ejabberd_http_ws.erl new file mode 100644 index 000000000..47ad03f55 --- /dev/null +++ b/src/ejabberd_http_ws.erl @@ -0,0 +1,340 @@ +%%%---------------------------------------------------------------------- +%%% File : ejabberd_websocket.erl +%%% Author : Eric Cestari +%%% Purpose : XMPP Websocket support +%%% Created : 09-10-2010 by Eric Cestari +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2015 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License +%%% along with this program; if not, write to the Free Software +%%% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +%%% 02111-1307 USA +%%% +%%%---------------------------------------------------------------------- +-module(ejabberd_http_ws). + +-author('ecestari@process-one.net'). + +-behaviour(gen_fsm). + +% External exports +-export([start/1, start_link/1, init/1, handle_event/3, + handle_sync_event/4, code_change/4, handle_info/3, + terminate/3, send_xml/2, setopts/2, sockname/1, peername/1, + controlling_process/2, become_controller/2, close/1, + socket_handoff/6]). + +-include("ejabberd.hrl"). +-include("logger.hrl"). + +-include("jlib.hrl"). + +-include("ejabberd_http.hrl"). + +-define(PING_INTERVAL, 60). +-define(WEBSOCKET_TIMEOUT, 300). + +-record(state, + {socket :: ws_socket(), + ping_interval = ?PING_INTERVAL :: pos_integer(), + ping_timer = make_ref() :: reference(), + pong_expected :: boolean(), + timeout = ?WEBSOCKET_TIMEOUT :: pos_integer(), + timer = make_ref() :: reference(), + input = [] :: list(), + waiting_input = false :: false | pid(), + last_receiver :: pid(), + ws :: {#ws{}, pid()}, + rfc_compilant = undefined :: boolean() | undefined}). + +%-define(DBGFSM, true). + +-ifdef(DBGFSM). + +-define(FSMOPTS, [{debug, [trace]}]). + +-else. + +-define(FSMOPTS, []). + +-endif. + +-type ws_socket() :: {http_ws, pid(), {inet:ip_address(), inet:port_number()}}. +-export_type([ws_socket/0]). + +start(WS) -> + supervisor:start_child(ejabberd_wsloop_sup, [WS]). + +start_link(WS) -> + gen_fsm:start_link(?MODULE, [WS], ?FSMOPTS). + +send_xml({http_ws, FsmRef, _IP}, Packet) -> + gen_fsm:sync_send_all_state_event(FsmRef, + {send_xml, Packet}). + +setopts({http_ws, FsmRef, _IP}, Opts) -> + case lists:member({active, once}, Opts) of + true -> + gen_fsm:send_all_state_event(FsmRef, + {activate, self()}); + _ -> ok + end. + +sockname(_Socket) -> {ok, {{0, 0, 0, 0}, 0}}. + +peername({http_ws, _FsmRef, IP}) -> {ok, IP}. + +controlling_process(_Socket, _Pid) -> ok. + +become_controller(FsmRef, C2SPid) -> + gen_fsm:send_all_state_event(FsmRef, + {become_controller, C2SPid}). + +close({http_ws, FsmRef, _IP}) -> + catch gen_fsm:sync_send_all_state_event(FsmRef, close). + +socket_handoff(LocalPath, Request, Socket, SockMod, Buf, Opts) -> + ejabberd_websocket:socket_handoff(LocalPath, Request, Socket, SockMod, + Buf, Opts, ?MODULE, fun get_human_html_xmlel/0). + +%%% Internal + +init([{#ws{ip = IP}, _} = WS]) -> + Opts = [{xml_socket, true} | ejabberd_c2s_config:get_c2s_limits()], + PingInterval = ejabberd_config:get_option( + {websocket_ping_interval, ?MYNAME}, + fun(I) when is_integer(I), I>=0 -> I end, + ?PING_INTERVAL) * 1000, + WSTimeout = ejabberd_config:get_option( + {websocket_timeout, ?MYNAME}, + fun(I) when is_integer(I), I>0 -> I end, + ?WEBSOCKET_TIMEOUT) * 1000, + Socket = {http_ws, self(), IP}, + ?DEBUG("Client connected through websocket ~p", + [Socket]), + ejabberd_socket:start(ejabberd_c2s, ?MODULE, Socket, + Opts), + Timer = erlang:start_timer(WSTimeout, self(), []), + {ok, loop, + #state{socket = Socket, timeout = WSTimeout, + timer = Timer, ws = WS, + ping_interval = PingInterval}}. + +handle_event({activate, From}, StateName, StateData) -> + case StateData#state.input of + [] -> + {next_state, StateName, + StateData#state{waiting_input = From}}; + Input -> + Receiver = From, + Receiver ! {tcp, StateData#state.socket, Input}, + {next_state, StateName, + StateData#state{input = [], waiting_input = false, + last_receiver = Receiver}} + end. + +handle_sync_event({send_xml, Packet}, _From, StateName, + #state{ws = {_, WsPid}, rfc_compilant = R} = StateData) -> + Packet2 = case {case R of undefined -> true; V -> V end, Packet} of + {true, {xmlstreamstart, _, Attrs}} -> + Attrs2 = [{<<"xmlns">>, <<"urn:ietf:params:xml:ns:xmpp-framing">>} | + lists:keydelete(<<"xmlns">>, 1, lists:keydelete(<<"xmlns:stream">>, 1, Attrs))], + {xmlstreamelement, #xmlel{name = <<"open">>, attrs = Attrs2}}; + {true, {xmlstreamend, _}} -> + {xmlstreamelement, #xmlel{name = <<"close">>, + attrs = [{<<"xmlns">>, <<"urn:ietf:params:xml:ns:xmpp-framing">>}]}}; + {true, {xmlstreamraw, <<"\r\n\r\n">>}} -> % cdata ping + skip; + {true, {xmlstreamelement, #xmlel{name=Name2} = El2}} -> + El3 = case Name2 of + <<"stream:", _/binary>> -> + xml:replace_tag_attr(<<"xmlns:stream">>, ?NS_STREAM, El2); + _ -> + case xml:get_tag_attr_s(<<"xmlns">>, El2) of + <<"">> -> + xml:replace_tag_attr(<<"xmlns">>, <<"jabber:client">>, El2); + _ -> + El2 + end + end, + {xmlstreamelement , El3}; + _ -> + Packet + end, + case Packet2 of + {xmlstreamstart, Name, Attrs3} -> + B = xml:element_to_binary(#xmlel{name = Name, attrs = Attrs3}), + WsPid ! {send, <<(binary:part(B, 0, byte_size(B)-2))/binary, ">">>}; + {xmlstreamend, Name} -> + WsPid ! {send, <<"">>}; + {xmlstreamelement, El} -> + WsPid ! {send, xml:element_to_binary(El)}; + {xmlstreamraw, Bin} -> + WsPid ! {send, Bin}; + {xmlstreamcdata, Bin2} -> + WsPid ! {send, Bin2}; + skip -> + ok + end, + {reply, ok, StateName, StateData}; +handle_sync_event(close, _From, _StateName, StateData) -> + {stop, normal, StateData}. + +handle_info(closed, _StateName, StateData) -> + {stop, normal, StateData}; +handle_info({received, Packet}, StateName, StateDataI) -> + {StateData, Parsed} = parse(StateDataI, Packet), + SD = case StateData#state.waiting_input of + false -> + Input = StateData#state.input ++ Parsed, + StateData#state{input = Input}; + Receiver -> + Receiver ! {tcp, StateData#state.socket, Parsed}, + setup_timers(StateData#state{waiting_input = false, + last_receiver = Receiver}) + end, + {next_state, StateName, SD}; +handle_info(PingPong, StateName, StateData) when PingPong == ping orelse + PingPong == pong -> + StateData2 = setup_timers(StateData), + {next_state, StateName, + StateData2#state{pong_expected = false}}; +handle_info({timeout, Timer, _}, _StateName, + #state{timer = Timer} = StateData) -> + {stop, normal, StateData}; +handle_info({timeout, Timer, _}, StateName, + #state{ping_timer = Timer, ws = {_, WsPid}} = StateData) -> + case StateData#state.pong_expected of + false -> + cancel_timer(StateData#state.ping_timer), + PingTimer = erlang:start_timer(StateData#state.ping_interval, + self(), []), + WsPid ! {ping, <<>>}, + {next_state, StateName, + StateData#state{ping_timer = PingTimer, pong_expected = true}}; + true -> + {stop, normal, StateData} + end; +handle_info(_, StateName, StateData) -> + {next_state, StateName, StateData}. + +code_change(_OldVsn, StateName, StateData, _Extra) -> + {ok, StateName, StateData}. + +terminate(_Reason, _StateName, StateData) -> + case StateData#state.waiting_input of + false -> ok; + Receiver -> + ?DEBUG("C2S Pid : ~p", [Receiver]), + Receiver ! {tcp_closed, StateData#state.socket} + end, + ok. + +setup_timers(StateData) -> + cancel_timer(StateData#state.timer), + Timer = erlang:start_timer(StateData#state.timeout, + self(), []), + cancel_timer(StateData#state.ping_timer), + PingTimer = case {StateData#state.ping_interval, StateData#state.rfc_compilant} of + {0, _} -> StateData#state.ping_timer; + {_, false} -> StateData#state.ping_timer; + {V, _} -> erlang:start_timer(V, self(), []) + end, + StateData#state{timer = Timer, ping_timer = PingTimer, + pong_expected = false}. + +cancel_timer(Timer) -> + erlang:cancel_timer(Timer), + receive {timeout, Timer, _} -> ok after 0 -> ok end. + +get_human_html_xmlel() -> + Heading = <<"ejabberd ", (jlib:atom_to_binary(?MODULE))/binary>>, + #xmlel{name = <<"html">>, + attrs = + [{<<"xmlns">>, <<"http://www.w3.org/1999/xhtml">>}], + children = + [#xmlel{name = <<"head">>, attrs = [], + children = + [#xmlel{name = <<"title">>, attrs = [], + children = [{xmlcdata, Heading}]}]}, + #xmlel{name = <<"body">>, attrs = [], + children = + [#xmlel{name = <<"h1">>, attrs = [], + children = [{xmlcdata, Heading}]}, + #xmlel{name = <<"p">>, attrs = [], + children = + [{xmlcdata, <<"An implementation of ">>}, + #xmlel{name = <<"a">>, + attrs = + [{<<"href">>, + <<"http://tools.ietf.org/html/rfc6455">>}], + children = + [{xmlcdata, + <<"WebSocket protocol">>}]}]}, + #xmlel{name = <<"p">>, attrs = [], + children = + [{xmlcdata, + <<"This web page is only informative. To " + "use WebSocket connection you need a Jabber/XMPP " + "client that supports it.">>}]}]}]}. + + +parse(#state{rfc_compilant = C} = State, Data) -> + case C of + undefined -> + P = xml_stream:new(self()), + P2 = xml_stream:parse(P, Data), + xml_stream:close(P2), + case parsed_items([]) of + error -> + {State#state{rfc_compilant = true}, <<"parse error">>}; + [] -> + {State#state{rfc_compilant = true}, <<"parse error">>}; + [{xmlstreamstart, <<"open">>, _} | _] -> + parse(State#state{rfc_compilant = true}, Data); + _ -> + parse(State#state{rfc_compilant = false}, Data) + end; + true -> + El = xml_stream:parse_element(Data), + case El of + #xmlel{name = <<"open">>, attrs = Attrs} -> + Attrs2 = [{<<"xmlns:stream">>, ?NS_STREAM}, {<<"xmlns">>, <<"jabber:client">>} | + lists:keydelete(<<"xmlns">>, 1, lists:keydelete(<<"xmlns:stream">>, 1, Attrs))], + {State, [{xmlstreamstart, <<"stream:stream">>, Attrs2}]}; + #xmlel{name = <<"close">>} -> + {State, [{xmlstreamend, <<"stream:stream">>}]}; + {error, _} -> + {State, <<"parse error">>}; + _ -> + {State, [El]} + end; + false -> + {State, Data} + end. + +parsed_items(List) -> + receive + {'$gen_event', El} + when element(1, El) == xmlel; + element(1, El) == xmlstreamstart; + element(1, El) == xmlstreamelement; + element(1, El) == xmlstreamend -> + parsed_items([El | List]); + {'$gen_event', {xmlstreamerror, _}} -> + error + after 0 -> + lists:reverse(List) + end. diff --git a/src/ejabberd_websocket.erl b/src/ejabberd_websocket.erl new file mode 100644 index 000000000..8cd1b2289 --- /dev/null +++ b/src/ejabberd_websocket.erl @@ -0,0 +1,403 @@ +%%%---------------------------------------------------------------------- +%%% File : ejabberd_websocket.erl +%%% Author : Eric Cestari +%%% Purpose : XMPP Websocket support +%%% Created : 09-10-2010 by Eric Cestari +%%% +%%% Some code lifted from MISULTIN - WebSocket misultin_websocket.erl - >-|-|-(°> +%%% (http://github.com/ostinelli/misultin/blob/master/src/misultin_websocket.erl) +%%% Copyright (C) 2010, Roberto Ostinelli , Joe Armstrong. +%%% All rights reserved. +%%% +%%% Code portions from Joe Armstrong have been originally taken under MIT license at the address: +%%% +%%% +%%% BSD License +%%% +%%% Redistribution and use in source and binary forms, with or without modification, are permitted provided +%%% that the following conditions are met: +%%% +%%% * Redistributions of source code must retain the above copyright notice, this list of conditions and the +%%% following disclaimer. +%%% * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and +%%% the following disclaimer in the documentation and/or other materials provided with the distribution. +%%% * Neither the name of the authors nor the names of its contributors may be used to endorse or promote +%%% products derived from this software without specific prior written permission. +%%% +%%% THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED +%%% WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +%%% PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +%%% ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +%%% TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +%%% HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +%%% NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +%%% POSSIBILITY OF SUCH DAMAGE. +%%% ========================================================================================================== +%%% ejabberd, Copyright (C) 2002-2015 ProcessOne +%%%---------------------------------------------------------------------- + +-module(ejabberd_websocket). + +-author('ecestari@process-one.net'). + +-export([check/2, socket_handoff/8]). + +-include("ejabberd.hrl"). +-include("logger.hrl"). + +-include("jlib.hrl"). + +-include("ejabberd_http.hrl"). + +-define(CT_XML, {<<"Content-Type">>, <<"text/xml; charset=utf-8">>}). +-define(CT_PLAIN, {<<"Content-Type">>, <<"text/plain">>}). + +-define(AC_ALLOW_ORIGIN, {<<"Access-Control-Allow-Origin">>, <<"*">>}). +-define(AC_ALLOW_METHODS, {<<"Access-Control-Allow-Methods">>, <<"GET, OPTIONS">>}). +-define(AC_ALLOW_HEADERS, {<<"Access-Control-Allow-Headers">>, <<"Content-Type">>}). +-define(AC_MAX_AGE, {<<"Access-Control-Max-Age">>, <<"86400">>}). + +-define(OPTIONS_HEADER, [?CT_PLAIN, ?AC_ALLOW_ORIGIN, ?AC_ALLOW_METHODS, + ?AC_ALLOW_HEADERS, ?AC_MAX_AGE]). +-define(HEADER, [?CT_XML, ?AC_ALLOW_ORIGIN, ?AC_ALLOW_HEADERS]). + +check(_Path, Headers) -> + RequiredHeaders = [{'Upgrade', <<"websocket">>}, + {'Connection', ignore}, {'Host', ignore}, + {<<"Sec-Websocket-Key">>, ignore}, + {<<"Sec-Websocket-Version">>, <<"13">>}], + + F = fun ({Tag, Val}) -> + case lists:keyfind(Tag, 1, Headers) of + false -> true; % header not found, keep in list + {_, HVal} -> + case Val of + ignore -> false; % ignore value -> ok, remove from list + HVal -> false; % expected val -> ok, remove from list + _ -> + true % val is different, keep in list + end + end + end, + case lists:filter(F, RequiredHeaders) of + [] -> true; + _MissingHeaders -> false + end. + +socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path, + headers = Headers, host = Host, port = Port}, + Socket, SockMod, Buf, _Opts, HandlerModule, InfoMsgFun) -> + case check(LocalPath, Headers) of + true -> + WS = #ws{socket = Socket, + sockmod = SockMod, + ip = IP, + q = Q, + host = Host, + port = Port, + path = Path, + headers = Headers, + local_path = LocalPath, + buf = Buf}, + + connect(WS, HandlerModule); + _ -> + {200, ?HEADER, InfoMsgFun()} + end; +socket_handoff(_, #request{method = 'OPTIONS'}, _, _, _, _, _, _) -> + {200, ?OPTIONS_HEADER, []}; +socket_handoff(_, #request{method = 'HEAD'}, _, _, _, _, _, _) -> + {200, ?HEADER, []}; +socket_handoff(_, _, _, _, _, _, _, _) -> + {400, ?HEADER, #xmlel{name = <<"h1">>, + children = [{xmlcdata, <<"400 Bad Request">>}]}}. + +connect(#ws{socket = Socket, sockmod = SockMod} = Ws, WsLoop) -> + {NewWs, HandshakeResponse} = handshake(Ws), + SockMod:send(Socket, HandshakeResponse), + + ?DEBUG("Sent handshake response : ~p", + [HandshakeResponse]), + Ws0 = {Ws, self()}, + {ok, WsHandleLoopPid} = WsLoop:start_link(Ws0), + erlang:monitor(process, WsHandleLoopPid), + + case NewWs#ws.buf of + <<>> -> + ok; + Data -> + self() ! {raw, Socket, Data} + end, + + % set opts + case SockMod of + gen_tcp -> + inet:setopts(Socket, [{packet, 0}, {active, true}]); + _ -> + SockMod:setopts(Socket, [{packet, 0}, {active, true}]) + end, + ws_loop(none, Socket, WsHandleLoopPid, SockMod). + +handshake(#ws{headers = Headers} = State) -> + {_, Key} = lists:keyfind(<<"Sec-Websocket-Key">>, 1, + Headers), + SubProtocolHeader = case find_subprotocol(Headers) of + false -> + []; + V -> + [<<"Sec-Websocket-Protocol:">>, V, <<"\r\n">>] + end, + Hash = jlib:encode_base64( + p1_sha:sha1(<>)), + {State, [<<"HTTP/1.1 101 Switching Protocols\r\n">>, + <<"Upgrade: websocket\r\n">>, + <<"Connection: Upgrade\r\n">>, + SubProtocolHeader, + <<"Sec-WebSocket-Accept: ">>, Hash, <<"\r\n\r\n">>]}. + +find_subprotocol(Headers) -> + case lists:keysearch(<<"Sec-Websocket-Protocol">>, 1, Headers) of + false -> + case lists:keysearch(<<"Websocket-Protocol">>, 1, Headers) of + false -> + false; + {value, {_, Protocol2}} -> + Protocol2 + end; + {value, {_, Protocol}} -> + Protocol + end. + + +ws_loop(FrameInfo, Socket, WsHandleLoopPid, SocketMode) -> + receive + {DataType, _Socket, Data} when DataType =:= tcp orelse DataType =:= raw -> + case handle_data(DataType, FrameInfo, Data, Socket, WsHandleLoopPid, SocketMode) of + {error, Error} -> + ?DEBUG("tls decode error ~p", [Error]), + websocket_close(Socket, WsHandleLoopPid, SocketMode, 1002); % protocol error + {NewFrameInfo, ToSend} -> + lists:foreach(fun(Pkt) -> SocketMode:send(Socket, Pkt) + end, ToSend), + ws_loop(NewFrameInfo, Socket, WsHandleLoopPid, SocketMode) + end; + {tcp_closed, _Socket} -> + ?DEBUG("tcp connection was closed, exit", []), + websocket_close(Socket, WsHandleLoopPid, SocketMode, 0); + {'DOWN', Ref, process, WsHandleLoopPid, Reason} -> + Code = case Reason of + normal -> + 1000; % normal close + _ -> + ?ERROR_MSG("linked websocket controlling loop crashed " + "with reason: ~p", + [Reason]), + 1011 % internal error + end, + erlang:demonitor(Ref), + websocket_close(Socket, WsHandleLoopPid, SocketMode, Code); + {send, Data} -> + SocketMode:send(Socket, encode_frame(Data, 1)), + ws_loop(FrameInfo, Socket, WsHandleLoopPid, + SocketMode); + {ping, Data} -> + SocketMode:send(Socket, encode_frame(Data, 9)), + ws_loop(FrameInfo, Socket, WsHandleLoopPid, + SocketMode); + shutdown -> + ?DEBUG("shutdown request received, closing websocket " + "with pid ~p", + [self()]), + websocket_close(Socket, WsHandleLoopPid, SocketMode, 1001); % going away + _Ignored -> + ?WARNING_MSG("received unexpected message, ignoring: ~p", + [_Ignored]), + ws_loop(FrameInfo, Socket, WsHandleLoopPid, + SocketMode) + end. + +encode_frame(Data, Opcode) -> + case byte_size(Data) of + S1 when S1 < 126 -> + <<1:1, 0:3, Opcode:4, 0:1, S1:7, Data/binary>>; + S2 when S2 < 65536 -> + <<1:1, 0:3, Opcode:4, 0:1, 126:7, S2:16, Data/binary>>; + S3 -> + <<1:1, 0:3, Opcode:4, 0:1, 127:7, S3:64, Data/binary>> + end. + +-record(frame_info, + {mask = none, offset = 0, left, final_frame = true, + opcode, unprocessed = <<>>, unmasked = <<>>, + unmasked_msg = <<>>}). + +decode_header(<>) + when Len < 126 -> + {Len, Final, Opcode, none, Data}; +decode_header(<>) -> + {Len, Final, Opcode, none, Data}; +decode_header(<>) -> + {Len, Final, Opcode, none, Data}; +decode_header(<>) + when Len < 126 -> + {Len, Final, Opcode, Mask, Data}; +decode_header(<>) -> + {Len, Final, Opcode, Mask, Data}; +decode_header(<>) -> + {Len, Final, Opcode, Mask, Data}; +decode_header(_) -> none. + +unmask_int(Offset, _, <<>>, Acc) -> + {Acc, Offset}; +unmask_int(0, <> = Mask, + <>, Acc) -> + unmask_int(0, Mask, Rest, + <>); +unmask_int(0, <> = Mask, + <>, Acc) -> + unmask_int(1, Mask, Rest, + <>); +unmask_int(1, <<_:8, M:8, _/binary>> = Mask, + <>, Acc) -> + unmask_int(2, Mask, Rest, + <>); +unmask_int(2, <<_:16, M:8, _/binary>> = Mask, + <>, Acc) -> + unmask_int(3, Mask, Rest, + <>); +unmask_int(3, <<_:24, M:8>> = Mask, + <>, Acc) -> + unmask_int(0, Mask, Rest, + <>). + +unmask(#frame_info{mask = none} = State, Data) -> + {State, Data}; +unmask(#frame_info{mask = Mask, offset = Offset} = State, Data) -> + {Unmasked, NewOffset} = unmask_int(Offset, Mask, + Data, <<>>), + {State#frame_info{offset = NewOffset}, Unmasked}. + +process_frame(none, Data) -> + process_frame(#frame_info{}, Data); +process_frame(#frame_info{left = Left} = FrameInfo, <<>>) when Left > 0 -> + {FrameInfo, [], []}; +process_frame(#frame_info{unprocessed = none, + unmasked = UnmaskedPre, left = Left} = + State, + Data) + when byte_size(Data) < Left -> + {State2, Unmasked} = unmask(State, Data), + {State2#frame_info{left = Left - byte_size(Data), + unmasked = [UnmaskedPre, Unmasked]}, + [], []}; +process_frame(#frame_info{unprocessed = none, + unmasked = UnmaskedPre, opcode = Opcode, + final_frame = Final, left = Left, + unmasked_msg = UnmaskedMsg} = + FrameInfo, + Data) -> + <> = Data, + {_, Unmasked} = unmask(FrameInfo, ToProcess), + case Final of + true -> + {FrameInfo3, Recv, Send} = process_frame(#frame_info{}, + Unprocessed), + case Opcode of + X when X < 3 -> + {FrameInfo3, + [iolist_to_binary([UnmaskedMsg, UnmaskedPre, Unmasked]) + | Recv], + Send}; + 9 -> % Ping + Frame = encode_frame(Unprocessed, 10), + {FrameInfo3#frame_info{unmasked_msg = UnmaskedMsg}, [ping | Recv], + [Frame | Send]}; + 10 -> % Pong + {FrameInfo3, [pong | Recv], Send}; + 8 -> % Close + CloseCode = case Unmasked of + <> -> + ?DEBUG("WebSocket close op: ~p ~s", + [Code, Message]), + Code; + <> -> + ?DEBUG("WebSocket close op: ~p", [Code]), + Code; + _ -> + ?DEBUG("WebSocket close op unknown: ~p", + [Unmasked]), + 1000 + end, + + Frame = encode_frame(<>, 8), + {FrameInfo3#frame_info{unmasked_msg=UnmaskedMsg}, Recv, + [Frame | Send]}; + _ -> + {FrameInfo3#frame_info{unmasked_msg = UnmaskedMsg}, Recv, + Send} + end; + _ -> + process_frame(#frame_info{unmasked_msg = + [UnmaskedMsg, UnmaskedPre, + Unmasked]}, + Unprocessed) + end; +process_frame(#frame_info{unprocessed = <<>>} = + FrameInfo, + Data) -> + case decode_header(Data) of + none -> + {FrameInfo#frame_info{unprocessed = Data}, [], []}; + {Len, Final, Opcode, Mask, Rest} -> + process_frame(FrameInfo#frame_info{mask = Mask, + final_frame = Final == 1, + left = Len, opcode = Opcode, + unprocessed = none}, + Rest) + end; +process_frame(#frame_info{unprocessed = + UnprocessedPre} = + FrameInfo, + Data) -> + process_frame(FrameInfo#frame_info{unprocessed = <<>>}, + <>). + +handle_data(tcp, FrameInfo, Data, Socket, WsHandleLoopPid, p1_tls) -> + case p1_tls:recv_data(Socket, Data) of + {ok, NewData} -> + handle_data_int(FrameInfo, NewData, Socket, WsHandleLoopPid, p1_tls); + {error, Error} -> + {error, Error} + end; +handle_data(_, FrameInfo, Data, Socket, WsHandleLoopPid, SockMod) -> + handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SockMod). + +handle_data_int(FrameInfo, Data, _Socket, WsHandleLoopPid, _SocketMode) -> + {NewFrameInfo, Recv, Send} = process_frame(FrameInfo, Data), + lists:foreach(fun (El) -> + case El of + pong -> + WsHandleLoopPid ! pong; + ping -> + WsHandleLoopPid ! ping; + _ -> + WsHandleLoopPid ! {received, El} + end + end, + Recv), + {NewFrameInfo, Send}. + +websocket_close(Socket, WsHandleLoopPid, + SocketMode, CloseCode) when CloseCode > 0 -> + Frame = encode_frame(<>, 8), + SocketMode:send(Socket, Frame), + websocket_close(Socket, WsHandleLoopPid, SocketMode, 0); +websocket_close(Socket, WsHandleLoopPid, SocketMode, _CloseCode) -> + WsHandleLoopPid ! closed, + SocketMode:close(Socket).