diff --git a/src/mod_mqtt_session.erl b/src/mod_mqtt_session.erl index 8ce040669..6a551f00f 100644 --- a/src/mod_mqtt_session.erl +++ b/src/mod_mqtt_session.erl @@ -29,6 +29,7 @@ -include("logger.hrl"). -include("mqtt.hrl"). -include_lib("xmpp/include/xmpp.hrl"). +-include_lib("public_key/include/public_key.hrl"). -record(state, {vsn = ?VSN :: integer(), version :: undefined | mqtt_version(), @@ -47,7 +48,8 @@ in_flight :: undefined | publish() | pubrel(), codec :: mqtt_codec:state(), queue :: undefined | p1_queue:queue(publish()), - tls :: boolean()}). + tls :: boolean(), + tls_verify :: boolean()}). -type acks() :: #{non_neg_integer() => pubrec()}. -type subscriptions() :: #{binary() => {sub_opts(), non_neg_integer()}}. @@ -162,6 +164,7 @@ init([SockMod, Socket, ListenOpts]) -> State1 = #state{socket = {SockMod, Socket}, id = p1_rand:uniform(65535), tls = proplists:get_bool(tls, ListenOpts), + tls_verify = proplists:get_bool(tls_verify, ListenOpts), codec = mqtt_codec:new(MaxSize)}, Timeout = timer:seconds(30), State2 = set_timeout(State1, Timeout), @@ -553,7 +556,7 @@ unregister_session(_, _) -> {error, state(), error_reason()}. handle_connect(#connect{clean_start = CleanStart} = Pkt, #state{jid = undefined, peername = IP} = State) -> - case authenticate(Pkt, IP) of + case authenticate(Pkt, IP, State) of {ok, JID} -> case validate_will(Pkt, JID) of ok -> @@ -939,7 +942,12 @@ check_sock_result({_, Sock}, {error, Why}) -> starttls(#state{socket = {gen_tcp, Socket}, tls = true}) -> case ejabberd_pkix:get_certfile() of {ok, Cert} -> - case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of + CAFileOpt = + case ejabberd_option:c2s_cafile(ejabberd_config:get_myname()) of + undefined -> []; + CAFile -> [{cafile, CAFile}] + end, + case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}] ++ CAFileOpt) of {ok, TLSSock} -> {ok, {fast_tls, TLSSock}}; {error, Why} -> @@ -1172,9 +1180,9 @@ parse_credentials(JID, ClientID) -> end end. --spec authenticate(connect(), peername()) -> {ok, jid:jid()} | {error, reason_code()}. -authenticate(Pkt, IP) -> - case authenticate(Pkt) of +-spec authenticate(connect(), peername(), state()) -> {ok, jid:jid()} | {error, reason_code()}. +authenticate(Pkt, IP, State) -> + case authenticate(Pkt, State) of {ok, JID, AuthModule} -> ?INFO_MSG("Accepted MQTT authentication for ~ts by ~s backend from ~s", [jid:encode(JID), @@ -1185,8 +1193,8 @@ authenticate(Pkt, IP) -> Err end. --spec authenticate(connect()) -> {ok, jid:jid(), module()} | {error, reason_code()}. -authenticate(#connect{password = Pass, properties = Props} = Pkt) -> +-spec authenticate(connect(), state()) -> {ok, jid:jid(), module()} | {error, reason_code()}. +authenticate(#connect{password = Pass, properties = Props} = Pkt, State) -> case parse_credentials(Pkt) of {ok, #jid{luser = LUser, lserver = LServer} = JID} -> case maps:find(authentication_method, Props) of @@ -1200,16 +1208,82 @@ authenticate(#connect{password = Pass, properties = Props} = Pkt) -> {ok, _} -> {error, 'bad-authentication-method'}; error -> - case ejabberd_auth:check_password_with_authmodule( - LUser, <<>>, LServer, Pass) of - {true, AuthModule} -> {ok, JID, AuthModule}; - false -> {error, 'not-authorized'} - end + case Pass of + <<>> -> + case tls_auth(JID, State) of + true -> + {ok, JID, pkix}; + false -> + {error, 'not-authorized'} + end; + _ -> + case ejabberd_auth:check_password_with_authmodule( + LUser, <<>>, LServer, Pass) of + {true, AuthModule} -> {ok, JID, AuthModule}; + false -> {error, 'not-authorized'} + end + end end; {error, _} = Err -> Err end. +-spec tls_auth(jid:jid(), state()) -> boolean(). +tls_auth(_JID, #state{tls_verify = false}) -> + false; +tls_auth(JID, State) -> + case State#state.socket of + {fast_tls, Sock} -> + case fast_tls:get_peer_certificate(Sock, otp) of + {ok, Cert} -> + case fast_tls:get_verify_result(Sock) of + 0 -> + case get_cert_jid(Cert) of + {ok, JID2} -> + jid:remove_resource(jid:tolower(JID)) == + jid:remove_resource(jid:tolower(JID2)); + error -> + false + end; + VerifyRes -> + Reason = fast_tls:get_cert_verify_string(VerifyRes, Cert), + ?WARNING_MSG("TLS verify failed: ~s", [Reason]), + false + end; + error -> + false + end; + _ -> + false + end. + +get_cert_jid(Cert) -> + case Cert#'OTPCertificate'.tbsCertificate#'OTPTBSCertificate'.subject of + {rdnSequence, Attrs1} -> + Attrs = lists:flatten(Attrs1), + case lists:keyfind(?'id-at-commonName', + #'AttributeTypeAndValue'.type, Attrs) of + #'AttributeTypeAndValue'{value = {utf8String, Val}} -> + try jid:decode(Val) of + #jid{luser = <<>>} -> + case jid:make(Val, ejabberd_config:get_myname()) of + error -> + error; + JID -> + {ok, JID} + end; + JID -> + {ok, JID} + catch _:{bad_jid, _} -> + error + end; + _ -> + error + end; + _ -> + error + end. + %%%=================================================================== %%% Validators %%%===================================================================