From 639183a783d8da5a42da3941c9c7f419fcaee219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Chmielowski?= Date: Wed, 7 Dec 2022 13:35:04 +0100 Subject: [PATCH] Add support for certificate authentication in mqtt bridge --- src/mod_mqtt_bridge.erl | 48 ++++++++++++--------- src/mod_mqtt_bridge_opt.erl | 2 +- src/mod_mqtt_bridge_session.erl | 76 ++++++++++++++++++++++----------- 3 files changed, 79 insertions(+), 47 deletions(-) diff --git a/src/mod_mqtt_bridge.erl b/src/mod_mqtt_bridge.erl index 5c8c25f01..21aec3cfe 100644 --- a/src/mod_mqtt_bridge.erl +++ b/src/mod_mqtt_bridge.erl @@ -72,13 +72,11 @@ reload(_Host, _NewOpts, _OldOpts) -> depends(_Host, _Opts) -> [{mod_mqtt, hard}]. -proc_name(Transport, Host, Port) -> +proc_name(Proto, Host, Port) -> HostB = list_to_binary(Host), - case Transport of - gen_tcp -> - binary_to_atom(<<"mod_mqtt_bridge_mqtt_", HostB/binary, "_", (integer_to_binary(Port))/binary>>, utf8); - _ -> binary_to_atom(<<"mod_mqtt_bridge_mqtts_", HostB/binary, "_", (integer_to_binary(Port))/binary>>, utf8) - end. + TransportB = list_to_binary(Proto), + binary_to_atom(<<"mod_mqtt_bridge_", TransportB/binary, "_", HostB/binary, + "_", (integer_to_binary(Port))/binary>>, utf8). -spec mqtt_publish_hook(jid:ljid(), publish(), non_neg_integer()) -> ok. mqtt_publish_hook({_, S, _}, #publish{topic = Topic} = Pkt, _ExpiryTime) -> @@ -97,8 +95,8 @@ mqtt_publish_hook({_, S, _}, #publish{topic = Topic} = Pkt, _ExpiryTime) -> %%%=================================================================== -spec mod_options(binary()) -> [{servers, - {[{atom(), gen_tcp | ssl, binary(), non_neg_integer(), - #{binary() => binary()}, #{binary() => binary()}, binary()}], + {[{atom(), mqtt | mqtts | mqtt5 | mqtt5s, binary(), non_neg_integer(), + #{binary() => binary()}, #{binary() => binary()}, map()}], #{binary() => [atom()]}}} | {atom(), any()}]. mod_options(Host) -> @@ -109,29 +107,39 @@ mod_opt_type(replication_user) -> econf:jid(); mod_opt_type(servers) -> econf:and_then( - econf:map(econf:url([mqtt, mqtts]), - econf:options(#{ - publish => econf:map(econf:binary(), econf:binary(), [{return, map}]), - subscribe => econf:map(econf:binary(), econf:binary(), [{return, map}]), - authentication => econf:binary()}, - [{return, map}]), + econf:map(econf:url([mqtt, mqtts, mqtt5, mqtt5s]), + econf:options( + #{ + publish => econf:map(econf:binary(), econf:binary(), [{return, map}]), + subscribe => econf:map(econf:binary(), econf:binary(), [{return, map}]), + authentication => econf:either( + econf:options( + #{ + username => econf:binary(), + password => econf:binary() + }, [{return, map}]), + econf:options( + #{ + certfile => econf:pem() + }, [{return, map}]) + )}, [{return, map}]), [{return, map}]), fun(Servers) -> maps:fold( fun(Url, Opts, {HAcc, PAcc}) -> - {ok, Scheme, _UserInfo, Host, Port, _Path, _Query} = misc:uri_parse(Url), + {ok, Scheme, _UserInfo, Host, Port, _Path, _Query} = + misc:uri_parse(Url, [{mqtt, 1883}, {mqtts, 8883}, + {mqtt5, 1883}, {mqtt5s, 8883}]), Publish = maps:get(publish, Opts, #{}), Subscribe = maps:get(subscribe, Opts, #{}), Authentication = maps:get(authentication, Opts, []), - Transport = case Scheme of "mqtt" -> gen_tcp; - _ -> ssl - end, - Proc = proc_name(Transport, Host, Port), + Proto = list_to_atom(Scheme), + Proc = proc_name(Scheme, Host, Port), PAcc2 = maps:fold( fun(Topic, _RemoteTopic, Acc) -> maps:update_with(Topic, fun(V) -> [Proc | V] end, [Proc], Acc) end, PAcc, Publish), - {[{Proc, Transport, Host, Port, Publish, Subscribe, Authentication} | HAcc], PAcc2} + {[{Proc, Proto, Host, Port, Publish, Subscribe, Authentication} | HAcc], PAcc2} end, {[], #{}}, Servers) end ). diff --git a/src/mod_mqtt_bridge_opt.erl b/src/mod_mqtt_bridge_opt.erl index fe423811f..e10f72e1d 100644 --- a/src/mod_mqtt_bridge_opt.erl +++ b/src/mod_mqtt_bridge_opt.erl @@ -12,7 +12,7 @@ replication_user(Opts) when is_map(Opts) -> replication_user(Host) -> gen_mod:get_module_opt(Host, mod_mqtt_bridge, replication_user). --spec servers(gen_mod:opts() | global | binary()) -> {[{atom(),'gen_tcp' | 'ssl',binary(),non_neg_integer(),#{binary()=>binary()},#{binary()=>binary()},binary()}],#{binary()=>[atom()]}}. +-spec servers(gen_mod:opts() | global | binary()) -> {[{atom(),'mqtt' | 'mqtts' | 'mqtt5' | 'mqtt5s',binary(),non_neg_integer(),#{binary()=>binary()},#{binary()=>binary()},map()}],#{binary()=>[atom()]}}. servers(Opts) when is_map(Opts) -> gen_mod:get_opt(servers, Opts); servers(Host) -> diff --git a/src/mod_mqtt_bridge_session.erl b/src/mod_mqtt_bridge_session.erl index fb9c21d47..3ad0e9fee 100644 --- a/src/mod_mqtt_bridge_session.erl +++ b/src/mod_mqtt_bridge_session.erl @@ -68,7 +68,7 @@ publish = #{}, id = 0 :: non_neg_integer(), codec :: mqtt_codec:state(), - authentication}). + authentication :: #{}}). -type state() :: #state{}. @@ -86,19 +86,27 @@ start_link(Proc, Transport, Host, Port, Publish, Subscribe, Authentication, Repl %%%=================================================================== %%% gen_server callbacks %%%=================================================================== -init([_Proc, Transport, Host, Port, Publish, Subscribe, Authentication, ReplicationUser]) -> - case Transport:connect(Host, Port, [binary]) of - {ok, Sock} -> - State1 = #state{socket = {Transport, Sock}, - version = 5, - id = p1_rand:uniform(65535), - codec = mqtt_codec:new(4096), - subscriptions = Subscribe, - authentication = Authentication, - usr = jid:tolower(ReplicationUser), - publish = Publish}, - State2 = connect(State1, Authentication), - {ok, State2} +init([_Proc, Proto, Host, Port, Publish, Subscribe, Authentication, ReplicationUser]) -> + {Version, Transport} = case Proto of + mqtt -> {4, gen_tcp}; + mqtts -> {4, ssl}; + mqtt5 -> {5, gen_tcp}; + mqtt5s -> {5, ssl} + end, + State = #state{version = Version, + id = p1_rand:uniform(65535), + codec = mqtt_codec:new(4096), + subscriptions = Subscribe, + authentication = Authentication, + usr = jid:tolower(ReplicationUser), + publish = Publish}, + case Authentication of + #{certfile := Cert} when Proto == mqtts; Proto == mqtt5s -> + connect(ssl:connect(Host, Port, [binary, {certfile, Cert}]), State, ssl, none); + #{username := User, password := Pass} -> + connect(Transport:connect(Host, Port, [binary]), State, Transport, {User, Pass}); + _ -> + {stop, {error, <<"Certificate can be only used for encrypted connections">>, Authentication, Proto}} end. handle_call(Request, From, State) -> @@ -109,8 +117,8 @@ handle_cast(Msg, State) -> ?WARNING_MSG("Unexpected cast: ~p", [Msg]), {noreply, State}. -handle_info({tcp, TCPSock, TCPData}, - #state{codec = Codec, socket = Socket} = State) -> +handle_info({Tag, TCPSock, TCPData}, + #state{codec = Codec, socket = Socket} = State) when Tag == tcp; Tag == ssl -> case mqtt_codec:decode(Codec, TCPData) of {ok, Pkt, Codec1} -> ?DEBUG("Got MQTT packet:~n~ts", [pp(Pkt)]), @@ -131,9 +139,15 @@ handle_info({tcp, TCPSock, TCPData}, handle_info({tcp_closed, _Sock}, State) -> ?DEBUG("MQTT connection reset by peer", []), stop(State, {socket, closed}); +handle_info({ssl_closed, _Sock}, State) -> + ?DEBUG("MQTT connection reset by peer", []), + stop(State, {socket, closed}); handle_info({tcp_error, _Sock, Reason}, State) -> ?DEBUG("MQTT connection error: ~ts", [format_inet_error(Reason)]), stop(State, {socket, Reason}); +handle_info({ssl_error, _Sock, Reason}, State) -> + ?DEBUG("MQTT connection error: ~ts", [format_inet_error(Reason)]), + stop(State, {socket, Reason}); handle_info({publish, #publish{topic = Topic} = Pkt}, #state{publish = Publish} = State) -> case maps:find(Topic, Publish) of {ok, RemoteTopic} -> @@ -193,18 +207,28 @@ code_change(_OldVsn, State, _Extra) -> %%%=================================================================== %%% State transitions %%%=================================================================== -connect(State, AuthString) -> - [User, Pass] = binary:split(AuthString, <<":">>), - Connect = #connect{client_id = integer_to_binary(State#state.id), - clean_start = true, - username = User, - password = Pass, - keep_alive = 60, - proto_level = 5}, - Pkt = mqtt_codec:encode(5, Connect), +connect({error, Reason}, _State, _Transport, _Auth) -> + {stop, {error, Reason}}; +connect({ok, Sock}, State0, Transport, Auth) -> + State = State0#state{socket = {Transport, Sock}}, + Connect = case Auth of + {User, Pass} -> + #connect{client_id = integer_to_binary(State#state.id), + clean_start = true, + username = User, + password = Pass, + keep_alive = 60, + proto_level = State#state.version}; + _ -> + #connect{client_id = integer_to_binary(State#state.id), + clean_start = true, + keep_alive = 60, + proto_level = State#state.version} + end, + Pkt = mqtt_codec:encode(State#state.version, Connect), send(State, Connect), {ok, _, Codec2} = mqtt_codec:decode(State#state.codec, Pkt), - State#state{codec = Codec2}. + {ok, State#state{codec = Codec2}}. -spec stop(state(), error_reason()) -> {noreply, state(), infinity} |