From 4311a5646fae97b988336f389b3ce6585dce4f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Chmielowski?= Date: Fri, 13 Jan 2023 19:40:53 +0100 Subject: [PATCH] Add support for websockets to mqtt bridge --- src/ejabberd_websocket.erl | 270 +++++++++---------------------- src/ejabberd_websocket_codec.erl | 175 ++++++++++++++++++++ src/mod_mqtt_bridge.erl | 13 +- src/mod_mqtt_bridge_session.erl | 161 +++++++++++++++--- 4 files changed, 392 insertions(+), 227 deletions(-) create mode 100644 src/ejabberd_websocket_codec.erl diff --git a/src/ejabberd_websocket.erl b/src/ejabberd_websocket.erl index dbc34a6a3..21123eb23 100644 --- a/src/ejabberd_websocket.erl +++ b/src/ejabberd_websocket.erl @@ -155,7 +155,7 @@ connect(#ws{socket = Socket, sockmod = SockMod} = Ws, WsLoop) -> _ -> SockMod:setopts(Socket, [{packet, 0}, {active, true}]) end, - ws_loop(none, Socket, WsHandleLoopPid, SockMod, none). + ws_loop(ejabberd_websocket_codec:new_server(), Socket, WsHandleLoopPid, SockMod, none). handshake(#ws{headers = Headers} = State) -> {_, Key} = lists:keyfind(<<"Sec-Websocket-Key">>, 1, @@ -188,17 +188,20 @@ find_subprotocol(Headers) -> end. -ws_loop(FrameInfo, Socket, WsHandleLoopPid, SockMod, Shaper) -> +ws_loop(Codec, Socket, WsHandleLoopPid, SockMod, Shaper) -> receive {DataType, _Socket, Data} when DataType =:= tcp orelse DataType =:= raw -> - case handle_data(DataType, FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper) of - {error, Error} -> + case handle_data(DataType, Codec, Data, Socket, WsHandleLoopPid, SockMod, Shaper) of + {error, tls, Error} -> ?DEBUG("TLS decode error ~p", [Error]), - websocket_close(Socket, WsHandleLoopPid, SockMod, 1002); % protocol error - {NewFrameInfo, ToSend, NewShaper} -> + websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 1002); % protocol error + {error, protocol, Error} -> + ?DEBUG("Websocket decode error ~p", [Error]), + websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 1002); % protocol error + {NewCodec, ToSend, NewShaper} -> lists:foreach(fun(Pkt) -> SockMod:send(Socket, Pkt) end, ToSend), - ws_loop(NewFrameInfo, Socket, WsHandleLoopPid, SockMod, NewShaper) + ws_loop(NewCodec, Socket, WsHandleLoopPid, SockMod, NewShaper) end; {new_shaper, NewShaper} -> NewShaper = case NewShaper of @@ -207,13 +210,13 @@ ws_loop(FrameInfo, Socket, WsHandleLoopPid, SockMod, Shaper) -> _ -> NewShaper end, - ws_loop(FrameInfo, Socket, WsHandleLoopPid, SockMod, NewShaper); + ws_loop(Codec, Socket, WsHandleLoopPid, SockMod, NewShaper); {tcp_closed, _Socket} -> ?DEBUG("TCP connection was closed, exit", []), - websocket_close(Socket, WsHandleLoopPid, SockMod, 0); + websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 0); {tcp_error, Socket, Reason} -> ?DEBUG("TCP connection error: ~ts", [inet:format_error(Reason)]), - websocket_close(Socket, WsHandleLoopPid, SockMod, 0); + websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 0); {'DOWN', Ref, process, WsHandleLoopPid, Reason} -> Code = case Reason of normal -> @@ -225,224 +228,95 @@ ws_loop(FrameInfo, Socket, WsHandleLoopPid, SockMod, Shaper) -> 1011 % internal error end, erlang:demonitor(Ref), - websocket_close(Socket, WsHandleLoopPid, SockMod, Code); + websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, Code); {text_with_reply, Data, Sender} -> - SockMod:send(Socket, encode_frame(Data, 1)), + SockMod:send(Socket, ejabberd_websocket_codec:encode(Codec, 1, Data)), Sender ! {text_reply, self()}, - ws_loop(FrameInfo, Socket, WsHandleLoopPid, + ws_loop(Codec, Socket, WsHandleLoopPid, SockMod, Shaper); {data_with_reply, Data, Sender} -> - SockMod:send(Socket, encode_frame(Data, 2)), + SockMod:send(Socket, ejabberd_websocket_codec:encode(Codec, 2, Data)), Sender ! {data_reply, self()}, - ws_loop(FrameInfo, Socket, WsHandleLoopPid, + ws_loop(Codec, Socket, WsHandleLoopPid, SockMod, Shaper); {text, Data} -> - SockMod:send(Socket, encode_frame(Data, 1)), - ws_loop(FrameInfo, Socket, WsHandleLoopPid, + SockMod:send(Socket, ejabberd_websocket_codec:encode(Codec, 1, Data)), + ws_loop(Codec, Socket, WsHandleLoopPid, SockMod, Shaper); {data, Data} -> - SockMod:send(Socket, encode_frame(Data, 2)), - ws_loop(FrameInfo, Socket, WsHandleLoopPid, + SockMod:send(Socket, ejabberd_websocket_codec:encode(Codec, 2, Data)), + ws_loop(Codec, Socket, WsHandleLoopPid, SockMod, Shaper); {ping, Data} -> - SockMod:send(Socket, encode_frame(Data, 9)), - ws_loop(FrameInfo, Socket, WsHandleLoopPid, + SockMod:send(Socket, ejabberd_websocket_codec:encode(Codec, 9, Data)), + ws_loop(Codec, Socket, WsHandleLoopPid, SockMod, Shaper); shutdown -> ?DEBUG("Shutdown request received, closing websocket " "with pid ~p", [self()]), - websocket_close(Socket, WsHandleLoopPid, SockMod, 1001); % going away + websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 1001); % going away _Ignored -> ?WARNING_MSG("Received unexpected message, ignoring: ~p", [_Ignored]), - ws_loop(FrameInfo, Socket, WsHandleLoopPid, + ws_loop(Codec, Socket, WsHandleLoopPid, SockMod, Shaper) end. -encode_frame(Data, Opcode) -> - case byte_size(Data) of - S1 when S1 < 126 -> - <<1:1, 0:3, Opcode:4, 0:1, S1:7, Data/binary>>; - S2 when S2 < 65536 -> - <<1:1, 0:3, Opcode:4, 0:1, 126:7, S2:16, Data/binary>>; - S3 -> - <<1:1, 0:3, Opcode:4, 0:1, 127:7, S3:64, Data/binary>> - end. - --record(frame_info, - {mask = none, offset = 0, left, final_frame = true, - opcode, unprocessed = <<>>, unmasked = <<>>, - unmasked_msg = <<>>}). - -decode_header(<>) - when Len < 126 -> - {Len, Final, Opcode, none, Data}; -decode_header(<>) -> - {Len, Final, Opcode, none, Data}; -decode_header(<>) -> - {Len, Final, Opcode, none, Data}; -decode_header(<>) - when Len < 126 -> - {Len, Final, Opcode, Mask, Data}; -decode_header(<>) -> - {Len, Final, Opcode, Mask, Data}; -decode_header(<>) -> - {Len, Final, Opcode, Mask, Data}; -decode_header(_) -> none. - -unmask_int(Offset, _, <<>>, Acc) -> - {Acc, Offset}; -unmask_int(0, <> = Mask, - <>, Acc) -> - unmask_int(0, Mask, Rest, - <>); -unmask_int(0, <> = Mask, - <>, Acc) -> - unmask_int(1, Mask, Rest, - <>); -unmask_int(1, <<_:8, M:8, _/binary>> = Mask, - <>, Acc) -> - unmask_int(2, Mask, Rest, - <>); -unmask_int(2, <<_:16, M:8, _/binary>> = Mask, - <>, Acc) -> - unmask_int(3, Mask, Rest, - <>); -unmask_int(3, <<_:24, M:8>> = Mask, - <>, Acc) -> - unmask_int(0, Mask, Rest, - <>). - -unmask(#frame_info{mask = none} = State, Data) -> - {State, Data}; -unmask(#frame_info{mask = Mask, offset = Offset} = State, Data) -> - {Unmasked, NewOffset} = unmask_int(Offset, Mask, - Data, <<>>), - {State#frame_info{offset = NewOffset}, Unmasked}. - -process_frame(none, Data) -> - process_frame(#frame_info{}, Data); -process_frame(#frame_info{left = Left} = FrameInfo, <<>>) when Left > 0 -> - {FrameInfo, [], []}; -process_frame(#frame_info{unprocessed = none, - unmasked = UnmaskedPre, left = Left} = - State, - Data) - when byte_size(Data) < Left -> - {State2, Unmasked} = unmask(State, Data), - {State2#frame_info{left = Left - byte_size(Data), - unmasked = [UnmaskedPre, Unmasked]}, - [], []}; -process_frame(#frame_info{unprocessed = none, - unmasked = UnmaskedPre, opcode = Opcode, - final_frame = Final, left = Left, - unmasked_msg = UnmaskedMsg} = - FrameInfo, - Data) -> - <> = Data, - {_, Unmasked} = unmask(FrameInfo, ToProcess), - case Final of - true -> - {FrameInfo3, Recv, Send} = process_frame(#frame_info{}, - Unprocessed), - case Opcode of - X when X < 3 -> - {FrameInfo3, - [iolist_to_binary([UnmaskedMsg, UnmaskedPre, Unmasked]) - | Recv], - Send}; - 9 -> % Ping - Frame = encode_frame(Unmasked, 10), - {FrameInfo3#frame_info{unmasked_msg = UnmaskedMsg}, [ping | Recv], - [Frame | Send]}; - 10 -> % Pong - {FrameInfo3, [pong | Recv], Send}; - 8 -> % Close - CloseCode = case Unmasked of - <> -> - ?DEBUG("WebSocket close op: ~p ~ts", - [Code, Message]), - Code; - <> -> - ?DEBUG("WebSocket close op: ~p", [Code]), - Code; - _ -> - ?DEBUG("WebSocket close op unknown: ~p", - [Unmasked]), - 1000 - end, - - Frame = encode_frame(<>, 8), - {FrameInfo3#frame_info{unmasked_msg=UnmaskedMsg}, Recv, - [Frame | Send]}; - _ -> - {FrameInfo3#frame_info{unmasked_msg = UnmaskedMsg}, Recv, - Send} - end; - _ -> - process_frame(#frame_info{unmasked_msg = - [UnmaskedMsg, UnmaskedPre, - Unmasked]}, - Unprocessed) - end; -process_frame(#frame_info{unprocessed = <<>>} = - FrameInfo, - Data) -> - case decode_header(Data) of - none -> - {FrameInfo#frame_info{unprocessed = Data}, [], []}; - {Len, Final, Opcode, Mask, Rest} -> - process_frame(FrameInfo#frame_info{mask = Mask, - final_frame = Final == 1, - left = Len, opcode = Opcode, - unprocessed = none}, - Rest) - end; -process_frame(#frame_info{unprocessed = - UnprocessedPre} = - FrameInfo, - Data) -> - process_frame(FrameInfo#frame_info{unprocessed = <<>>}, - <>). - -handle_data(tcp, FrameInfo, Data, Socket, WsHandleLoopPid, fast_tls, Shaper) -> +handle_data(tcp, Codec, Data, Socket, WsHandleLoopPid, fast_tls, Shaper) -> case fast_tls:recv_data(Socket, Data) of {ok, NewData} -> - handle_data_int(FrameInfo, NewData, Socket, WsHandleLoopPid, fast_tls, Shaper); + handle_data_int(Codec, NewData, Socket, WsHandleLoopPid, fast_tls, Shaper); {error, Error} -> - {error, Error} + {error, tls, Error} end; -handle_data(_, FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper) -> - handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper). +handle_data(_, Codec, Data, Socket, WsHandleLoopPid, SockMod, Shaper) -> + handle_data_int(Codec, Data, Socket, WsHandleLoopPid, SockMod, Shaper). -handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper) -> - {NewFrameInfo, Recv, Send} = process_frame(FrameInfo, Data), - lists:foreach(fun (El) -> - case El of - pong -> - WsHandleLoopPid ! pong; - ping -> - WsHandleLoopPid ! ping; - _ -> - WsHandleLoopPid ! {received, El} - end - end, - Recv), - {NewFrameInfo, Send, handle_shaping(Data, Socket, SockMod, Shaper)}. +handle_data_int(Codec, Data, Socket, WsHandleLoopPid, SockMod, Shaper) -> + {Type, NewCodec, Recv} = ejabberd_websocket_codec:decode(Codec, Data), + Send = + lists:filtermap( + fun({Op, Payload}) when Op == 1; Op == 2 -> + WsHandleLoopPid ! {received, Payload}, + false; + ({8, Payload}) -> + CloseCode = + case Payload of + <> -> + ?DEBUG("WebSocket close op: ~p ~ts", + [Code, Message]), + Code; + <> -> + ?DEBUG("WebSocket close op: ~p", [Code]), + Code; + _ -> + ?DEBUG("WebSocket close op unknown: ~p", [Payload]), + 1000 + end, + Frame = ejabberd_websocket_codec:encode(Codec, 8, <>), + {true, Frame}; + ({9, Payload}) -> + WsHandleLoopPid ! ping, + Frame = ejabberd_websocket_codec:encode(Codec, 10, Payload), + {true, Frame}; + ({10, _Payload}) -> + WsHandleLoopPid ! pong, + false + end, Recv), + case Type of + error -> + {error, protocol, NewCodec}; + _ -> + {NewCodec, Send, handle_shaping(Data, Socket, SockMod, Shaper)} + end. -websocket_close(Socket, WsHandleLoopPid, +websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, CloseCode) when CloseCode > 0 -> - Frame = encode_frame(<>, 8), + Frame = ejabberd_websocket_codec:encode(Codec, 8, <>), SockMod:send(Socket, Frame), - websocket_close(Socket, WsHandleLoopPid, SockMod, 0); -websocket_close(Socket, WsHandleLoopPid, SockMod, _CloseCode) -> + websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 0); +websocket_close(_Codec, Socket, WsHandleLoopPid, SockMod, _CloseCode) -> WsHandleLoopPid ! closed, SockMod:close(Socket). diff --git a/src/ejabberd_websocket_codec.erl b/src/ejabberd_websocket_codec.erl new file mode 100644 index 000000000..4cdc7a3db --- /dev/null +++ b/src/ejabberd_websocket_codec.erl @@ -0,0 +1,175 @@ +%% +% File : ejabberd_websocket_codec.erl +% Author : Paweł Chmielowski +% Purpose : Coder/Encoder of websocket frames +% Created : 9 sty 2023 by Paweł Chmielowski +% +% +% ejabberd, Copyright (C) 2002-2023 ProcessOne +% +% This program is free software; you can redistribute it and/or +% modify it under the terms of the GNU General Public License as +% published by the Free Software Foundation; either version 2 of the +% License, or (at your option) any later version. +% +% This program is distributed in the hope that it will be useful, +% but WITHOUT ANY WARRANTY; without even the implied warranty of +% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +% General Public License for more details. +% +% You should have received a copy of the GNU General Public License along +% with this program; if not, write to the Free Software Foundation, Inc., +% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +% +% +-module(ejabberd_websocket_codec). +-author("pawel@process-one.net"). + +%% API +-export([new_server/0, new_client/0, decode/2, encode/3]). + +-record(codec_state, { + our_mask = none :: none | binary(), + partial = none :: none | {non_neg_integer(), binary()}, + opcode = 0 :: non_neg_integer(), + is_fin = false :: boolean(), + mask = none :: none | binary(), + mask_offset = 0 :: non_neg_integer(), + required = -1 :: integer(), + data = <<>> :: binary() +}). + +-opaque codec_state() :: #codec_state{}. +-export_type([codec_state/0]). + +-spec new_server() -> codec_state(). +new_server() -> + #codec_state{}. + +new_client() -> + #codec_state{our_mask = p1_rand:bytes(4)}. + +-spec decode(codec_state(), binary()) -> {ok, codec_state(), [binary()]} | {error, atom(), [binary()]}. +decode(#codec_state{required = -1, data = PrevData, partial = Partial} = S, Data) -> + Data2 = <>, + case parse_header(Data2) of + none -> + {ok, S#codec_state{data = Data2}, []}; + {_, _, Opcode, _, _} when (Opcode > 2 andalso Opcode < 8) orelse (Opcode > 10) -> + {error, unknown_opcode, []}; + {_, 0, Opcode, _, _} when Opcode > 7 -> + {error, partial_control_frame, []}; + {_, _, Opcode, _, _} when Opcode > 0 andalso Opcode < 8 andalso Partial /= none -> + {error, partial_frame_non_finished, []}; + {Len, Final, Opcode, Mask, Payload} -> + decode(S#codec_state{opcode = Opcode, is_fin = Final == 1, + mask = Mask, mask_offset = 0, + required = Len, data = <<>>}, Payload) + end; +decode(#codec_state{required = Req, data = PrevData, + mask = Mask, mask_offset = Offset} = S, Data) + when byte_size(PrevData) + byte_size(Data) < Req -> + {Unmasked, NewOffset} = apply_mask(Offset, Mask, Data, PrevData), + {ok, S#codec_state{data = Unmasked, mask_offset = NewOffset}, []}; +decode(#codec_state{required = Req, data = PrevData, + mask = Mask, mask_offset = Offset, + is_fin = IsFin, opcode = Opcode, + partial = Partial} = S, Data) -> + Left = Req - byte_size(PrevData), + <> = Data, + {Unmasked, _} = apply_mask(Offset, Mask, CurrentPayload, PrevData), + {NS, Packets} = + case {IsFin, Partial} of + {false, none} -> + {S#codec_state{partial = {Opcode, Unmasked}, + data = <<>>, required = -1}, []}; + {false, {PartOp, PartData}} -> + {S#codec_state{partial = {PartOp, <>}, + data = <<>>, required = -1}, []}; + {true, none} -> + {S#codec_state{data = <<>>, required = -1}, [{Opcode, Unmasked}]}; + {true, {PartOp, PartData}} -> + {S#codec_state{partial = none, data = <<>>, required = -1}, + [{PartOp, <>}]} + end, + case NextPacketData of + <<>> -> + {ok, NS, Packets}; + _ -> + case decode(NS, NextPacketData) of + {T1, T2, Packets2} -> + {T1, T2, Packets ++ Packets2} + end + end. + +-spec encode(codec_state(), non_neg_integer(), binary()) -> binary(). +encode(#codec_state{our_mask = none}, Opcode, Data) -> + case byte_size(Data) of + S1 when S1 < 126 -> + <<1:1, 0:3, Opcode:4, 0:1, S1:7, Data/binary>>; + S2 when S2 < 65536 -> + <<1:1, 0:3, Opcode:4, 0:1, 126:7, S2:16, Data/binary>>; + S3 -> + <<1:1, 0:3, Opcode:4, 0:1, 127:7, S3:64, Data/binary>> + end; +encode(#codec_state{our_mask = Mask}, Opcode, Data) -> + {MaskedData, _} = apply_mask(0, Mask, Data, <<>>), + case byte_size(Data) of + S1 when S1 < 126 -> + <<1:1, 0:3, Opcode:4, 1:1, S1:7, Mask/binary, MaskedData/binary>>; + S2 when S2 < 65536 -> + <<1:1, 0:3, Opcode:4, 1:1, 126:7, S2:16, Mask/binary, MaskedData/binary>>; + S3 -> + <<1:1, 0:3, Opcode:4, 1:1, 127:7, S3:64, Mask/binary, MaskedData/binary>> + end. + + +-spec parse_header(binary()) -> none | {integer(), integer(), integer(), none | binary(), binary()}. +parse_header(<>) + when Len < 126 -> + {Len, Final, Opcode, none, Data}; +parse_header(<>) -> + {Len, Final, Opcode, none, Data}; +parse_header(<>) -> + {Len, Final, Opcode, none, Data}; +parse_header(<>) + when Len < 126 -> + {Len, Final, Opcode, Mask, Data}; +parse_header(<>) -> + {Len, Final, Opcode, Mask, Data}; +parse_header(<>) -> + {Len, Final, Opcode, Mask, Data}; +parse_header(_) -> + none. + +-spec apply_mask(integer(), none | binary(), binary(), binary()) -> {binary(), non_neg_integer()}. +apply_mask(_, none, Data, _) -> + {Data, 0}; +apply_mask(Offset, _, <<>>, Acc) -> + {Acc, Offset}; +apply_mask(0, <> = Mask, + <>, Acc) -> + apply_mask(0, Mask, Rest, + <>); +apply_mask(0, <> = Mask, + <>, Acc) -> + apply_mask(1, Mask, Rest, + <>); +apply_mask(1, <<_:8, M:8, _/binary>> = Mask, + <>, Acc) -> + apply_mask(2, Mask, Rest, + <>); +apply_mask(2, <<_:16, M:8, _/binary>> = Mask, + <>, Acc) -> + apply_mask(3, Mask, Rest, + <>); +apply_mask(3, <<_:24, M:8>> = Mask, + <>, Acc) -> + apply_mask(0, Mask, Rest, + <>). diff --git a/src/mod_mqtt_bridge.erl b/src/mod_mqtt_bridge.erl index 14879b6df..4aeafc7e6 100644 --- a/src/mod_mqtt_bridge.erl +++ b/src/mod_mqtt_bridge.erl @@ -35,7 +35,7 @@ start(Host, Opts) -> User = mod_mqtt_bridge_opt:replication_user(Opts), lists:foldl( - fun({Proc, Transport, HostAddr, Port, Publish, Subscribe, Authentication}, Started) -> + fun({Proc, Transport, HostAddr, Port, Path, Publish, Subscribe, Authentication}, Started) -> case Started of #{Proc := _} -> ?DEBUG("Already started ~p", [Proc]), @@ -43,7 +43,7 @@ start(Host, Opts) -> _ -> ChildSpec = {Proc, {mod_mqtt_bridge_session, start_link, - [Proc, Transport, HostAddr, Port, Publish, Subscribe, Authentication, User]}, + [Proc, Transport, HostAddr, Port, Path, Publish, Subscribe, Authentication, User]}, transient, 1000, worker, @@ -107,7 +107,7 @@ mod_opt_type(replication_user) -> econf:jid(); mod_opt_type(servers) -> econf:and_then( - econf:map(econf:url([mqtt, mqtts, mqtt5, mqtt5s]), + econf:map(econf:url([mqtt, mqtts, mqtt5, mqtt5s, ws, wss]), econf:options( #{ publish => econf:map(econf:binary(), econf:binary(), [{return, map}]), @@ -127,9 +127,10 @@ mod_opt_type(servers) -> fun(Servers) -> maps:fold( fun(Url, Opts, {HAcc, PAcc}) -> - {ok, Scheme, _UserInfo, Host, Port, _Path, _Query} = + {ok, Scheme, _UserInfo, Host, Port, Path, _Query} = misc:uri_parse(Url, [{mqtt, 1883}, {mqtts, 8883}, - {mqtt5, 1883}, {mqtt5s, 8883}]), + {mqtt5, 1883}, {mqtt5s, 8883}, + {ws, 80}, {wss, 443}]), Publish = maps:get(publish, Opts, #{}), Subscribe = maps:get(subscribe, Opts, #{}), Authentication = maps:get(authentication, Opts, []), @@ -139,7 +140,7 @@ mod_opt_type(servers) -> fun(Topic, _RemoteTopic, Acc) -> maps:update_with(Topic, fun(V) -> [Proc | V] end, [Proc], Acc) end, PAcc, Publish), - {[{Proc, Proto, Host, Port, Publish, Subscribe, Authentication} | HAcc], PAcc2} + {[{Proc, Proto, Host, Port, Path, Publish, Subscribe, Authentication} | HAcc], PAcc2} end, {[], #{}}, Servers) end ). diff --git a/src/mod_mqtt_bridge_session.erl b/src/mod_mqtt_bridge_session.erl index eeb40096f..1877d8941 100644 --- a/src/mod_mqtt_bridge_session.erl +++ b/src/mod_mqtt_bridge_session.erl @@ -21,7 +21,7 @@ -vsn(?VSN). %% API --export([start/8, start_link/8]). +-export([start/9, start_link/9]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). @@ -66,6 +66,7 @@ stop_reason :: undefined | error_reason(), subscriptions = #{}, publish = #{}, + ws_codec = none, id = 0 :: non_neg_integer(), codec :: mqtt_codec:state(), authentication :: #{username => binary(), password => binary(), certfile => binary()}}). @@ -75,23 +76,25 @@ %%%=================================================================== %%% API %%%=================================================================== -start(Proc, Transport, Host, Port, Publish, Subscribe, Authentication, ReplicationUser) -> - p1_server:start({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Publish, Subscribe, Authentication, +start(Proc, Transport, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser) -> + p1_server:start({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser], []). -start_link(Proc, Transport, Host, Port, Publish, Subscribe, Authentication, ReplicationUser) -> - p1_server:start_link({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Publish, Subscribe, +start_link(Proc, Transport, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser) -> + p1_server:start_link({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser], []). %%%=================================================================== %%% gen_server callbacks %%%=================================================================== -init([_Proc, Proto, Host, Port, Publish, Subscribe, Authentication, ReplicationUser]) -> +init([_Proc, Proto, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser]) -> {Version, Transport} = case Proto of mqtt -> {4, gen_tcp}; mqtts -> {4, ssl}; mqtt5 -> {5, gen_tcp}; - mqtt5s -> {5, ssl} + mqtt5s -> {5, ssl}; + ws -> {4, gen_tcp}; + wss -> {4, ssl} end, State = #state{version = Version, id = p1_rand:uniform(65535), @@ -101,12 +104,20 @@ init([_Proc, Proto, Host, Port, Publish, Subscribe, Authentication, ReplicationU usr = jid:tolower(ReplicationUser), publish = Publish}, case Authentication of - #{certfile := Cert} when Proto == mqtts; Proto == mqtt5s -> - connect(ssl:connect(Host, Port, [binary, {certfile, Cert}]), State, ssl, none); + #{certfile := Cert} when Proto == mqtts; Proto == mqtt5s; Proto == wss -> + Sock = ssl:connect(Host, Port, [binary, {active, true}, {certfile, Cert}]), + if Proto == ws orelse Proto == wss -> + connect_ws(Host, Port, Path, Sock, State, ssl, none); + true -> connect(Sock, State, ssl, none) + end; #{username := User, password := Pass} -> - connect(Transport:connect(Host, Port, [binary]), State, Transport, {User, Pass}); + Sock = Transport:connect(Host, Port, [binary, {active, true}]), + if Proto == ws orelse Proto == wss -> + connect_ws(Host, Port, Path, Sock, State, Transport, {User, Pass}); + true -> connect(Sock, State, Transport, {User, Pass}) + end; _ -> - {stop, {error, <<"Certificate can be only used for encrypted connections">>}} + {stop, {error, <<"Certificate can be only used for encrypted connections">> }} end. handle_call(Request, From, State) -> @@ -118,20 +129,108 @@ handle_cast(Msg, State) -> {noreply, State}. handle_info({Tag, TCPSock, TCPData}, - #state{codec = Codec, socket = Socket} = State) when Tag == tcp; Tag == ssl -> + #state{ws_codec = {init, Hash, Auth, Last}} = State) + when (Tag == tcp orelse Tag == ssl) -> + Data = <>, + case erlang:decode_packet(http_bin, Data, []) of + {ok, {http_response, _, 101, _}, Rest} -> + handle_info({tcp, TCPSock, Rest}, State#state{ws_codec = {inith, Hash, none, Auth, <<>>}}); + {ok, {http_response, _, _, _}, _Rest} -> + stop(State, {socket, closed}); + {ok, {http_error, _}, _} -> + stop(State, {socket, closed}); + {error, _} -> + stop(State, {socket, closed}); + {more, _} -> + {noreply, State#state{ws_codec = {init, Hash, Auth, Data}}} + end; +handle_info({Tag, TCPSock, TCPData}, + #state{ws_codec = {inith, Hash, Upgrade, Auth, Last}, + socket = {Transport, _}} = State) + when (Tag == tcp orelse Tag == ssl) -> + Data = <>, + case erlang:decode_packet(httph_bin, Data, []) of + {ok, {http_header, _, <<"Sec-Websocket-Accept">>, _, Val}, Rest} -> + case str:to_lower(Val) of + Hash -> + handle_info({tcp, TCPSock, Rest}, + State#state{ws_codec = {inith, ok, Upgrade, Auth, <<>>}}); + _ -> + stop(State, {socket, closed}) + end; + {ok, {http_header, _, 'Connection', _, Val}, Rest} -> + case str:to_lower(Val) of + <<"upgrade">> -> + handle_info({tcp, TCPSock, Rest}, + State#state{ws_codec = {inith, Hash, ok, Auth, <<>>}}); + _ -> + stop(State, {socket, closed}) + end; + {ok, {http_header, _, _, _, _}, Rest} -> + handle_info({tcp, TCPSock, Rest}, State); + {ok, {http_error, _}, _} -> + stop(State, {socket, closed}); + {ok, http_eoh, Rest} -> + case {Hash, Upgrade} of + {ok, ok} -> + {ok, State2} = connect({ok, TCPSock}, + State#state{ws_codec = ejabberd_websocket_codec:new_client()}, + Transport, Auth), + handle_info({tcp, TCPSock, Rest}, State2); + _ -> + stop(State, {socket, closed}) + end; + {error, _} -> + stop(State, {socket, closed}); + {more, _} -> + {noreply, State#state{ws_codec = {inith, Hash, Upgrade, Data}}} + end; +handle_info({Tag, TCPSock, TCPData}, + #state{ws_codec = WSCodec} = State) + when (Tag == tcp orelse Tag == ssl) andalso WSCodec /= none -> + {Packets, Acc0} = + case ejabberd_websocket_codec:decode(WSCodec, TCPData) of + {ok, NewWSCodec, Packets0} -> + {Packets0, {State#state{ws_codec = NewWSCodec}, ok}}; + {error, _Error, Packets0} -> + {Packets0, {State, stop}} + end, + Res2 = + lists:foldl( + fun(_, {stop, _, _} = Res) -> Res; + ({_Op, Data}, {S, Res}) -> + case handle_info({tcp_decoded, TCPSock, Data}, S) of + {stop, _, _} = Stop -> + Stop; + {_, NewState, _} -> + {NewState, Res}; + {_, NewState} -> + {NewState, Res} + end + end, Acc0, Packets), + case Res2 of + {stop, _, _} -> + Res2; + {NewState2, ok} -> + {noreply, NewState2}; + {NewState2, stop} -> + stop(NewState2, {socket, closed}) + end; +handle_info({Tag, TCPSock, TCPData}, + #state{codec = Codec} = State) + when Tag == tcp; Tag == ssl; Tag == tcp_decoded -> case mqtt_codec:decode(Codec, TCPData) of {ok, Pkt, Codec1} -> ?DEBUG("Got MQTT packet:~n~ts", [pp(Pkt)]), State1 = State#state{codec = Codec1}, case handle_packet(Pkt, State1) of {ok, State2} -> - handle_info({tcp, TCPSock, <<>>}, State2); + handle_info({tcp_decoded, TCPSock, <<>>}, State2); {error, State2, Reason} -> stop(State2, Reason) end; {more, Codec1} -> State1 = State#state{codec = Codec1}, - activate(Socket), {noreply, State1}; {error, Why} -> stop(State, {codec, Why}) @@ -156,7 +255,7 @@ handle_info({publish, #publish{topic = Topic} = Pkt}, #state{publish = Publish} {noreply, State2} end; _ -> - State + {noreply, State} end; handle_info({timeout, _TRef, ping_timeout}, State) -> case send(State, #pingreq{}) of @@ -230,6 +329,22 @@ connect({ok, Sock}, State0, Transport, Auth) -> {ok, _, Codec2} = mqtt_codec:decode(State#state.codec, Pkt), {ok, State#state{codec = Codec2}}. +connect_ws(_Host, _Port, _Path, {error, Reason}, _State, _Transport, _Auth) -> + {stop, {error, Reason}}; +connect_ws(Host, Port, Path, {ok, Sock}, State0, Transport, Auth) -> + Key = base64:encode(p1_rand:get_string()), + Hash = str:to_lower(base64:encode(crypto:hash(sha, <>))), + Data = <<"GET ", (list_to_binary(Path))/binary, " HTTP/1.1\r\n", + "Host: ", (list_to_binary(Host))/binary, ":", (integer_to_binary(Port))/binary,"\r\n", + "Upgrade: websocket\r\n", + "Connection: Upgrade\r\n", + "Sec-WebSocket-Protocol: mqtt\r\n", + "Sec-WebSocket-Key: ", Key/binary, "\r\n", + "Sec-WebSocket-Version: 13\r\n\r\n">>, + Res = Transport:send(Sock, Data), + check_sock_result({Transport, Sock}, Res), + {ok, State0#state{ws_codec = {init, Hash, Auth, <<>>}, socket = {Transport, Sock}}}. + -spec stop(state(), error_reason()) -> {noreply, state(), infinity} | {stop, normal, state()}. @@ -286,6 +401,14 @@ send(State, Pkt) -> {ok, do_send(State, Pkt)}. -spec do_send(state(), mqtt_packet()) -> state(). +do_send(#state{ws_codec = WSCodec, socket = {SockMod, Sock} = Socket} = State, Pkt) + when WSCodec /= none -> + ?DEBUG("Send MQTT packet:~n~ts", [pp(Pkt)]), + Data = mqtt_codec:encode(State#state.version, Pkt), + WSData = ejabberd_websocket_codec:encode(WSCodec, 2, Data), + Res = SockMod:send(Sock, WSData), + check_sock_result(Socket, Res), + reset_ping_timer(State); do_send(#state{socket = {SockMod, Sock} = Socket} = State, Pkt) -> ?DEBUG("Send MQTT packet:~n~ts", [pp(Pkt)]), Data = mqtt_codec:encode(State#state.version, Pkt), @@ -295,14 +418,6 @@ do_send(#state{socket = {SockMod, Sock} = Socket} = State, Pkt) -> do_send(State, _Pkt) -> State. --spec activate(socket()) -> ok. -activate(Socket) -> - Res = case Socket of - {gen_tcp, Sock} -> inet:setopts(Sock, [{active, once}]); - {SockMod, Sock} -> SockMod:setopts(Sock, [{active, once}]) - end, - check_sock_result(Socket, Res). - -spec disconnect(state(), error_reason()) -> state(). disconnect(#state{socket = {SockMod, Sock}} = State, Err) -> State1 = case Err of