Add TLS certificate authentication for MQTT connections

This commit is contained in:
Alexey Shchepin 2022-03-14 15:37:21 +03:00
parent 871e26a01e
commit 5506b838c8
1 changed files with 87 additions and 13 deletions

View File

@ -29,6 +29,7 @@
-include("logger.hrl").
-include("mqtt.hrl").
-include_lib("xmpp/include/xmpp.hrl").
-include_lib("public_key/include/public_key.hrl").
-record(state, {vsn = ?VSN :: integer(),
version :: undefined | mqtt_version(),
@ -47,7 +48,8 @@
in_flight :: undefined | publish() | pubrel(),
codec :: mqtt_codec:state(),
queue :: undefined | p1_queue:queue(publish()),
tls :: boolean()}).
tls :: boolean(),
tls_verify :: boolean()}).
-type acks() :: #{non_neg_integer() => pubrec()}.
-type subscriptions() :: #{binary() => {sub_opts(), non_neg_integer()}}.
@ -162,6 +164,7 @@ init([SockMod, Socket, ListenOpts]) ->
State1 = #state{socket = {SockMod, Socket},
id = p1_rand:uniform(65535),
tls = proplists:get_bool(tls, ListenOpts),
tls_verify = proplists:get_bool(tls_verify, ListenOpts),
codec = mqtt_codec:new(MaxSize)},
Timeout = timer:seconds(30),
State2 = set_timeout(State1, Timeout),
@ -553,7 +556,7 @@ unregister_session(_, _) ->
{error, state(), error_reason()}.
handle_connect(#connect{clean_start = CleanStart} = Pkt,
#state{jid = undefined, peername = IP} = State) ->
case authenticate(Pkt, IP) of
case authenticate(Pkt, IP, State) of
{ok, JID} ->
case validate_will(Pkt, JID) of
ok ->
@ -939,7 +942,12 @@ check_sock_result({_, Sock}, {error, Why}) ->
starttls(#state{socket = {gen_tcp, Socket}, tls = true}) ->
case ejabberd_pkix:get_certfile() of
{ok, Cert} ->
case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of
CAFileOpt =
case ejabberd_option:c2s_cafile(ejabberd_config:get_myname()) of
undefined -> [];
CAFile -> [{cafile, CAFile}]
end,
case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}] ++ CAFileOpt) of
{ok, TLSSock} ->
{ok, {fast_tls, TLSSock}};
{error, Why} ->
@ -1172,9 +1180,9 @@ parse_credentials(JID, ClientID) ->
end
end.
-spec authenticate(connect(), peername()) -> {ok, jid:jid()} | {error, reason_code()}.
authenticate(Pkt, IP) ->
case authenticate(Pkt) of
-spec authenticate(connect(), peername(), state()) -> {ok, jid:jid()} | {error, reason_code()}.
authenticate(Pkt, IP, State) ->
case authenticate(Pkt, State) of
{ok, JID, AuthModule} ->
?INFO_MSG("Accepted MQTT authentication for ~ts by ~s backend from ~s",
[jid:encode(JID),
@ -1185,8 +1193,8 @@ authenticate(Pkt, IP) ->
Err
end.
-spec authenticate(connect()) -> {ok, jid:jid(), module()} | {error, reason_code()}.
authenticate(#connect{password = Pass, properties = Props} = Pkt) ->
-spec authenticate(connect(), state()) -> {ok, jid:jid(), module()} | {error, reason_code()}.
authenticate(#connect{password = Pass, properties = Props} = Pkt, State) ->
case parse_credentials(Pkt) of
{ok, #jid{luser = LUser, lserver = LServer} = JID} ->
case maps:find(authentication_method, Props) of
@ -1200,16 +1208,82 @@ authenticate(#connect{password = Pass, properties = Props} = Pkt) ->
{ok, _} ->
{error, 'bad-authentication-method'};
error ->
case ejabberd_auth:check_password_with_authmodule(
LUser, <<>>, LServer, Pass) of
{true, AuthModule} -> {ok, JID, AuthModule};
false -> {error, 'not-authorized'}
end
case Pass of
<<>> ->
case tls_auth(JID, State) of
true ->
{ok, JID, pkix};
false ->
{error, 'not-authorized'}
end;
_ ->
case ejabberd_auth:check_password_with_authmodule(
LUser, <<>>, LServer, Pass) of
{true, AuthModule} -> {ok, JID, AuthModule};
false -> {error, 'not-authorized'}
end
end
end;
{error, _} = Err ->
Err
end.
-spec tls_auth(jid:jid(), state()) -> boolean().
tls_auth(_JID, #state{tls_verify = false}) ->
false;
tls_auth(JID, State) ->
case State#state.socket of
{fast_tls, Sock} ->
case fast_tls:get_peer_certificate(Sock, otp) of
{ok, Cert} ->
case fast_tls:get_verify_result(Sock) of
0 ->
case get_cert_jid(Cert) of
{ok, JID2} ->
jid:remove_resource(jid:tolower(JID)) ==
jid:remove_resource(jid:tolower(JID2));
error ->
false
end;
VerifyRes ->
Reason = fast_tls:get_cert_verify_string(VerifyRes, Cert),
?WARNING_MSG("TLS verify failed: ~s", [Reason]),
false
end;
error ->
false
end;
_ ->
false
end.
get_cert_jid(Cert) ->
case Cert#'OTPCertificate'.tbsCertificate#'OTPTBSCertificate'.subject of
{rdnSequence, Attrs1} ->
Attrs = lists:flatten(Attrs1),
case lists:keyfind(?'id-at-commonName',
#'AttributeTypeAndValue'.type, Attrs) of
#'AttributeTypeAndValue'{value = {utf8String, Val}} ->
try jid:decode(Val) of
#jid{luser = <<>>} ->
case jid:make(Val, ejabberd_config:get_myname()) of
error ->
error;
JID ->
{ok, JID}
end;
JID ->
{ok, JID}
catch _:{bad_jid, _} ->
error
end;
_ ->
error
end;
_ ->
error
end.
%%%===================================================================
%%% Validators
%%%===================================================================