25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-12-22 17:28:25 +01:00

Add support for websockets to mqtt bridge

This commit is contained in:
Paweł Chmielowski 2023-01-13 19:40:53 +01:00
parent c103182bc7
commit 4311a5646f
4 changed files with 392 additions and 227 deletions

View File

@ -155,7 +155,7 @@ connect(#ws{socket = Socket, sockmod = SockMod} = Ws, WsLoop) ->
_ ->
SockMod:setopts(Socket, [{packet, 0}, {active, true}])
end,
ws_loop(none, Socket, WsHandleLoopPid, SockMod, none).
ws_loop(ejabberd_websocket_codec:new_server(), Socket, WsHandleLoopPid, SockMod, none).
handshake(#ws{headers = Headers} = State) ->
{_, Key} = lists:keyfind(<<"Sec-Websocket-Key">>, 1,
@ -188,17 +188,20 @@ find_subprotocol(Headers) ->
end.
ws_loop(FrameInfo, Socket, WsHandleLoopPid, SockMod, Shaper) ->
ws_loop(Codec, Socket, WsHandleLoopPid, SockMod, Shaper) ->
receive
{DataType, _Socket, Data} when DataType =:= tcp orelse DataType =:= raw ->
case handle_data(DataType, FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper) of
{error, Error} ->
case handle_data(DataType, Codec, Data, Socket, WsHandleLoopPid, SockMod, Shaper) of
{error, tls, Error} ->
?DEBUG("TLS decode error ~p", [Error]),
websocket_close(Socket, WsHandleLoopPid, SockMod, 1002); % protocol error
{NewFrameInfo, ToSend, NewShaper} ->
websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 1002); % protocol error
{error, protocol, Error} ->
?DEBUG("Websocket decode error ~p", [Error]),
websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 1002); % protocol error
{NewCodec, ToSend, NewShaper} ->
lists:foreach(fun(Pkt) -> SockMod:send(Socket, Pkt)
end, ToSend),
ws_loop(NewFrameInfo, Socket, WsHandleLoopPid, SockMod, NewShaper)
ws_loop(NewCodec, Socket, WsHandleLoopPid, SockMod, NewShaper)
end;
{new_shaper, NewShaper} ->
NewShaper = case NewShaper of
@ -207,13 +210,13 @@ ws_loop(FrameInfo, Socket, WsHandleLoopPid, SockMod, Shaper) ->
_ ->
NewShaper
end,
ws_loop(FrameInfo, Socket, WsHandleLoopPid, SockMod, NewShaper);
ws_loop(Codec, Socket, WsHandleLoopPid, SockMod, NewShaper);
{tcp_closed, _Socket} ->
?DEBUG("TCP connection was closed, exit", []),
websocket_close(Socket, WsHandleLoopPid, SockMod, 0);
websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 0);
{tcp_error, Socket, Reason} ->
?DEBUG("TCP connection error: ~ts", [inet:format_error(Reason)]),
websocket_close(Socket, WsHandleLoopPid, SockMod, 0);
websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 0);
{'DOWN', Ref, process, WsHandleLoopPid, Reason} ->
Code = case Reason of
normal ->
@ -225,224 +228,95 @@ ws_loop(FrameInfo, Socket, WsHandleLoopPid, SockMod, Shaper) ->
1011 % internal error
end,
erlang:demonitor(Ref),
websocket_close(Socket, WsHandleLoopPid, SockMod, Code);
websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, Code);
{text_with_reply, Data, Sender} ->
SockMod:send(Socket, encode_frame(Data, 1)),
SockMod:send(Socket, ejabberd_websocket_codec:encode(Codec, 1, Data)),
Sender ! {text_reply, self()},
ws_loop(FrameInfo, Socket, WsHandleLoopPid,
ws_loop(Codec, Socket, WsHandleLoopPid,
SockMod, Shaper);
{data_with_reply, Data, Sender} ->
SockMod:send(Socket, encode_frame(Data, 2)),
SockMod:send(Socket, ejabberd_websocket_codec:encode(Codec, 2, Data)),
Sender ! {data_reply, self()},
ws_loop(FrameInfo, Socket, WsHandleLoopPid,
ws_loop(Codec, Socket, WsHandleLoopPid,
SockMod, Shaper);
{text, Data} ->
SockMod:send(Socket, encode_frame(Data, 1)),
ws_loop(FrameInfo, Socket, WsHandleLoopPid,
SockMod:send(Socket, ejabberd_websocket_codec:encode(Codec, 1, Data)),
ws_loop(Codec, Socket, WsHandleLoopPid,
SockMod, Shaper);
{data, Data} ->
SockMod:send(Socket, encode_frame(Data, 2)),
ws_loop(FrameInfo, Socket, WsHandleLoopPid,
SockMod:send(Socket, ejabberd_websocket_codec:encode(Codec, 2, Data)),
ws_loop(Codec, Socket, WsHandleLoopPid,
SockMod, Shaper);
{ping, Data} ->
SockMod:send(Socket, encode_frame(Data, 9)),
ws_loop(FrameInfo, Socket, WsHandleLoopPid,
SockMod:send(Socket, ejabberd_websocket_codec:encode(Codec, 9, Data)),
ws_loop(Codec, Socket, WsHandleLoopPid,
SockMod, Shaper);
shutdown ->
?DEBUG("Shutdown request received, closing websocket "
"with pid ~p",
[self()]),
websocket_close(Socket, WsHandleLoopPid, SockMod, 1001); % going away
websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 1001); % going away
_Ignored ->
?WARNING_MSG("Received unexpected message, ignoring: ~p",
[_Ignored]),
ws_loop(FrameInfo, Socket, WsHandleLoopPid,
ws_loop(Codec, Socket, WsHandleLoopPid,
SockMod, Shaper)
end.
encode_frame(Data, Opcode) ->
case byte_size(Data) of
S1 when S1 < 126 ->
<<1:1, 0:3, Opcode:4, 0:1, S1:7, Data/binary>>;
S2 when S2 < 65536 ->
<<1:1, 0:3, Opcode:4, 0:1, 126:7, S2:16, Data/binary>>;
S3 ->
<<1:1, 0:3, Opcode:4, 0:1, 127:7, S3:64, Data/binary>>
end.
-record(frame_info,
{mask = none, offset = 0, left, final_frame = true,
opcode, unprocessed = <<>>, unmasked = <<>>,
unmasked_msg = <<>>}).
decode_header(<<Final:1, _:3, Opcode:4, 0:1,
Len:7, Data/binary>>)
when Len < 126 ->
{Len, Final, Opcode, none, Data};
decode_header(<<Final:1, _:3, Opcode:4, 0:1,
126:7, Len:16/integer, Data/binary>>) ->
{Len, Final, Opcode, none, Data};
decode_header(<<Final:1, _:3, Opcode:4, 0:1,
127:7, Len:64/integer, Data/binary>>) ->
{Len, Final, Opcode, none, Data};
decode_header(<<Final:1, _:3, Opcode:4, 1:1,
Len:7, Mask:4/binary, Data/binary>>)
when Len < 126 ->
{Len, Final, Opcode, Mask, Data};
decode_header(<<Final:1, _:3, Opcode:4, 1:1,
126:7, Len:16/integer, Mask:4/binary, Data/binary>>) ->
{Len, Final, Opcode, Mask, Data};
decode_header(<<Final:1, _:3, Opcode:4, 1:1,
127:7, Len:64/integer, Mask:4/binary, Data/binary>>) ->
{Len, Final, Opcode, Mask, Data};
decode_header(_) -> none.
unmask_int(Offset, _, <<>>, Acc) ->
{Acc, Offset};
unmask_int(0, <<M:32>> = Mask,
<<N:32, Rest/binary>>, Acc) ->
unmask_int(0, Mask, Rest,
<<Acc/binary, (M bxor N):32>>);
unmask_int(0, <<M:8, _/binary>> = Mask,
<<N:8, Rest/binary>>, Acc) ->
unmask_int(1, Mask, Rest,
<<Acc/binary, (M bxor N):8>>);
unmask_int(1, <<_:8, M:8, _/binary>> = Mask,
<<N:8, Rest/binary>>, Acc) ->
unmask_int(2, Mask, Rest,
<<Acc/binary, (M bxor N):8>>);
unmask_int(2, <<_:16, M:8, _/binary>> = Mask,
<<N:8, Rest/binary>>, Acc) ->
unmask_int(3, Mask, Rest,
<<Acc/binary, (M bxor N):8>>);
unmask_int(3, <<_:24, M:8>> = Mask,
<<N:8, Rest/binary>>, Acc) ->
unmask_int(0, Mask, Rest,
<<Acc/binary, (M bxor N):8>>).
unmask(#frame_info{mask = none} = State, Data) ->
{State, Data};
unmask(#frame_info{mask = Mask, offset = Offset} = State, Data) ->
{Unmasked, NewOffset} = unmask_int(Offset, Mask,
Data, <<>>),
{State#frame_info{offset = NewOffset}, Unmasked}.
process_frame(none, Data) ->
process_frame(#frame_info{}, Data);
process_frame(#frame_info{left = Left} = FrameInfo, <<>>) when Left > 0 ->
{FrameInfo, [], []};
process_frame(#frame_info{unprocessed = none,
unmasked = UnmaskedPre, left = Left} =
State,
Data)
when byte_size(Data) < Left ->
{State2, Unmasked} = unmask(State, Data),
{State2#frame_info{left = Left - byte_size(Data),
unmasked = [UnmaskedPre, Unmasked]},
[], []};
process_frame(#frame_info{unprocessed = none,
unmasked = UnmaskedPre, opcode = Opcode,
final_frame = Final, left = Left,
unmasked_msg = UnmaskedMsg} =
FrameInfo,
Data) ->
<<ToProcess:(Left)/binary, Unprocessed/binary>> = Data,
{_, Unmasked} = unmask(FrameInfo, ToProcess),
case Final of
true ->
{FrameInfo3, Recv, Send} = process_frame(#frame_info{},
Unprocessed),
case Opcode of
X when X < 3 ->
{FrameInfo3,
[iolist_to_binary([UnmaskedMsg, UnmaskedPre, Unmasked])
| Recv],
Send};
9 -> % Ping
Frame = encode_frame(Unmasked, 10),
{FrameInfo3#frame_info{unmasked_msg = UnmaskedMsg}, [ping | Recv],
[Frame | Send]};
10 -> % Pong
{FrameInfo3, [pong | Recv], Send};
8 -> % Close
CloseCode = case Unmasked of
<<Code:16/integer-big, Message/binary>> ->
?DEBUG("WebSocket close op: ~p ~ts",
[Code, Message]),
Code;
<<Code:16/integer-big>> ->
?DEBUG("WebSocket close op: ~p", [Code]),
Code;
_ ->
?DEBUG("WebSocket close op unknown: ~p",
[Unmasked]),
1000
end,
Frame = encode_frame(<<CloseCode:16/integer-big>>, 8),
{FrameInfo3#frame_info{unmasked_msg=UnmaskedMsg}, Recv,
[Frame | Send]};
_ ->
{FrameInfo3#frame_info{unmasked_msg = UnmaskedMsg}, Recv,
Send}
end;
_ ->
process_frame(#frame_info{unmasked_msg =
[UnmaskedMsg, UnmaskedPre,
Unmasked]},
Unprocessed)
end;
process_frame(#frame_info{unprocessed = <<>>} =
FrameInfo,
Data) ->
case decode_header(Data) of
none ->
{FrameInfo#frame_info{unprocessed = Data}, [], []};
{Len, Final, Opcode, Mask, Rest} ->
process_frame(FrameInfo#frame_info{mask = Mask,
final_frame = Final == 1,
left = Len, opcode = Opcode,
unprocessed = none},
Rest)
end;
process_frame(#frame_info{unprocessed =
UnprocessedPre} =
FrameInfo,
Data) ->
process_frame(FrameInfo#frame_info{unprocessed = <<>>},
<<UnprocessedPre/binary, Data/binary>>).
handle_data(tcp, FrameInfo, Data, Socket, WsHandleLoopPid, fast_tls, Shaper) ->
handle_data(tcp, Codec, Data, Socket, WsHandleLoopPid, fast_tls, Shaper) ->
case fast_tls:recv_data(Socket, Data) of
{ok, NewData} ->
handle_data_int(FrameInfo, NewData, Socket, WsHandleLoopPid, fast_tls, Shaper);
handle_data_int(Codec, NewData, Socket, WsHandleLoopPid, fast_tls, Shaper);
{error, Error} ->
{error, Error}
{error, tls, Error}
end;
handle_data(_, FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper) ->
handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper).
handle_data(_, Codec, Data, Socket, WsHandleLoopPid, SockMod, Shaper) ->
handle_data_int(Codec, Data, Socket, WsHandleLoopPid, SockMod, Shaper).
handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper) ->
{NewFrameInfo, Recv, Send} = process_frame(FrameInfo, Data),
lists:foreach(fun (El) ->
case El of
pong ->
WsHandleLoopPid ! pong;
ping ->
WsHandleLoopPid ! ping;
_ ->
WsHandleLoopPid ! {received, El}
end
end,
Recv),
{NewFrameInfo, Send, handle_shaping(Data, Socket, SockMod, Shaper)}.
handle_data_int(Codec, Data, Socket, WsHandleLoopPid, SockMod, Shaper) ->
{Type, NewCodec, Recv} = ejabberd_websocket_codec:decode(Codec, Data),
Send =
lists:filtermap(
fun({Op, Payload}) when Op == 1; Op == 2 ->
WsHandleLoopPid ! {received, Payload},
false;
({8, Payload}) ->
CloseCode =
case Payload of
<<Code:16/integer-big, Message/binary>> ->
?DEBUG("WebSocket close op: ~p ~ts",
[Code, Message]),
Code;
<<Code:16/integer-big>> ->
?DEBUG("WebSocket close op: ~p", [Code]),
Code;
_ ->
?DEBUG("WebSocket close op unknown: ~p", [Payload]),
1000
end,
Frame = ejabberd_websocket_codec:encode(Codec, 8, <<CloseCode:16/integer-big>>),
{true, Frame};
({9, Payload}) ->
WsHandleLoopPid ! ping,
Frame = ejabberd_websocket_codec:encode(Codec, 10, Payload),
{true, Frame};
({10, _Payload}) ->
WsHandleLoopPid ! pong,
false
end, Recv),
case Type of
error ->
{error, protocol, NewCodec};
_ ->
{NewCodec, Send, handle_shaping(Data, Socket, SockMod, Shaper)}
end.
websocket_close(Socket, WsHandleLoopPid,
websocket_close(Codec, Socket, WsHandleLoopPid,
SockMod, CloseCode) when CloseCode > 0 ->
Frame = encode_frame(<<CloseCode:16/integer-big>>, 8),
Frame = ejabberd_websocket_codec:encode(Codec, 8, <<CloseCode:16/integer-big>>),
SockMod:send(Socket, Frame),
websocket_close(Socket, WsHandleLoopPid, SockMod, 0);
websocket_close(Socket, WsHandleLoopPid, SockMod, _CloseCode) ->
websocket_close(Codec, Socket, WsHandleLoopPid, SockMod, 0);
websocket_close(_Codec, Socket, WsHandleLoopPid, SockMod, _CloseCode) ->
WsHandleLoopPid ! closed,
SockMod:close(Socket).

View File

@ -0,0 +1,175 @@
%%
% File : ejabberd_websocket_codec.erl
% Author : Paweł Chmielowski <pawel@process-one.net>
% Purpose : Coder/Encoder of websocket frames
% Created : 9 sty 2023 by Paweł Chmielowski <pawel@process-one.net>
%
%
% ejabberd, Copyright (C) 2002-2023 ProcessOne
%
% This program is free software; you can redistribute it and/or
% modify it under the terms of the GNU General Public License as
% published by the Free Software Foundation; either version 2 of the
% License, or (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
% General Public License for more details.
%
% You should have received a copy of the GNU General Public License along
% with this program; if not, write to the Free Software Foundation, Inc.,
% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%
%
-module(ejabberd_websocket_codec).
-author("pawel@process-one.net").
%% API
-export([new_server/0, new_client/0, decode/2, encode/3]).
-record(codec_state, {
our_mask = none :: none | binary(),
partial = none :: none | {non_neg_integer(), binary()},
opcode = 0 :: non_neg_integer(),
is_fin = false :: boolean(),
mask = none :: none | binary(),
mask_offset = 0 :: non_neg_integer(),
required = -1 :: integer(),
data = <<>> :: binary()
}).
-opaque codec_state() :: #codec_state{}.
-export_type([codec_state/0]).
-spec new_server() -> codec_state().
new_server() ->
#codec_state{}.
new_client() ->
#codec_state{our_mask = p1_rand:bytes(4)}.
-spec decode(codec_state(), binary()) -> {ok, codec_state(), [binary()]} | {error, atom(), [binary()]}.
decode(#codec_state{required = -1, data = PrevData, partial = Partial} = S, Data) ->
Data2 = <<PrevData/binary, Data/binary>>,
case parse_header(Data2) of
none ->
{ok, S#codec_state{data = Data2}, []};
{_, _, Opcode, _, _} when (Opcode > 2 andalso Opcode < 8) orelse (Opcode > 10) ->
{error, unknown_opcode, []};
{_, 0, Opcode, _, _} when Opcode > 7 ->
{error, partial_control_frame, []};
{_, _, Opcode, _, _} when Opcode > 0 andalso Opcode < 8 andalso Partial /= none ->
{error, partial_frame_non_finished, []};
{Len, Final, Opcode, Mask, Payload} ->
decode(S#codec_state{opcode = Opcode, is_fin = Final == 1,
mask = Mask, mask_offset = 0,
required = Len, data = <<>>}, Payload)
end;
decode(#codec_state{required = Req, data = PrevData,
mask = Mask, mask_offset = Offset} = S, Data)
when byte_size(PrevData) + byte_size(Data) < Req ->
{Unmasked, NewOffset} = apply_mask(Offset, Mask, Data, PrevData),
{ok, S#codec_state{data = Unmasked, mask_offset = NewOffset}, []};
decode(#codec_state{required = Req, data = PrevData,
mask = Mask, mask_offset = Offset,
is_fin = IsFin, opcode = Opcode,
partial = Partial} = S, Data) ->
Left = Req - byte_size(PrevData),
<<CurrentPayload:Left/binary, NextPacketData/binary>> = Data,
{Unmasked, _} = apply_mask(Offset, Mask, CurrentPayload, PrevData),
{NS, Packets} =
case {IsFin, Partial} of
{false, none} ->
{S#codec_state{partial = {Opcode, Unmasked},
data = <<>>, required = -1}, []};
{false, {PartOp, PartData}} ->
{S#codec_state{partial = {PartOp, <<PartData/binary, Unmasked/binary>>},
data = <<>>, required = -1}, []};
{true, none} ->
{S#codec_state{data = <<>>, required = -1}, [{Opcode, Unmasked}]};
{true, {PartOp, PartData}} ->
{S#codec_state{partial = none, data = <<>>, required = -1},
[{PartOp, <<PartData/binary, Unmasked/binary>>}]}
end,
case NextPacketData of
<<>> ->
{ok, NS, Packets};
_ ->
case decode(NS, NextPacketData) of
{T1, T2, Packets2} ->
{T1, T2, Packets ++ Packets2}
end
end.
-spec encode(codec_state(), non_neg_integer(), binary()) -> binary().
encode(#codec_state{our_mask = none}, Opcode, Data) ->
case byte_size(Data) of
S1 when S1 < 126 ->
<<1:1, 0:3, Opcode:4, 0:1, S1:7, Data/binary>>;
S2 when S2 < 65536 ->
<<1:1, 0:3, Opcode:4, 0:1, 126:7, S2:16, Data/binary>>;
S3 ->
<<1:1, 0:3, Opcode:4, 0:1, 127:7, S3:64, Data/binary>>
end;
encode(#codec_state{our_mask = Mask}, Opcode, Data) ->
{MaskedData, _} = apply_mask(0, Mask, Data, <<>>),
case byte_size(Data) of
S1 when S1 < 126 ->
<<1:1, 0:3, Opcode:4, 1:1, S1:7, Mask/binary, MaskedData/binary>>;
S2 when S2 < 65536 ->
<<1:1, 0:3, Opcode:4, 1:1, 126:7, S2:16, Mask/binary, MaskedData/binary>>;
S3 ->
<<1:1, 0:3, Opcode:4, 1:1, 127:7, S3:64, Mask/binary, MaskedData/binary>>
end.
-spec parse_header(binary()) -> none | {integer(), integer(), integer(), none | binary(), binary()}.
parse_header(<<Final:1, _:3, Opcode:4, 0:1,
Len:7, Data/binary>>)
when Len < 126 ->
{Len, Final, Opcode, none, Data};
parse_header(<<Final:1, _:3, Opcode:4, 0:1,
126:7, Len:16/integer, Data/binary>>) ->
{Len, Final, Opcode, none, Data};
parse_header(<<Final:1, _:3, Opcode:4, 0:1,
127:7, Len:64/integer, Data/binary>>) ->
{Len, Final, Opcode, none, Data};
parse_header(<<Final:1, _:3, Opcode:4, 1:1,
Len:7, Mask:4/binary, Data/binary>>)
when Len < 126 ->
{Len, Final, Opcode, Mask, Data};
parse_header(<<Final:1, _:3, Opcode:4, 1:1,
126:7, Len:16/integer, Mask:4/binary, Data/binary>>) ->
{Len, Final, Opcode, Mask, Data};
parse_header(<<Final:1, _:3, Opcode:4, 1:1,
127:7, Len:64/integer, Mask:4/binary, Data/binary>>) ->
{Len, Final, Opcode, Mask, Data};
parse_header(_) ->
none.
-spec apply_mask(integer(), none | binary(), binary(), binary()) -> {binary(), non_neg_integer()}.
apply_mask(_, none, Data, _) ->
{Data, 0};
apply_mask(Offset, _, <<>>, Acc) ->
{Acc, Offset};
apply_mask(0, <<M:32>> = Mask,
<<N:32, Rest/binary>>, Acc) ->
apply_mask(0, Mask, Rest,
<<Acc/binary, (M bxor N):32>>);
apply_mask(0, <<M:8, _/binary>> = Mask,
<<N:8, Rest/binary>>, Acc) ->
apply_mask(1, Mask, Rest,
<<Acc/binary, (M bxor N):8>>);
apply_mask(1, <<_:8, M:8, _/binary>> = Mask,
<<N:8, Rest/binary>>, Acc) ->
apply_mask(2, Mask, Rest,
<<Acc/binary, (M bxor N):8>>);
apply_mask(2, <<_:16, M:8, _/binary>> = Mask,
<<N:8, Rest/binary>>, Acc) ->
apply_mask(3, Mask, Rest,
<<Acc/binary, (M bxor N):8>>);
apply_mask(3, <<_:24, M:8>> = Mask,
<<N:8, Rest/binary>>, Acc) ->
apply_mask(0, Mask, Rest,
<<Acc/binary, (M bxor N):8>>).

View File

@ -35,7 +35,7 @@
start(Host, Opts) ->
User = mod_mqtt_bridge_opt:replication_user(Opts),
lists:foldl(
fun({Proc, Transport, HostAddr, Port, Publish, Subscribe, Authentication}, Started) ->
fun({Proc, Transport, HostAddr, Port, Path, Publish, Subscribe, Authentication}, Started) ->
case Started of
#{Proc := _} ->
?DEBUG("Already started ~p", [Proc]),
@ -43,7 +43,7 @@ start(Host, Opts) ->
_ ->
ChildSpec = {Proc,
{mod_mqtt_bridge_session, start_link,
[Proc, Transport, HostAddr, Port, Publish, Subscribe, Authentication, User]},
[Proc, Transport, HostAddr, Port, Path, Publish, Subscribe, Authentication, User]},
transient,
1000,
worker,
@ -107,7 +107,7 @@ mod_opt_type(replication_user) ->
econf:jid();
mod_opt_type(servers) ->
econf:and_then(
econf:map(econf:url([mqtt, mqtts, mqtt5, mqtt5s]),
econf:map(econf:url([mqtt, mqtts, mqtt5, mqtt5s, ws, wss]),
econf:options(
#{
publish => econf:map(econf:binary(), econf:binary(), [{return, map}]),
@ -127,9 +127,10 @@ mod_opt_type(servers) ->
fun(Servers) ->
maps:fold(
fun(Url, Opts, {HAcc, PAcc}) ->
{ok, Scheme, _UserInfo, Host, Port, _Path, _Query} =
{ok, Scheme, _UserInfo, Host, Port, Path, _Query} =
misc:uri_parse(Url, [{mqtt, 1883}, {mqtts, 8883},
{mqtt5, 1883}, {mqtt5s, 8883}]),
{mqtt5, 1883}, {mqtt5s, 8883},
{ws, 80}, {wss, 443}]),
Publish = maps:get(publish, Opts, #{}),
Subscribe = maps:get(subscribe, Opts, #{}),
Authentication = maps:get(authentication, Opts, []),
@ -139,7 +140,7 @@ mod_opt_type(servers) ->
fun(Topic, _RemoteTopic, Acc) ->
maps:update_with(Topic, fun(V) -> [Proc | V] end, [Proc], Acc)
end, PAcc, Publish),
{[{Proc, Proto, Host, Port, Publish, Subscribe, Authentication} | HAcc], PAcc2}
{[{Proc, Proto, Host, Port, Path, Publish, Subscribe, Authentication} | HAcc], PAcc2}
end, {[], #{}}, Servers)
end
).

View File

@ -21,7 +21,7 @@
-vsn(?VSN).
%% API
-export([start/8, start_link/8]).
-export([start/9, start_link/9]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
@ -66,6 +66,7 @@
stop_reason :: undefined | error_reason(),
subscriptions = #{},
publish = #{},
ws_codec = none,
id = 0 :: non_neg_integer(),
codec :: mqtt_codec:state(),
authentication :: #{username => binary(), password => binary(), certfile => binary()}}).
@ -75,23 +76,25 @@
%%%===================================================================
%%% API
%%%===================================================================
start(Proc, Transport, Host, Port, Publish, Subscribe, Authentication, ReplicationUser) ->
p1_server:start({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Publish, Subscribe, Authentication,
start(Proc, Transport, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser) ->
p1_server:start({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Path, Publish, Subscribe, Authentication,
ReplicationUser], []).
start_link(Proc, Transport, Host, Port, Publish, Subscribe, Authentication, ReplicationUser) ->
p1_server:start_link({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Publish, Subscribe,
start_link(Proc, Transport, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser) ->
p1_server:start_link({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Path, Publish, Subscribe,
Authentication, ReplicationUser], []).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([_Proc, Proto, Host, Port, Publish, Subscribe, Authentication, ReplicationUser]) ->
init([_Proc, Proto, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser]) ->
{Version, Transport} = case Proto of
mqtt -> {4, gen_tcp};
mqtts -> {4, ssl};
mqtt5 -> {5, gen_tcp};
mqtt5s -> {5, ssl}
mqtt5s -> {5, ssl};
ws -> {4, gen_tcp};
wss -> {4, ssl}
end,
State = #state{version = Version,
id = p1_rand:uniform(65535),
@ -101,12 +104,20 @@ init([_Proc, Proto, Host, Port, Publish, Subscribe, Authentication, ReplicationU
usr = jid:tolower(ReplicationUser),
publish = Publish},
case Authentication of
#{certfile := Cert} when Proto == mqtts; Proto == mqtt5s ->
connect(ssl:connect(Host, Port, [binary, {certfile, Cert}]), State, ssl, none);
#{certfile := Cert} when Proto == mqtts; Proto == mqtt5s; Proto == wss ->
Sock = ssl:connect(Host, Port, [binary, {active, true}, {certfile, Cert}]),
if Proto == ws orelse Proto == wss ->
connect_ws(Host, Port, Path, Sock, State, ssl, none);
true -> connect(Sock, State, ssl, none)
end;
#{username := User, password := Pass} ->
connect(Transport:connect(Host, Port, [binary]), State, Transport, {User, Pass});
Sock = Transport:connect(Host, Port, [binary, {active, true}]),
if Proto == ws orelse Proto == wss ->
connect_ws(Host, Port, Path, Sock, State, Transport, {User, Pass});
true -> connect(Sock, State, Transport, {User, Pass})
end;
_ ->
{stop, {error, <<"Certificate can be only used for encrypted connections">>}}
{stop, {error, <<"Certificate can be only used for encrypted connections">> }}
end.
handle_call(Request, From, State) ->
@ -118,20 +129,108 @@ handle_cast(Msg, State) ->
{noreply, State}.
handle_info({Tag, TCPSock, TCPData},
#state{codec = Codec, socket = Socket} = State) when Tag == tcp; Tag == ssl ->
#state{ws_codec = {init, Hash, Auth, Last}} = State)
when (Tag == tcp orelse Tag == ssl) ->
Data = <<Last/binary, TCPData/binary>>,
case erlang:decode_packet(http_bin, Data, []) of
{ok, {http_response, _, 101, _}, Rest} ->
handle_info({tcp, TCPSock, Rest}, State#state{ws_codec = {inith, Hash, none, Auth, <<>>}});
{ok, {http_response, _, _, _}, _Rest} ->
stop(State, {socket, closed});
{ok, {http_error, _}, _} ->
stop(State, {socket, closed});
{error, _} ->
stop(State, {socket, closed});
{more, _} ->
{noreply, State#state{ws_codec = {init, Hash, Auth, Data}}}
end;
handle_info({Tag, TCPSock, TCPData},
#state{ws_codec = {inith, Hash, Upgrade, Auth, Last},
socket = {Transport, _}} = State)
when (Tag == tcp orelse Tag == ssl) ->
Data = <<Last/binary, TCPData/binary>>,
case erlang:decode_packet(httph_bin, Data, []) of
{ok, {http_header, _, <<"Sec-Websocket-Accept">>, _, Val}, Rest} ->
case str:to_lower(Val) of
Hash ->
handle_info({tcp, TCPSock, Rest},
State#state{ws_codec = {inith, ok, Upgrade, Auth, <<>>}});
_ ->
stop(State, {socket, closed})
end;
{ok, {http_header, _, 'Connection', _, Val}, Rest} ->
case str:to_lower(Val) of
<<"upgrade">> ->
handle_info({tcp, TCPSock, Rest},
State#state{ws_codec = {inith, Hash, ok, Auth, <<>>}});
_ ->
stop(State, {socket, closed})
end;
{ok, {http_header, _, _, _, _}, Rest} ->
handle_info({tcp, TCPSock, Rest}, State);
{ok, {http_error, _}, _} ->
stop(State, {socket, closed});
{ok, http_eoh, Rest} ->
case {Hash, Upgrade} of
{ok, ok} ->
{ok, State2} = connect({ok, TCPSock},
State#state{ws_codec = ejabberd_websocket_codec:new_client()},
Transport, Auth),
handle_info({tcp, TCPSock, Rest}, State2);
_ ->
stop(State, {socket, closed})
end;
{error, _} ->
stop(State, {socket, closed});
{more, _} ->
{noreply, State#state{ws_codec = {inith, Hash, Upgrade, Data}}}
end;
handle_info({Tag, TCPSock, TCPData},
#state{ws_codec = WSCodec} = State)
when (Tag == tcp orelse Tag == ssl) andalso WSCodec /= none ->
{Packets, Acc0} =
case ejabberd_websocket_codec:decode(WSCodec, TCPData) of
{ok, NewWSCodec, Packets0} ->
{Packets0, {State#state{ws_codec = NewWSCodec}, ok}};
{error, _Error, Packets0} ->
{Packets0, {State, stop}}
end,
Res2 =
lists:foldl(
fun(_, {stop, _, _} = Res) -> Res;
({_Op, Data}, {S, Res}) ->
case handle_info({tcp_decoded, TCPSock, Data}, S) of
{stop, _, _} = Stop ->
Stop;
{_, NewState, _} ->
{NewState, Res};
{_, NewState} ->
{NewState, Res}
end
end, Acc0, Packets),
case Res2 of
{stop, _, _} ->
Res2;
{NewState2, ok} ->
{noreply, NewState2};
{NewState2, stop} ->
stop(NewState2, {socket, closed})
end;
handle_info({Tag, TCPSock, TCPData},
#state{codec = Codec} = State)
when Tag == tcp; Tag == ssl; Tag == tcp_decoded ->
case mqtt_codec:decode(Codec, TCPData) of
{ok, Pkt, Codec1} ->
?DEBUG("Got MQTT packet:~n~ts", [pp(Pkt)]),
State1 = State#state{codec = Codec1},
case handle_packet(Pkt, State1) of
{ok, State2} ->
handle_info({tcp, TCPSock, <<>>}, State2);
handle_info({tcp_decoded, TCPSock, <<>>}, State2);
{error, State2, Reason} ->
stop(State2, Reason)
end;
{more, Codec1} ->
State1 = State#state{codec = Codec1},
activate(Socket),
{noreply, State1};
{error, Why} ->
stop(State, {codec, Why})
@ -156,7 +255,7 @@ handle_info({publish, #publish{topic = Topic} = Pkt}, #state{publish = Publish}
{noreply, State2}
end;
_ ->
State
{noreply, State}
end;
handle_info({timeout, _TRef, ping_timeout}, State) ->
case send(State, #pingreq{}) of
@ -230,6 +329,22 @@ connect({ok, Sock}, State0, Transport, Auth) ->
{ok, _, Codec2} = mqtt_codec:decode(State#state.codec, Pkt),
{ok, State#state{codec = Codec2}}.
connect_ws(_Host, _Port, _Path, {error, Reason}, _State, _Transport, _Auth) ->
{stop, {error, Reason}};
connect_ws(Host, Port, Path, {ok, Sock}, State0, Transport, Auth) ->
Key = base64:encode(p1_rand:get_string()),
Hash = str:to_lower(base64:encode(crypto:hash(sha, <<Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11">>))),
Data = <<"GET ", (list_to_binary(Path))/binary, " HTTP/1.1\r\n",
"Host: ", (list_to_binary(Host))/binary, ":", (integer_to_binary(Port))/binary,"\r\n",
"Upgrade: websocket\r\n",
"Connection: Upgrade\r\n",
"Sec-WebSocket-Protocol: mqtt\r\n",
"Sec-WebSocket-Key: ", Key/binary, "\r\n",
"Sec-WebSocket-Version: 13\r\n\r\n">>,
Res = Transport:send(Sock, Data),
check_sock_result({Transport, Sock}, Res),
{ok, State0#state{ws_codec = {init, Hash, Auth, <<>>}, socket = {Transport, Sock}}}.
-spec stop(state(), error_reason()) ->
{noreply, state(), infinity} |
{stop, normal, state()}.
@ -286,6 +401,14 @@ send(State, Pkt) ->
{ok, do_send(State, Pkt)}.
-spec do_send(state(), mqtt_packet()) -> state().
do_send(#state{ws_codec = WSCodec, socket = {SockMod, Sock} = Socket} = State, Pkt)
when WSCodec /= none ->
?DEBUG("Send MQTT packet:~n~ts", [pp(Pkt)]),
Data = mqtt_codec:encode(State#state.version, Pkt),
WSData = ejabberd_websocket_codec:encode(WSCodec, 2, Data),
Res = SockMod:send(Sock, WSData),
check_sock_result(Socket, Res),
reset_ping_timer(State);
do_send(#state{socket = {SockMod, Sock} = Socket} = State, Pkt) ->
?DEBUG("Send MQTT packet:~n~ts", [pp(Pkt)]),
Data = mqtt_codec:encode(State#state.version, Pkt),
@ -295,14 +418,6 @@ do_send(#state{socket = {SockMod, Sock} = Socket} = State, Pkt) ->
do_send(State, _Pkt) ->
State.
-spec activate(socket()) -> ok.
activate(Socket) ->
Res = case Socket of
{gen_tcp, Sock} -> inet:setopts(Sock, [{active, once}]);
{SockMod, Sock} -> SockMod:setopts(Sock, [{active, once}])
end,
check_sock_result(Socket, Res).
-spec disconnect(state(), error_reason()) -> state().
disconnect(#state{socket = {SockMod, Sock}} = State, Err) ->
State1 = case Err of