25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-11-20 16:15:59 +01:00

Add support for certificate authentication in mqtt bridge

This commit is contained in:
Paweł Chmielowski 2022-12-07 13:35:04 +01:00
parent 80477f71b3
commit 639183a783
3 changed files with 79 additions and 47 deletions

View File

@ -72,13 +72,11 @@ reload(_Host, _NewOpts, _OldOpts) ->
depends(_Host, _Opts) -> depends(_Host, _Opts) ->
[{mod_mqtt, hard}]. [{mod_mqtt, hard}].
proc_name(Transport, Host, Port) -> proc_name(Proto, Host, Port) ->
HostB = list_to_binary(Host), HostB = list_to_binary(Host),
case Transport of TransportB = list_to_binary(Proto),
gen_tcp -> binary_to_atom(<<"mod_mqtt_bridge_", TransportB/binary, "_", HostB/binary,
binary_to_atom(<<"mod_mqtt_bridge_mqtt_", HostB/binary, "_", (integer_to_binary(Port))/binary>>, utf8); "_", (integer_to_binary(Port))/binary>>, utf8).
_ -> binary_to_atom(<<"mod_mqtt_bridge_mqtts_", HostB/binary, "_", (integer_to_binary(Port))/binary>>, utf8)
end.
-spec mqtt_publish_hook(jid:ljid(), publish(), non_neg_integer()) -> ok. -spec mqtt_publish_hook(jid:ljid(), publish(), non_neg_integer()) -> ok.
mqtt_publish_hook({_, S, _}, #publish{topic = Topic} = Pkt, _ExpiryTime) -> mqtt_publish_hook({_, S, _}, #publish{topic = Topic} = Pkt, _ExpiryTime) ->
@ -97,8 +95,8 @@ mqtt_publish_hook({_, S, _}, #publish{topic = Topic} = Pkt, _ExpiryTime) ->
%%%=================================================================== %%%===================================================================
-spec mod_options(binary()) -> -spec mod_options(binary()) ->
[{servers, [{servers,
{[{atom(), gen_tcp | ssl, binary(), non_neg_integer(), {[{atom(), mqtt | mqtts | mqtt5 | mqtt5s, binary(), non_neg_integer(),
#{binary() => binary()}, #{binary() => binary()}, binary()}], #{binary() => binary()}, #{binary() => binary()}, map()}],
#{binary() => [atom()]}}} | #{binary() => [atom()]}}} |
{atom(), any()}]. {atom(), any()}].
mod_options(Host) -> mod_options(Host) ->
@ -109,29 +107,39 @@ mod_opt_type(replication_user) ->
econf:jid(); econf:jid();
mod_opt_type(servers) -> mod_opt_type(servers) ->
econf:and_then( econf:and_then(
econf:map(econf:url([mqtt, mqtts]), econf:map(econf:url([mqtt, mqtts, mqtt5, mqtt5s]),
econf:options(#{ econf:options(
publish => econf:map(econf:binary(), econf:binary(), [{return, map}]), #{
subscribe => econf:map(econf:binary(), econf:binary(), [{return, map}]), publish => econf:map(econf:binary(), econf:binary(), [{return, map}]),
authentication => econf:binary()}, subscribe => econf:map(econf:binary(), econf:binary(), [{return, map}]),
[{return, map}]), authentication => econf:either(
econf:options(
#{
username => econf:binary(),
password => econf:binary()
}, [{return, map}]),
econf:options(
#{
certfile => econf:pem()
}, [{return, map}])
)}, [{return, map}]),
[{return, map}]), [{return, map}]),
fun(Servers) -> fun(Servers) ->
maps:fold( maps:fold(
fun(Url, Opts, {HAcc, PAcc}) -> fun(Url, Opts, {HAcc, PAcc}) ->
{ok, Scheme, _UserInfo, Host, Port, _Path, _Query} = misc:uri_parse(Url), {ok, Scheme, _UserInfo, Host, Port, _Path, _Query} =
misc:uri_parse(Url, [{mqtt, 1883}, {mqtts, 8883},
{mqtt5, 1883}, {mqtt5s, 8883}]),
Publish = maps:get(publish, Opts, #{}), Publish = maps:get(publish, Opts, #{}),
Subscribe = maps:get(subscribe, Opts, #{}), Subscribe = maps:get(subscribe, Opts, #{}),
Authentication = maps:get(authentication, Opts, []), Authentication = maps:get(authentication, Opts, []),
Transport = case Scheme of "mqtt" -> gen_tcp; Proto = list_to_atom(Scheme),
_ -> ssl Proc = proc_name(Scheme, Host, Port),
end,
Proc = proc_name(Transport, Host, Port),
PAcc2 = maps:fold( PAcc2 = maps:fold(
fun(Topic, _RemoteTopic, Acc) -> fun(Topic, _RemoteTopic, Acc) ->
maps:update_with(Topic, fun(V) -> [Proc | V] end, [Proc], Acc) maps:update_with(Topic, fun(V) -> [Proc | V] end, [Proc], Acc)
end, PAcc, Publish), end, PAcc, Publish),
{[{Proc, Transport, Host, Port, Publish, Subscribe, Authentication} | HAcc], PAcc2} {[{Proc, Proto, Host, Port, Publish, Subscribe, Authentication} | HAcc], PAcc2}
end, {[], #{}}, Servers) end, {[], #{}}, Servers)
end end
). ).

View File

@ -12,7 +12,7 @@ replication_user(Opts) when is_map(Opts) ->
replication_user(Host) -> replication_user(Host) ->
gen_mod:get_module_opt(Host, mod_mqtt_bridge, replication_user). gen_mod:get_module_opt(Host, mod_mqtt_bridge, replication_user).
-spec servers(gen_mod:opts() | global | binary()) -> {[{atom(),'gen_tcp' | 'ssl',binary(),non_neg_integer(),#{binary()=>binary()},#{binary()=>binary()},binary()}],#{binary()=>[atom()]}}. -spec servers(gen_mod:opts() | global | binary()) -> {[{atom(),'mqtt' | 'mqtts' | 'mqtt5' | 'mqtt5s',binary(),non_neg_integer(),#{binary()=>binary()},#{binary()=>binary()},map()}],#{binary()=>[atom()]}}.
servers(Opts) when is_map(Opts) -> servers(Opts) when is_map(Opts) ->
gen_mod:get_opt(servers, Opts); gen_mod:get_opt(servers, Opts);
servers(Host) -> servers(Host) ->

View File

@ -68,7 +68,7 @@
publish = #{}, publish = #{},
id = 0 :: non_neg_integer(), id = 0 :: non_neg_integer(),
codec :: mqtt_codec:state(), codec :: mqtt_codec:state(),
authentication}). authentication :: #{}}).
-type state() :: #state{}. -type state() :: #state{}.
@ -86,19 +86,27 @@ start_link(Proc, Transport, Host, Port, Publish, Subscribe, Authentication, Repl
%%%=================================================================== %%%===================================================================
%%% gen_server callbacks %%% gen_server callbacks
%%%=================================================================== %%%===================================================================
init([_Proc, Transport, Host, Port, Publish, Subscribe, Authentication, ReplicationUser]) -> init([_Proc, Proto, Host, Port, Publish, Subscribe, Authentication, ReplicationUser]) ->
case Transport:connect(Host, Port, [binary]) of {Version, Transport} = case Proto of
{ok, Sock} -> mqtt -> {4, gen_tcp};
State1 = #state{socket = {Transport, Sock}, mqtts -> {4, ssl};
version = 5, mqtt5 -> {5, gen_tcp};
id = p1_rand:uniform(65535), mqtt5s -> {5, ssl}
codec = mqtt_codec:new(4096), end,
subscriptions = Subscribe, State = #state{version = Version,
authentication = Authentication, id = p1_rand:uniform(65535),
usr = jid:tolower(ReplicationUser), codec = mqtt_codec:new(4096),
publish = Publish}, subscriptions = Subscribe,
State2 = connect(State1, Authentication), authentication = Authentication,
{ok, State2} usr = jid:tolower(ReplicationUser),
publish = Publish},
case Authentication of
#{certfile := Cert} when Proto == mqtts; Proto == mqtt5s ->
connect(ssl:connect(Host, Port, [binary, {certfile, Cert}]), State, ssl, none);
#{username := User, password := Pass} ->
connect(Transport:connect(Host, Port, [binary]), State, Transport, {User, Pass});
_ ->
{stop, {error, <<"Certificate can be only used for encrypted connections">>, Authentication, Proto}}
end. end.
handle_call(Request, From, State) -> handle_call(Request, From, State) ->
@ -109,8 +117,8 @@ handle_cast(Msg, State) ->
?WARNING_MSG("Unexpected cast: ~p", [Msg]), ?WARNING_MSG("Unexpected cast: ~p", [Msg]),
{noreply, State}. {noreply, State}.
handle_info({tcp, TCPSock, TCPData}, handle_info({Tag, TCPSock, TCPData},
#state{codec = Codec, socket = Socket} = State) -> #state{codec = Codec, socket = Socket} = State) when Tag == tcp; Tag == ssl ->
case mqtt_codec:decode(Codec, TCPData) of case mqtt_codec:decode(Codec, TCPData) of
{ok, Pkt, Codec1} -> {ok, Pkt, Codec1} ->
?DEBUG("Got MQTT packet:~n~ts", [pp(Pkt)]), ?DEBUG("Got MQTT packet:~n~ts", [pp(Pkt)]),
@ -131,9 +139,15 @@ handle_info({tcp, TCPSock, TCPData},
handle_info({tcp_closed, _Sock}, State) -> handle_info({tcp_closed, _Sock}, State) ->
?DEBUG("MQTT connection reset by peer", []), ?DEBUG("MQTT connection reset by peer", []),
stop(State, {socket, closed}); stop(State, {socket, closed});
handle_info({ssl_closed, _Sock}, State) ->
?DEBUG("MQTT connection reset by peer", []),
stop(State, {socket, closed});
handle_info({tcp_error, _Sock, Reason}, State) -> handle_info({tcp_error, _Sock, Reason}, State) ->
?DEBUG("MQTT connection error: ~ts", [format_inet_error(Reason)]), ?DEBUG("MQTT connection error: ~ts", [format_inet_error(Reason)]),
stop(State, {socket, Reason}); stop(State, {socket, Reason});
handle_info({ssl_error, _Sock, Reason}, State) ->
?DEBUG("MQTT connection error: ~ts", [format_inet_error(Reason)]),
stop(State, {socket, Reason});
handle_info({publish, #publish{topic = Topic} = Pkt}, #state{publish = Publish} = State) -> handle_info({publish, #publish{topic = Topic} = Pkt}, #state{publish = Publish} = State) ->
case maps:find(Topic, Publish) of case maps:find(Topic, Publish) of
{ok, RemoteTopic} -> {ok, RemoteTopic} ->
@ -193,18 +207,28 @@ code_change(_OldVsn, State, _Extra) ->
%%%=================================================================== %%%===================================================================
%%% State transitions %%% State transitions
%%%=================================================================== %%%===================================================================
connect(State, AuthString) -> connect({error, Reason}, _State, _Transport, _Auth) ->
[User, Pass] = binary:split(AuthString, <<":">>), {stop, {error, Reason}};
Connect = #connect{client_id = integer_to_binary(State#state.id), connect({ok, Sock}, State0, Transport, Auth) ->
clean_start = true, State = State0#state{socket = {Transport, Sock}},
username = User, Connect = case Auth of
password = Pass, {User, Pass} ->
keep_alive = 60, #connect{client_id = integer_to_binary(State#state.id),
proto_level = 5}, clean_start = true,
Pkt = mqtt_codec:encode(5, Connect), username = User,
password = Pass,
keep_alive = 60,
proto_level = State#state.version};
_ ->
#connect{client_id = integer_to_binary(State#state.id),
clean_start = true,
keep_alive = 60,
proto_level = State#state.version}
end,
Pkt = mqtt_codec:encode(State#state.version, Connect),
send(State, Connect), send(State, Connect),
{ok, _, Codec2} = mqtt_codec:decode(State#state.codec, Pkt), {ok, _, Codec2} = mqtt_codec:decode(State#state.codec, Pkt),
State#state{codec = Codec2}. {ok, State#state{codec = Codec2}}.
-spec stop(state(), error_reason()) -> -spec stop(state(), error_reason()) ->
{noreply, state(), infinity} | {noreply, state(), infinity} |