Add TLS certificate authentication for MQTT connections
This commit is contained in:
parent
871e26a01e
commit
5506b838c8
|
@ -29,6 +29,7 @@
|
||||||
-include("logger.hrl").
|
-include("logger.hrl").
|
||||||
-include("mqtt.hrl").
|
-include("mqtt.hrl").
|
||||||
-include_lib("xmpp/include/xmpp.hrl").
|
-include_lib("xmpp/include/xmpp.hrl").
|
||||||
|
-include_lib("public_key/include/public_key.hrl").
|
||||||
|
|
||||||
-record(state, {vsn = ?VSN :: integer(),
|
-record(state, {vsn = ?VSN :: integer(),
|
||||||
version :: undefined | mqtt_version(),
|
version :: undefined | mqtt_version(),
|
||||||
|
@ -47,7 +48,8 @@
|
||||||
in_flight :: undefined | publish() | pubrel(),
|
in_flight :: undefined | publish() | pubrel(),
|
||||||
codec :: mqtt_codec:state(),
|
codec :: mqtt_codec:state(),
|
||||||
queue :: undefined | p1_queue:queue(publish()),
|
queue :: undefined | p1_queue:queue(publish()),
|
||||||
tls :: boolean()}).
|
tls :: boolean(),
|
||||||
|
tls_verify :: boolean()}).
|
||||||
|
|
||||||
-type acks() :: #{non_neg_integer() => pubrec()}.
|
-type acks() :: #{non_neg_integer() => pubrec()}.
|
||||||
-type subscriptions() :: #{binary() => {sub_opts(), non_neg_integer()}}.
|
-type subscriptions() :: #{binary() => {sub_opts(), non_neg_integer()}}.
|
||||||
|
@ -162,6 +164,7 @@ init([SockMod, Socket, ListenOpts]) ->
|
||||||
State1 = #state{socket = {SockMod, Socket},
|
State1 = #state{socket = {SockMod, Socket},
|
||||||
id = p1_rand:uniform(65535),
|
id = p1_rand:uniform(65535),
|
||||||
tls = proplists:get_bool(tls, ListenOpts),
|
tls = proplists:get_bool(tls, ListenOpts),
|
||||||
|
tls_verify = proplists:get_bool(tls_verify, ListenOpts),
|
||||||
codec = mqtt_codec:new(MaxSize)},
|
codec = mqtt_codec:new(MaxSize)},
|
||||||
Timeout = timer:seconds(30),
|
Timeout = timer:seconds(30),
|
||||||
State2 = set_timeout(State1, Timeout),
|
State2 = set_timeout(State1, Timeout),
|
||||||
|
@ -553,7 +556,7 @@ unregister_session(_, _) ->
|
||||||
{error, state(), error_reason()}.
|
{error, state(), error_reason()}.
|
||||||
handle_connect(#connect{clean_start = CleanStart} = Pkt,
|
handle_connect(#connect{clean_start = CleanStart} = Pkt,
|
||||||
#state{jid = undefined, peername = IP} = State) ->
|
#state{jid = undefined, peername = IP} = State) ->
|
||||||
case authenticate(Pkt, IP) of
|
case authenticate(Pkt, IP, State) of
|
||||||
{ok, JID} ->
|
{ok, JID} ->
|
||||||
case validate_will(Pkt, JID) of
|
case validate_will(Pkt, JID) of
|
||||||
ok ->
|
ok ->
|
||||||
|
@ -939,7 +942,12 @@ check_sock_result({_, Sock}, {error, Why}) ->
|
||||||
starttls(#state{socket = {gen_tcp, Socket}, tls = true}) ->
|
starttls(#state{socket = {gen_tcp, Socket}, tls = true}) ->
|
||||||
case ejabberd_pkix:get_certfile() of
|
case ejabberd_pkix:get_certfile() of
|
||||||
{ok, Cert} ->
|
{ok, Cert} ->
|
||||||
case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of
|
CAFileOpt =
|
||||||
|
case ejabberd_option:c2s_cafile(ejabberd_config:get_myname()) of
|
||||||
|
undefined -> [];
|
||||||
|
CAFile -> [{cafile, CAFile}]
|
||||||
|
end,
|
||||||
|
case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}] ++ CAFileOpt) of
|
||||||
{ok, TLSSock} ->
|
{ok, TLSSock} ->
|
||||||
{ok, {fast_tls, TLSSock}};
|
{ok, {fast_tls, TLSSock}};
|
||||||
{error, Why} ->
|
{error, Why} ->
|
||||||
|
@ -1172,9 +1180,9 @@ parse_credentials(JID, ClientID) ->
|
||||||
end
|
end
|
||||||
end.
|
end.
|
||||||
|
|
||||||
-spec authenticate(connect(), peername()) -> {ok, jid:jid()} | {error, reason_code()}.
|
-spec authenticate(connect(), peername(), state()) -> {ok, jid:jid()} | {error, reason_code()}.
|
||||||
authenticate(Pkt, IP) ->
|
authenticate(Pkt, IP, State) ->
|
||||||
case authenticate(Pkt) of
|
case authenticate(Pkt, State) of
|
||||||
{ok, JID, AuthModule} ->
|
{ok, JID, AuthModule} ->
|
||||||
?INFO_MSG("Accepted MQTT authentication for ~ts by ~s backend from ~s",
|
?INFO_MSG("Accepted MQTT authentication for ~ts by ~s backend from ~s",
|
||||||
[jid:encode(JID),
|
[jid:encode(JID),
|
||||||
|
@ -1185,8 +1193,8 @@ authenticate(Pkt, IP) ->
|
||||||
Err
|
Err
|
||||||
end.
|
end.
|
||||||
|
|
||||||
-spec authenticate(connect()) -> {ok, jid:jid(), module()} | {error, reason_code()}.
|
-spec authenticate(connect(), state()) -> {ok, jid:jid(), module()} | {error, reason_code()}.
|
||||||
authenticate(#connect{password = Pass, properties = Props} = Pkt) ->
|
authenticate(#connect{password = Pass, properties = Props} = Pkt, State) ->
|
||||||
case parse_credentials(Pkt) of
|
case parse_credentials(Pkt) of
|
||||||
{ok, #jid{luser = LUser, lserver = LServer} = JID} ->
|
{ok, #jid{luser = LUser, lserver = LServer} = JID} ->
|
||||||
case maps:find(authentication_method, Props) of
|
case maps:find(authentication_method, Props) of
|
||||||
|
@ -1200,16 +1208,82 @@ authenticate(#connect{password = Pass, properties = Props} = Pkt) ->
|
||||||
{ok, _} ->
|
{ok, _} ->
|
||||||
{error, 'bad-authentication-method'};
|
{error, 'bad-authentication-method'};
|
||||||
error ->
|
error ->
|
||||||
|
case Pass of
|
||||||
|
<<>> ->
|
||||||
|
case tls_auth(JID, State) of
|
||||||
|
true ->
|
||||||
|
{ok, JID, pkix};
|
||||||
|
false ->
|
||||||
|
{error, 'not-authorized'}
|
||||||
|
end;
|
||||||
|
_ ->
|
||||||
case ejabberd_auth:check_password_with_authmodule(
|
case ejabberd_auth:check_password_with_authmodule(
|
||||||
LUser, <<>>, LServer, Pass) of
|
LUser, <<>>, LServer, Pass) of
|
||||||
{true, AuthModule} -> {ok, JID, AuthModule};
|
{true, AuthModule} -> {ok, JID, AuthModule};
|
||||||
false -> {error, 'not-authorized'}
|
false -> {error, 'not-authorized'}
|
||||||
end
|
end
|
||||||
|
end
|
||||||
end;
|
end;
|
||||||
{error, _} = Err ->
|
{error, _} = Err ->
|
||||||
Err
|
Err
|
||||||
end.
|
end.
|
||||||
|
|
||||||
|
-spec tls_auth(jid:jid(), state()) -> boolean().
|
||||||
|
tls_auth(_JID, #state{tls_verify = false}) ->
|
||||||
|
false;
|
||||||
|
tls_auth(JID, State) ->
|
||||||
|
case State#state.socket of
|
||||||
|
{fast_tls, Sock} ->
|
||||||
|
case fast_tls:get_peer_certificate(Sock, otp) of
|
||||||
|
{ok, Cert} ->
|
||||||
|
case fast_tls:get_verify_result(Sock) of
|
||||||
|
0 ->
|
||||||
|
case get_cert_jid(Cert) of
|
||||||
|
{ok, JID2} ->
|
||||||
|
jid:remove_resource(jid:tolower(JID)) ==
|
||||||
|
jid:remove_resource(jid:tolower(JID2));
|
||||||
|
error ->
|
||||||
|
false
|
||||||
|
end;
|
||||||
|
VerifyRes ->
|
||||||
|
Reason = fast_tls:get_cert_verify_string(VerifyRes, Cert),
|
||||||
|
?WARNING_MSG("TLS verify failed: ~s", [Reason]),
|
||||||
|
false
|
||||||
|
end;
|
||||||
|
error ->
|
||||||
|
false
|
||||||
|
end;
|
||||||
|
_ ->
|
||||||
|
false
|
||||||
|
end.
|
||||||
|
|
||||||
|
get_cert_jid(Cert) ->
|
||||||
|
case Cert#'OTPCertificate'.tbsCertificate#'OTPTBSCertificate'.subject of
|
||||||
|
{rdnSequence, Attrs1} ->
|
||||||
|
Attrs = lists:flatten(Attrs1),
|
||||||
|
case lists:keyfind(?'id-at-commonName',
|
||||||
|
#'AttributeTypeAndValue'.type, Attrs) of
|
||||||
|
#'AttributeTypeAndValue'{value = {utf8String, Val}} ->
|
||||||
|
try jid:decode(Val) of
|
||||||
|
#jid{luser = <<>>} ->
|
||||||
|
case jid:make(Val, ejabberd_config:get_myname()) of
|
||||||
|
error ->
|
||||||
|
error;
|
||||||
|
JID ->
|
||||||
|
{ok, JID}
|
||||||
|
end;
|
||||||
|
JID ->
|
||||||
|
{ok, JID}
|
||||||
|
catch _:{bad_jid, _} ->
|
||||||
|
error
|
||||||
|
end;
|
||||||
|
_ ->
|
||||||
|
error
|
||||||
|
end;
|
||||||
|
_ ->
|
||||||
|
error
|
||||||
|
end.
|
||||||
|
|
||||||
%%%===================================================================
|
%%%===================================================================
|
||||||
%%% Validators
|
%%% Validators
|
||||||
%%%===================================================================
|
%%%===================================================================
|
||||||
|
|
Loading…
Reference in New Issue