From a3df791373c30ccc79a6082f4c910a378d726cdc Mon Sep 17 00:00:00 2001 From: Evgeny Khramtsov Date: Mon, 25 Feb 2019 11:42:09 +0300 Subject: [PATCH] Add MQTT support --- include/mqtt.hrl | 183 +++++ rebar.config | 1 + src/mod_mqtt.erl | 561 +++++++++++++++ src/mod_mqtt_mnesia.erl | 132 ++++ src/mod_mqtt_session.erl | 1318 +++++++++++++++++++++++++++++++++++ src/mod_mqtt_sql.erl | 151 ++++ src/mqtt_codec.erl | 1402 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 3748 insertions(+) create mode 100644 include/mqtt.hrl create mode 100644 src/mod_mqtt.erl create mode 100644 src/mod_mqtt_mnesia.erl create mode 100644 src/mod_mqtt_session.erl create mode 100644 src/mod_mqtt_sql.erl create mode 100644 src/mqtt_codec.erl diff --git a/include/mqtt.hrl b/include/mqtt.hrl new file mode 100644 index 000000000..6756d9483 --- /dev/null +++ b/include/mqtt.hrl @@ -0,0 +1,183 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2002-2019 ProcessOne, SARL. All Rights Reserved. +%%% +%%% Licensed under the Apache License, Version 2.0 (the "License"); +%%% you may not use this file except in compliance with the License. +%%% You may obtain a copy of the License at +%%% +%%% http://www.apache.org/licenses/LICENSE-2.0 +%%% +%%% Unless required by applicable law or agreed to in writing, software +%%% distributed under the License is distributed on an "AS IS" BASIS, +%%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%%% See the License for the specific language governing permissions and +%%% limitations under the License. +%%% +%%%------------------------------------------------------------------- +-define(MQTT_VERSION_4, 4). +-define(MQTT_VERSION_5, 5). + +-record(connect, {proto_level = 4 :: non_neg_integer(), + will :: undefined | publish(), + clean_start = true :: boolean(), + keep_alive = 0 :: non_neg_integer(), + client_id = <<>> :: binary(), + username = <<>> :: binary(), + password = <<>> :: binary(), + will_properties = #{} :: properties(), + properties = #{} :: properties()}). +-record(connack, {session_present = false :: boolean(), + code = success :: reason_code(), + properties = #{} :: properties()}). + +-record(publish, {id :: undefined | non_neg_integer(), + dup = false :: boolean(), + qos = 0 :: qos(), + retain = false :: boolean(), + topic :: binary(), + payload :: binary(), + properties = #{} :: properties(), + meta = #{} :: map()}). +-record(puback, {id :: non_neg_integer(), + code = success :: reason_code(), + properties = #{} :: properties()}). +-record(pubrec, {id :: non_neg_integer(), + code = success :: reason_code(), + properties = #{} :: properties()}). +-record(pubrel, {id :: non_neg_integer(), + code = success :: reason_code(), + properties = #{} :: properties(), + meta = #{} :: map()}). +-record(pubcomp, {id :: non_neg_integer(), + code = success :: reason_code(), + properties = #{} :: properties()}). + +-record(subscribe, {id :: non_neg_integer(), + filters :: [{binary(), sub_opts()}], + properties = #{} :: properties(), + meta = #{} :: map()}). +-record(suback, {id :: non_neg_integer(), + codes = [] :: [char() | reason_code()], + properties = #{} :: properties()}). + +-record(unsubscribe, {id :: non_neg_integer(), + filters :: [binary()], + properties = #{} :: properties(), + meta = #{} :: map()}). +-record(unsuback, {id :: non_neg_integer(), + codes = [] :: [reason_code()], + properties = #{} :: properties()}). + +-record(pingreq, {meta = #{} :: map()}). +-record(pingresp, {}). + +-record(disconnect, {code = 'normal-disconnection' :: reason_code(), + properties = #{} :: properties()}). + +-record(auth, {code = success :: reason_code(), + properties = #{} :: properties()}). + +-record(sub_opts, {qos = 0 :: qos(), + no_local = false :: boolean(), + retain_as_published = false :: boolean(), + retain_handling = 0 :: 0..2}). + +-type qos() :: 0|1|2. +-type sub_opts() :: #sub_opts{}. +-type utf8_pair() :: {binary(), binary()}. +-type properties() :: map(). +-type property() :: assigned_client_identifier | + authentication_data | + authentication_method | + content_type | + correlation_data | + maximum_packet_size | + maximum_qos | + message_expiry_interval | + payload_format_indicator | + reason_string | + receive_maximum | + request_problem_information | + request_response_information | + response_information | + response_topic | + retain_available | + server_keep_alive | + server_reference | + session_expiry_interval | + shared_subscription_available | + subscription_identifier | + subscription_identifiers_available | + topic_alias | + topic_alias_maximum | + user_property | + wildcard_subscription_available | + will_delay_interval. +-type reason_code() :: 'success' | + 'normal-disconnection' | + 'granted-qos-0' | + 'granted-qos-1' | + 'granted-qos-2' | + 'disconnect-with-will-message' | + 'no-matching-subscribers' | + 'no-subscription-existed' | + 'continue-authentication' | + 're-authenticate' | + 'unspecified-error' | + 'malformed-packet' | + 'protocol-error' | + 'implementation-specific-error' | + 'unsupported-protocol-version' | + 'client-identifier-not-valid' | + 'bad-user-name-or-password' | + 'not-authorized' | + 'server-unavailable' | + 'server-busy' | + 'banned' | + 'server-shutting-down' | + 'bad-authentication-method' | + 'keep-alive-timeout' | + 'session-taken-over' | + 'topic-filter-invalid' | + 'topic-name-invalid' | + 'packet-identifier-in-use' | + 'packet-identifier-not-found' | + 'receive-maximum-exceeded' | + 'topic-alias-invalid' | + 'packet-too-large' | + 'message-rate-too-high' | + 'quota-exceeded' | + 'administrative-action' | + 'payload-format-invalid' | + 'retain-not-supported' | + 'qos-not-supported' | + 'use-another-server' | + 'server-moved' | + 'shared-subscriptions-not-supported' | + 'connection-rate-exceeded' | + 'maximum-connect-time' | + 'subscription-identifiers-not-supported' | + 'wildcard-subscriptions-not-supported'. + +-type connect() :: #connect{}. +-type connack() :: #connack{}. +-type publish() :: #publish{}. +-type puback() :: #puback{}. +-type pubrel() :: #pubrel{}. +-type pubrec() :: #pubrec{}. +-type pubcomp() :: #pubcomp{}. +-type subscribe() :: #subscribe{}. +-type suback() :: #suback{}. +-type unsubscribe() :: #unsubscribe{}. +-type unsuback() :: #unsuback{}. +-type pingreq() :: #pingreq{}. +-type pingresp() :: #pingresp{}. +-type disconnect() :: #disconnect{}. +-type auth() :: #auth{}. + +-type mqtt_packet() :: connect() | connack() | publish() | puback() | + pubrel() | pubrec() | pubcomp() | subscribe() | + suback() | unsubscribe() | unsuback() | pingreq() | + pingresp() | disconnect() | auth(). +-type mqtt_version() :: ?MQTT_VERSION_4 | ?MQTT_VERSION_5. diff --git a/rebar.config b/rebar.config index 39608346d..b9955c7ab 100644 --- a/rebar.config +++ b/rebar.config @@ -31,6 +31,7 @@ {pkix, ".*", {git, "https://github.com/processone/pkix", {tag, "1.0.0"}}}, {jose, ".*", {git, "https://github.com/potatosalad/erlang-jose", {tag, "1.8.4"}}}, {eimp, ".*", {git, "https://github.com/processone/eimp", {tag, "1.0.9"}}}, + {mqtree, ".*", {git, "https://github.com/processone/mqtree", {tag, "1.0.1"}}}, {if_var_true, stun, {stun, ".*", {git, "https://github.com/processone/stun", {tag, "1.0.26"}}}}, {if_var_true, sip, {esip, ".*", {git, "https://github.com/processone/esip", {tag, "1.0.27"}}}}, {if_var_true, mysql, {p1_mysql, ".*", {git, "https://github.com/processone/p1_mysql", diff --git a/src/mod_mqtt.erl b/src/mod_mqtt.erl new file mode 100644 index 000000000..86aea87be --- /dev/null +++ b/src/mod_mqtt.erl @@ -0,0 +1,561 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2002-2019 ProcessOne, SARL. All Rights Reserved. +%%% +%%% Licensed under the Apache License, Version 2.0 (the "License"); +%%% you may not use this file except in compliance with the License. +%%% You may obtain a copy of the License at +%%% +%%% http://www.apache.org/licenses/LICENSE-2.0 +%%% +%%% Unless required by applicable law or agreed to in writing, software +%%% distributed under the License is distributed on an "AS IS" BASIS, +%%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%%% See the License for the specific language governing permissions and +%%% limitations under the License. +%%% +%%%------------------------------------------------------------------- +-module(mod_mqtt). +-behaviour(p1_server). +-behaviour(gen_mod). +-behaviour(ejabberd_listener). + +%% gen_mod API +-export([start/2, stop/1, reload/3, depends/2, mod_options/1, mod_opt_type/1]). +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). +%% ejabberd_listener API +-export([start_link/2, listen_opt_type/1, listen_options/0, accept/1]). +%% Legacy ejabberd_listener API +-export([become_controller/2, socket_type/0]). +%% API +-export([open_session/1, close_session/1, lookup_session/1, + publish/3, subscribe/4, unsubscribe/2, select_retained/4, + check_publish_access/2, check_subscribe_access/2]). + +-include("logger.hrl"). +-include("mqtt.hrl"). + +-define(MQTT_TOPIC_CACHE, mqtt_topic_cache). +-define(MQTT_PAYLOAD_CACHE, mqtt_payload_cache). + +-type continuation() :: term(). +-type seconds() :: non_neg_integer(). + +%% RAM backend callbacks +-callback init() -> ok | {error, any()}. +-callback open_session(jid:ljid()) -> ok | {error, db_failure}. +-callback close_session(jid:ljid()) -> ok | {error, db_failure}. +-callback lookup_session(jid:ljid()) -> {ok, pid()} | {error, notfound | db_failure}. +-callback subscribe(jid:ljid(), binary(), sub_opts(), non_neg_integer()) -> ok | {error, db_failure}. +-callback unsubscribe(jid:ljid(), binary()) -> ok | {error, notfound | db_failure}. +-callback find_subscriber(binary(), binary() | continuation()) -> + {ok, {pid(), qos()}, continuation()} | {error, notfound | db_failure}. +%% Disc backend callbacks +-callback init(binary(), gen_mod:opts()) -> ok | {error, any()}. +-callback publish(jid:ljid(), binary(), binary(), qos(), properties(), seconds()) -> + ok | {error, db_failure}. +-callback delete_published(jid:ljid(), binary()) -> ok | {error, db_failure}. +-callback lookup_published(jid:ljid(), binary()) -> + {ok, {binary(), qos(), properties(), seconds()}} | + {error, notfound | db_failure}. +-callback list_topics(binary()) -> {ok, [binary()]} | {error, db_failure}. +-callback use_cache(binary()) -> boolean(). +-callback cache_nodes(binary()) -> [node()]. + +-optional_callbacks([use_cache/1, cache_nodes/1]). + +-record(state, {}). + +%%%=================================================================== +%%% API +%%%=================================================================== +start({SockMod, Sock}, ListenOpts) -> + mod_mqtt_session:start(SockMod, Sock, ListenOpts); +start(Host, Opts) -> + gen_mod:start_child(?MODULE, Host, Opts). + +start_link({SockMod, Sock}, ListenOpts) -> + mod_mqtt_session:start_link(SockMod, Sock, ListenOpts). + +stop(Host) -> + gen_mod:stop_child(?MODULE, Host). + +reload(_Host, _NewOpts, _OldOpts) -> + ok. + +depends(_Host, _Opts) -> + []. + +socket_type() -> + raw. + +become_controller(Pid, _) -> + accept(Pid). + +accept(Pid) -> + mod_mqtt_session:accept(Pid). + +open_session({U, S, R}) -> + Mod = gen_mod:ram_db_mod(S, ?MODULE), + Mod:open_session({U, S, R}). + +close_session({U, S, R}) -> + Mod = gen_mod:ram_db_mod(S, ?MODULE), + Mod:close_session({U, S, R}). + +lookup_session({U, S, R}) -> + Mod = gen_mod:ram_db_mod(S, ?MODULE), + Mod:lookup_session({U, S, R}). + +-spec publish(jid:ljid(), publish(), seconds()) -> + {ok, non_neg_integer()} | {error, db_failure | publish_forbidden}. +publish({_, S, _} = USR, Pkt, ExpiryTime) -> + case check_publish_access(Pkt#publish.topic, USR) of + allow -> + case retain(USR, Pkt, ExpiryTime) of + ok -> + Mod = gen_mod:ram_db_mod(S, ?MODULE), + route(Mod, S, Pkt, ExpiryTime); + {error, _} = Err -> + Err + end; + deny -> + {error, publish_forbidden} + end. + +-spec subscribe(jid:ljid(), binary(), sub_opts(), non_neg_integer()) -> + ok | {error, db_failure | subscribe_forbidden}. +subscribe({_, S, _} = USR, TopicFilter, SubOpts, ID) -> + Mod = gen_mod:ram_db_mod(S, ?MODULE), + Limit = gen_mod:get_module_opt(S, ?MODULE, max_topic_depth), + case check_topic_depth(TopicFilter, Limit) of + allow -> + case check_subscribe_access(TopicFilter, USR) of + allow -> + Mod:subscribe(USR, TopicFilter, SubOpts, ID); + deny -> + {error, subscribe_forbidden} + end; + deny -> + {error, subscribe_forbidden} + end. + +-spec unsubscribe(jid:ljid(), binary()) -> ok | {error, notfound | db_failure}. +unsubscribe({U, S, R}, Topic) -> + Mod = gen_mod:ram_db_mod(S, ?MODULE), + Mod:unsubscribe({U, S, R}, Topic). + +-spec select_retained(jid:ljid(), binary(), qos(), non_neg_integer()) -> + [{publish(), seconds()}]. +select_retained({_, S, _} = USR, TopicFilter, QoS, SubID) -> + Mod = gen_mod:db_mod(S, ?MODULE), + Limit = gen_mod:get_module_opt(S, ?MODULE, match_retained_limit), + select_retained(Mod, USR, TopicFilter, QoS, SubID, Limit). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([Host, Opts]) -> + Mod = gen_mod:db_mod(Host, Opts, ?MODULE), + RMod = gen_mod:ram_db_mod(Host, Opts, ?MODULE), + try + ok = Mod:init(Host, Opts), + ok = RMod:init(), + ok = init_cache(Mod, Host, Opts), + {ok, #state{}} + catch _:{badmatch, {error, Why}} -> + {stop, Why} + end. + +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Options +%%%=================================================================== +mod_options(Host) -> + [{match_retained_limit, 1000}, + {max_topic_depth, 8}, + {max_topic_aliases, 100}, + {session_expiry, 300}, + {max_queue, 5000}, + {access_subscribe, []}, + {access_publish, []}, + {db_type, ejabberd_config:default_db(Host, ?MODULE)}, + {ram_db_type, ejabberd_config:default_ram_db(Host, ?MODULE)}, + {queue_type, ejabberd_config:default_queue_type(Host)}, + {use_cache, ejabberd_config:use_cache(Host)}, + {cache_size, ejabberd_config:cache_size(Host)}, + {cache_missed, ejabberd_config:cache_missed(Host)}, + {cache_life_time, ejabberd_config:cache_life_time(Host)}]. + +mod_opt_type(max_queue) -> + fun(I) when is_integer(I), I > 0 -> I; + (infinity) -> unlimited; + (unlimited) -> unlimited + end; +mod_opt_type(session_expiry) -> + fun(I) when is_integer(I), I>= 0 -> I end; +mod_opt_type(match_retained_limit) -> + fun(I) when is_integer(I), I>0 -> I; + (unlimited) -> infinity; + (infinity) -> infinity + end; +mod_opt_type(max_topic_depth) -> + fun(I) when is_integer(I), I>0 -> I; + (unlimited) -> infinity; + (infinity) -> infinity + end; +mod_opt_type(max_topic_aliases) -> + fun(I) when is_integer(I), I>=0, I<65536 -> I end; +mod_opt_type(access_subscribe) -> + fun validate_topic_access/1; +mod_opt_type(access_publish) -> + fun validate_topic_access/1; +mod_opt_type(db_type) -> + fun(T) -> ejabberd_config:v_db(?MODULE, T) end; +mod_opt_type(ram_db_type) -> + fun(T) -> ejabberd_config:v_db(?MODULE, T) end; +mod_opt_type(queue_type) -> + fun(ram) -> ram; (file) -> file end; +mod_opt_type(O) when O == cache_life_time; O == cache_size -> + fun(I) when is_integer(I), I > 0 -> I; + (infinity) -> infinity + end; +mod_opt_type(O) when O == use_cache; O == cache_missed -> + fun (B) when is_boolean(B) -> B end. + +listen_opt_type(tls_verify) -> + fun(B) when is_boolean(B) -> B end; +listen_opt_type(max_payload_size) -> + fun(I) when is_integer(I), I>0 -> I; + (unlimited) -> infinity; + (infinity) -> infinity + end. + +listen_options() -> + [{max_fsm_queue, 5000}, + {max_payload_size, infinity}, + {tls, false}, + {tls_verify, false}]. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +route(Mod, LServer, Pkt, ExpiryTime) -> + route(Mod, LServer, Pkt, ExpiryTime, Pkt#publish.topic, 0). + +route(Mod, LServer, Pkt, ExpiryTime, Continuation, Num) -> + case Mod:find_subscriber(LServer, Continuation) of + {ok, {Pid, #sub_opts{no_local = true}, _}, Continuation1} + when Pid == self() -> + route(Mod, LServer, Pkt, ExpiryTime, Continuation1, Num); + {ok, {Pid, SubOpts, ID}, Continuation1} -> + ?DEBUG("Route to ~p: ~s", [Pid, Pkt#publish.topic]), + MinQoS = min(SubOpts#sub_opts.qos, Pkt#publish.qos), + Retain = case SubOpts#sub_opts.retain_as_published of + false -> false; + true -> Pkt#publish.retain + end, + Props = set_sub_id(ID, Pkt#publish.properties), + mod_mqtt_session:route( + Pid, {Pkt#publish{qos = MinQoS, + dup = false, + retain = Retain, + properties = Props}, + ExpiryTime}), + route(Mod, LServer, Pkt, ExpiryTime, Continuation1, Num+1); + {error, _} -> + {ok, Num} + end. + +select_retained(Mod, {_, LServer, _} = USR, TopicFilter, QoS, SubID, Limit) -> + Topics = match_topics(TopicFilter, LServer, Limit), + lists:filtermap( + fun({{Filter, _}, Topic}) -> + case lookup_published(Mod, USR, Topic) of + {ok, {Payload, QoS1, Props, ExpiryTime}} -> + Props1 = set_sub_id(SubID, Props), + {true, {#publish{topic = Topic, + payload = Payload, + retain = true, + properties = Props1, + qos = min(QoS, QoS1)}, + ExpiryTime}}; + error -> + ets:delete(?MQTT_TOPIC_CACHE, {Filter, LServer}), + false; + _ -> + false + end + end, Topics). + +match_topics(Topic, LServer, Limit) -> + Filter = topic_filter(Topic), + case Limit of + infinity -> + ets:match_object(?MQTT_TOPIC_CACHE, {{Filter, LServer}, '_'}); + _ -> + case ets:select(?MQTT_TOPIC_CACHE, + [{{{Filter, LServer}, '_'}, [], ['$_']}], Limit) of + {Topics, _} -> Topics; + '$end_of_table' -> [] + end + end. + +retain({_, S, _} = USR, #publish{retain = true, + topic = Topic, payload = Data, + qos = QoS, properties = Props}, + ExpiryTime) -> + Mod = gen_mod:db_mod(S, ?MODULE), + TopicKey = topic_key(Topic), + case Data of + <<>> -> + ets:delete(?MQTT_TOPIC_CACHE, {TopicKey, S}), + case use_cache(Mod, S) of + true -> + ets_cache:delete(?MQTT_PAYLOAD_CACHE, {S, Topic}, + cache_nodes(Mod, S)); + false -> + ok + end, + Mod:delete_published(USR, Topic); + _ -> + ets:insert(?MQTT_TOPIC_CACHE, {{TopicKey, S}, Topic}), + case use_cache(Mod, S) of + true -> + case ets_cache:update( + ?MQTT_PAYLOAD_CACHE, {S, Topic}, + {ok, {Data, QoS, Props, ExpiryTime}}, + fun() -> + Mod:publish(USR, Topic, Data, + QoS, Props, ExpiryTime) + end, cache_nodes(Mod, S)) of + {ok, _} -> ok; + {error, _} = Err -> Err + end; + false -> + Mod:publish(USR, Topic, Data, QoS, Props, ExpiryTime) + end + end; +retain(_, _, _) -> + ok. + +lookup_published(Mod, {_, LServer, _} = USR, Topic) -> + case use_cache(Mod, LServer) of + true -> + ets_cache:lookup( + ?MQTT_PAYLOAD_CACHE, {LServer, Topic}, + fun() -> + Mod:lookup_published(USR, Topic) + end); + false -> + Mod:lookup_published(USR, Topic) + end. + +set_sub_id(0, Props) -> + Props; +set_sub_id(ID, Props) -> + Props#{subscription_identifier => [ID]}. + +%%%=================================================================== +%%% Matching functions +%%%=================================================================== +topic_key(S) -> + Parts = split_path(S), + case join_key(Parts) of + [<<>>|T] -> T; + T -> T + end. + +topic_filter(S) -> + Parts = split_path(S), + case join_filter(Parts) of + [<<>>|T] -> T; + T -> T + end. + +join_key([X,Y|T]) -> + [X, $/|join_key([Y|T])]; +join_key([X]) -> + [X]; +join_key([]) -> + []. + +join_filter([X, <<$#>>]) -> + [wildcard(X)|'_']; +join_filter([X,Y|T]) -> + [wildcard(X), $/|join_filter([Y|T])]; +join_filter([<<>>]) -> + []; +join_filter([<<$#>>]) -> + '_'; +join_filter([X]) -> + [wildcard(X)]; +join_filter([]) -> + []. + +wildcard(<<$+>>) -> '_'; +wildcard(Bin) -> Bin. + +check_topic_depth(_Topic, infinity) -> + allow; +check_topic_depth(_, N) when N=<0 -> + deny; +check_topic_depth(<<$/, T/binary>>, N) -> + check_topic_depth(T, N-1); +check_topic_depth(<<_, T/binary>>, N) -> + check_topic_depth(T, N); +check_topic_depth(<<>>, _) -> + allow. + +split_path(Path) -> + binary:split(Path, <<$/>>, [global]). + +%%%=================================================================== +%%% Validators +%%%=================================================================== +validate_topic_access(FilterRules) -> + lists:map( + fun({TopicFilter, Access}) -> + Rule = acl:access_rules_validator(Access), + try + mqtt_codec:topic_filter(TopicFilter), + {split_path(TopicFilter), Rule} + catch _:_ -> + ?ERROR_MSG("Invalid topic filter: ~s", [TopicFilter]), + erlang:error(badarg) + end + end, lists:reverse(lists:keysort(1, FilterRules))). + +%%%=================================================================== +%%% ACL checks +%%%=================================================================== +check_subscribe_access(Topic, {_, S, _} = USR) -> + Rules = gen_mod:get_module_opt(S, mod_mqtt, access_subscribe), + check_access(Topic, USR, Rules). + +check_publish_access(<<$$, _/binary>>, _) -> + deny; +check_publish_access(Topic, {_, S, _} = USR) -> + Rules = gen_mod:get_module_opt(S, mod_mqtt, access_publish), + check_access(Topic, USR, Rules). + +check_access(_, _, []) -> + allow; +check_access(Topic, {U, S, R} = USR, FilterRules) -> + TopicParts = binary:split(Topic, <<$/>>, [global]), + case lists:any( + fun({FilterParts, Rule}) -> + case match(TopicParts, FilterParts, U, S, R) of + true -> + allow == acl:match_rule(S, Rule, USR); + false -> + false + end + end, FilterRules) of + true -> allow; + false -> deny + end. + +match(_, [<<"#">>|_], _, _, _) -> + true; +match([], [<<>>, <<"#">>|_], _, _, _) -> + true; +match([_|T1], [<<"+">>|T2], U, S, R) -> + match(T1, T2, U, S, R); +match([H|T1], [<<"%u">>|T2], U, S, R) -> + case jid:nodeprep(H) of + U -> match(T1, T2, U, S, R); + _ -> false + end; +match([H|T1], [<<"%d">>|T2], U, S, R) -> + case jid:nameprep(H) of + S -> match(T1, T2, U, S, R); + _ -> false + end; +match([H|T1], [<<"%c">>|T2], U, S, R) -> + case jid:resourceprep(H) of + R -> match(T1, T2, U, S, R); + _ -> false + end; +match([H|T1], [H|T2], U, S, R) -> + match(T1, T2, U, S, R); +match([], [], _, _, _) -> + true; +match(_, _, _, _, _) -> + false. + +%%%=================================================================== +%%% Cache stuff +%%%=================================================================== +-spec init_cache(module(), binary(), gen_mod:opts()) -> ok | {error, db_failure}. +init_cache(Mod, Host, Opts) -> + init_payload_cache(Mod, Host, Opts), + init_topic_cache(Mod, Host). + +-spec init_topic_cache(module(), binary()) -> ok | {error, db_failure}. +init_topic_cache(Mod, Host) -> + catch ets:new(?MQTT_TOPIC_CACHE, + [named_table, ordered_set, public, + {heir, erlang:group_leader(), none}]), + ?INFO_MSG("Building MQTT cache for ~s, this may take a while", [Host]), + case Mod:list_topics(Host) of + {ok, Topics} -> + lists:foreach( + fun(Topic) -> + ets:insert(?MQTT_TOPIC_CACHE, + {{topic_key(Topic), Host}, Topic}) + end, Topics); + {error, _} = Err -> + Err + end. + +-spec init_payload_cache(module(), binary(), gen_mod:opts()) -> ok. +init_payload_cache(Mod, Host, Opts) -> + case use_cache(Mod, Host) of + true -> + CacheOpts = cache_opts(Opts), + ets_cache:new(?MQTT_PAYLOAD_CACHE, CacheOpts); + false -> + ets_cache:delete(?MQTT_PAYLOAD_CACHE) + end. + +-spec cache_opts(gen_mod:opts()) -> [proplists:property()]. +cache_opts(Opts) -> + MaxSize = gen_mod:get_opt(cache_size, Opts), + CacheMissed = gen_mod:get_opt(cache_missed, Opts), + LifeTime = case gen_mod:get_opt(cache_life_time, Opts) of + infinity -> infinity; + I -> timer:seconds(I) + end, + [{max_size, MaxSize}, {cache_missed, CacheMissed}, {life_time, LifeTime}]. + +-spec use_cache(module(), binary()) -> boolean(). +use_cache(Mod, Host) -> + case erlang:function_exported(Mod, use_cache, 1) of + true -> Mod:use_cache(Host); + false -> gen_mod:get_module_opt(Host, ?MODULE, use_cache) + end. + +-spec cache_nodes(module(), binary()) -> [node()]. +cache_nodes(Mod, Host) -> + case erlang:function_exported(Mod, cache_nodes, 1) of + true -> Mod:cache_nodes(Host); + false -> ejabberd_cluster:get_nodes() + end. diff --git a/src/mod_mqtt_mnesia.erl b/src/mod_mqtt_mnesia.erl new file mode 100644 index 000000000..3439c9304 --- /dev/null +++ b/src/mod_mqtt_mnesia.erl @@ -0,0 +1,132 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2002-2019 ProcessOne, SARL. All Rights Reserved. +%%% +%%% Licensed under the Apache License, Version 2.0 (the "License"); +%%% you may not use this file except in compliance with the License. +%%% You may obtain a copy of the License at +%%% +%%% http://www.apache.org/licenses/LICENSE-2.0 +%%% +%%% Unless required by applicable law or agreed to in writing, software +%%% distributed under the License is distributed on an "AS IS" BASIS, +%%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%%% See the License for the specific language governing permissions and +%%% limitations under the License. +%%% +%%%------------------------------------------------------------------- +-module(mod_mqtt_mnesia). +-behaviour(mod_mqtt). + +%% API +-export([init/2, publish/6, delete_published/2, lookup_published/2]). +-export([list_topics/1, use_cache/1]). +%% Unsupported backend API +-export([init/0]). +-export([subscribe/4, unsubscribe/2, find_subscriber/2]). +-export([open_session/1, close_session/1, lookup_session/1]). + +-include("logger.hrl"). + +-record(mqtt_pub, {topic_server :: {binary(), binary()}, + user :: binary(), + resource :: binary(), + qos :: 0..2, + payload :: binary(), + expiry :: non_neg_integer(), + payload_format = binary :: binary | utf8, + response_topic = <<>> :: binary(), + correlation_data = <<>> :: binary(), + content_type = <<>> :: binary(), + user_properties = [] :: [{binary(), binary()}]}). + +%%%=================================================================== +%%% API +%%%=================================================================== +init(_Host, _Opts) -> + case ejabberd_mnesia:create( + ?MODULE, mqtt_pub, + [{disc_only_copies, [node()]}, + {attributes, record_info(fields, mqtt_pub)}]) of + {atomic, _} -> + ok; + Err -> + {error, Err} + end. + +use_cache(Host) -> + case mnesia:table_info(mqtt_pub, storage_type) of + disc_only_copies -> + gen_mod:get_module_opt(Host, mod_mqtt, use_cache); + _ -> + false + end. + +publish({U, LServer, R}, Topic, Payload, QoS, Props, ExpiryTime) -> + PayloadFormat = maps:get(payload_format_indicator, Props, binary), + ResponseTopic = maps:get(response_topic, Props, <<"">>), + CorrelationData = maps:get(correlation_data, Props, <<"">>), + ContentType = maps:get(content_type, Props, <<"">>), + UserProps = maps:get(user_property, Props, []), + mnesia:dirty_write(#mqtt_pub{topic_server = {Topic, LServer}, + user = U, + resource = R, + qos = QoS, + payload = Payload, + expiry = ExpiryTime, + payload_format = PayloadFormat, + response_topic = ResponseTopic, + correlation_data = CorrelationData, + content_type = ContentType, + user_properties = UserProps}). + +delete_published({_, S, _}, Topic) -> + mnesia:dirty_delete(mqtt_pub, {Topic, S}). + +lookup_published({_, S, _}, Topic) -> + case mnesia:dirty_read(mqtt_pub, {Topic, S}) of + [#mqtt_pub{qos = QoS, + payload = Payload, + expiry = ExpiryTime, + payload_format = PayloadFormat, + response_topic = ResponseTopic, + correlation_data = CorrelationData, + content_type = ContentType, + user_properties = UserProps}] -> + Props = #{payload_format => PayloadFormat, + response_topic => ResponseTopic, + correlation_data => CorrelationData, + content_type => ContentType, + user_property => UserProps}, + {ok, {Payload, QoS, Props, ExpiryTime}}; + [] -> + {error, notfound} + end. + +list_topics(S) -> + {ok, [Topic || {Topic, S1} <- mnesia:dirty_all_keys(mqtt_pub), S1 == S]}. + +init() -> + erlang:nif_error(unsupported_db). + +open_session(_) -> + erlang:nif_error(unsupported_db). + +close_session(_) -> + erlang:nif_error(unsupported_db). + +lookup_session(_) -> + erlang:nif_error(unsupported_db). + +subscribe(_, _, _, _) -> + erlang:nif_error(unsupported_db). + +unsubscribe(_, _) -> + erlang:nif_error(unsupported_db). + +find_subscriber(_, _) -> + erlang:nif_error(unsupported_db). + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== diff --git a/src/mod_mqtt_session.erl b/src/mod_mqtt_session.erl new file mode 100644 index 000000000..3df36b8fb --- /dev/null +++ b/src/mod_mqtt_session.erl @@ -0,0 +1,1318 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2002-2019 ProcessOne, SARL. All Rights Reserved. +%%% +%%% Licensed under the Apache License, Version 2.0 (the "License"); +%%% you may not use this file except in compliance with the License. +%%% You may obtain a copy of the License at +%%% +%%% http://www.apache.org/licenses/LICENSE-2.0 +%%% +%%% Unless required by applicable law or agreed to in writing, software +%%% distributed under the License is distributed on an "AS IS" BASIS, +%%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%%% See the License for the specific language governing permissions and +%%% limitations under the License. +%%% +%%%------------------------------------------------------------------- +-module(mod_mqtt_session). +-behaviour(p1_server). +-define(VSN, 1). +-vsn(?VSN). + +%% API +-export([start/3, start_link/3, accept/1, route/2]). +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-include("logger.hrl"). +-include("mqtt.hrl"). +-include("xmpp.hrl"). + +-record(state, {vsn = ?VSN :: integer(), + version :: undefined | mqtt_version(), + socket :: undefined | socket(), + peername :: peername(), + timeout = infinity :: timer(), + jid :: undefined | jid:jid(), + session_expiry = 0 :: seconds(), + will :: undefined | publish(), + will_delay = 0 :: seconds(), + stop_reason :: undefined | error_reason(), + acks = #{} :: map(), + subscriptions = #{} :: map(), + topic_aliases = #{} :: map(), + id = 0 :: non_neg_integer(), + in_flight :: undefined | publish() | pubrel(), + codec :: mqtt_codec:state(), + queue :: undefined | p1_queue:queue()}). + +-type error_reason() :: {auth, reason_code()} | + {code, reason_code()} | + {peer_disconnected, reason_code(), binary()} | + {socket, socket_error_reason()} | + {codec, mqtt_codec:error_reason()} | + {unexpected_packet, atom()} | + {tls, inet:posix() | atom() | binary()} | + {replaced, pid()} | {resumed, pid()} | + subscribe_forbidden | publish_forbidden | + will_topic_forbidden | internal_server_error | + session_expired | idle_connection | + queue_full | shutdown | db_failure | + {payload_format_invalid, will | publish} | + session_expiry_non_zero | unknown_topic_alias. + +-type state() :: #state{}. +-type sockmod() :: gen_tcp | fast_tls. +-type socket() :: {sockmod(), inet:socket() | fast_tls:tls_socket()}. +-type peername() :: {inet:ip_address(), inet:port_number()}. +-type seconds() :: non_neg_integer(). +-type milli_seconds() :: non_neg_integer(). +-type timer() :: infinity | {milli_seconds(), integer()}. +-type socket_error_reason() :: closed | timeout | inet:posix(). + +-define(CALL_TIMEOUT, timer:minutes(1)). +-define(RELAY_TIMEOUT, timer:minutes(1)). +-define(MAX_UINT32, 4294967295). + +%%%=================================================================== +%%% API +%%%=================================================================== +start(SockMod, Socket, ListenOpts) -> + p1_server:start(?MODULE, [SockMod, Socket, ListenOpts], + ejabberd_config:fsm_limit_opts(ListenOpts)). + +start_link(SockMod, Socket, ListenOpts) -> + p1_server:start_link(?MODULE, [SockMod, Socket, ListenOpts], + ejabberd_config:fsm_limit_opts(ListenOpts)). + +-spec accept(pid()) -> ok. +accept(Pid) -> + p1_server:cast(Pid, accept). + +-spec route(pid(), term()) -> boolean(). +route(Pid, Term) -> + ejabberd_cluster:send(Pid, Term). + +-spec format_error(error_reason()) -> string(). +format_error(session_expired) -> + "Disconnected session is expired"; +format_error(idle_connection) -> + "Idle connection"; +format_error(queue_full) -> + "Message queue is overloaded"; +format_error(internal_server_error) -> + "Internal server error"; +format_error(db_failure) -> + "Database failure"; +format_error(shutdown) -> + "System shutting down"; +format_error(subscribe_forbidden) -> + "Subscribing to this topic is forbidden by service policy"; +format_error(publish_forbidden) -> + "Publishing to this topic is forbidden by service policy"; +format_error(will_topic_forbidden) -> + "Publishing to this will topic is forbidden by service policy"; +format_error(session_expiry_non_zero) -> + "Session Expiry Interval in DISCONNECT packet should have been zero"; +format_error(unknown_topic_alias) -> + "No mapping found for this Topic Alias"; +format_error({payload_format_invalid, will}) -> + "Will payload format doesn't match its indicator"; +format_error({payload_format_invalid, publish}) -> + "PUBLISH payload format doesn't match its indicator"; +format_error({peer_disconnected, Code, <<>>}) -> + format("Peer disconnected with reason: ~s", + [mqtt_codec:format_reason_code(Code)]); +format_error({peer_disconnected, Code, Reason}) -> + format("Peer disconnected with reason: ~s (~s)", [Reason, Code]); +format_error({replaced, Pid}) -> + format("Replaced by ~p at ~s", [Pid, node(Pid)]); +format_error({resumed, Pid}) -> + format("Resumed by ~p at ~s", [Pid, node(Pid)]); +format_error({unexpected_packet, Name}) -> + format("Unexpected ~s packet", [string:to_upper(atom_to_list(Name))]); +format_error({tls, Reason}) -> + format("TLS failed: ~s", [format_tls_error(Reason)]); +format_error({socket, A}) -> + format("Connection failed: ~s", [format_inet_error(A)]); +format_error({code, Code}) -> + format("Protocol error: ~s", [mqtt_codec:format_reason_code(Code)]); +format_error({auth, Code}) -> + format("Authentication failed: ~s", [mqtt_codec:format_reason_code(Code)]); +format_error({codec, CodecError}) -> + format("Protocol error: ~s", [mqtt_codec:format_error(CodecError)]); +format_error(A) when is_atom(A) -> + atom_to_list(A); +format_error(Reason) -> + format("Unrecognized error: ~w", [Reason]). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([SockMod, Socket, ListenOpts]) -> + MaxSize = proplists:get_value(max_payload_size, ListenOpts, infinity), + SockMod1 = case proplists:get_bool(tls, ListenOpts) of + true -> fast_tls; + false -> SockMod + end, + State1 = #state{socket = {SockMod1, Socket}, + id = p1_rand:uniform(65535), + codec = mqtt_codec:new(MaxSize)}, + Timeout = timer:seconds(30), + State2 = set_timeout(State1, Timeout), + {ok, State2, Timeout}. + +handle_call({get_state, _}, From, #state{stop_reason = {resumed, Pid}} = State) -> + p1_server:reply(From, {error, {resumed, Pid}}), + noreply(State); +handle_call({get_state, Pid}, From, State) -> + case stop(State, {resumed, Pid}) of + {stop, Status, State1} -> + {stop, Status, State1#state{stop_reason = {replaced, Pid}}}; + {noreply, State1, _} -> + ?DEBUG("Transfering MQTT session state to ~p at ~s", [Pid, node(Pid)]), + Q1 = p1_queue:file_to_ram(State1#state.queue), + p1_server:reply(From, {ok, State1#state{queue = Q1}}), + SessionExpiry = timer:seconds(State1#state.session_expiry), + State2 = set_timeout(State1, min(SessionExpiry, ?RELAY_TIMEOUT)), + State3 = State2#state{queue = undefined, + stop_reason = {resumed, Pid}, + acks = #{}, + will = undefined, + session_expiry = 0, + topic_aliases = #{}, + subscriptions = #{}}, + noreply(State3) + end; +handle_call(Request, From, State) -> + ?WARNING_MSG("Got unexpected call from ~p: ~p", [From, Request]), + noreply(State). + +handle_cast(accept, #state{socket = {_, TCPSock} = Socket} = State) -> + case inet:peername(TCPSock) of + {ok, IPPort} -> + State1 = State#state{peername = IPPort}, + case starttls(Socket) of + {ok, Socket1} -> + State2 = State1#state{socket = Socket1}, + handle_info({tcp, TCPSock, <<>>}, State2); + {error, Why} -> + stop(State1, Why) + end; + {error, Why} -> + stop(State, {socket, Why}) + end; +handle_cast(Msg, State) -> + ?WARNING_MSG("Got unexpected cast: ~p", [Msg]), + noreply(State). + +handle_info(Msg, #state{stop_reason = {resumed, Pid} = Reason} = State) -> + case Msg of + {#publish{}, _} -> + ?DEBUG("Relaying delayed publish to ~p at ~s", [Pid, node(Pid)]), + ejabberd_cluster:send(Pid, Msg), + noreply(State); + timeout -> + stop(State, Reason); + _ -> + noreply(State) + end; +handle_info({#publish{meta = Meta} = Pkt, ExpiryTime}, State) -> + ID = next_id(State#state.id), + Meta1 = Meta#{expiry_time => ExpiryTime}, + Pkt1 = Pkt#publish{id = ID, meta = Meta1}, + State1 = State#state{id = ID}, + case send(State1, Pkt1) of + {ok, State2} -> noreply(State2); + {error, State2, Reason} -> stop(State2, Reason) + end; +handle_info({tcp, TCPSock, TCPData}, + #state{codec = Codec, socket = Socket} = State) -> + case recv_data(Socket, TCPData) of + {ok, Data} -> + case mqtt_codec:decode(Codec, Data) of + {ok, Pkt, Codec1} -> + ?DEBUG("Got MQTT packet:~n~s", [pp(Pkt)]), + State1 = State#state{codec = Codec1}, + case handle_packet(Pkt, State1) of + {ok, State2} -> + handle_info({tcp, TCPSock, <<>>}, State2); + {error, State2, Reason} -> + stop(State2, Reason) + end; + {more, Codec1} -> + State1 = State#state{codec = Codec1}, + State2 = reset_keep_alive(State1), + activate(Socket), + noreply(State2); + {error, Why} -> + stop(State, {codec, Why}) + end; + {error, Why} -> + stop(State, Why) + end; +handle_info({tcp_closed, _Sock}, State) -> + ?DEBUG("MQTT connection reset by peer", []), + stop(State, {socket, closed}); +handle_info({tcp_error, _Sock, Reason}, State) -> + ?DEBUG("MQTT connection error: ~s", [format_inet_error(Reason)]), + stop(State, {socket, Reason}); +handle_info(timeout, #state{socket = Socket} = State) -> + case Socket of + undefined -> + ?DEBUG("MQTT session expired", []), + stop(State#state{session_expiry = 0}, session_expired); + _ -> + ?DEBUG("MQTT connection timed out", []), + stop(State, idle_connection) + end; +handle_info({replaced, Pid}, State) -> + stop(State#state{session_expiry = 0}, {replaced, Pid}); +handle_info({timeout, _TRef, publish_will}, State) -> + noreply(publish_will(State)); +handle_info({Ref, badarg}, State) when is_reference(Ref) -> + %% TODO: figure out from where this messages comes from + noreply(State); +handle_info(Info, State) -> + ?WARNING_MSG("Got unexpected info: ~p", [Info]), + noreply(State). + +-spec handle_packet(mqtt_packet(), state()) -> {ok, state()} | + {error, state(), error_reason()}. +handle_packet(#connect{proto_level = Version} = Pkt, State) -> + handle_connect(Pkt, State#state{version = Version}); +handle_packet(#publish{} = Pkt, State) -> + handle_publish(Pkt, State); +handle_packet(#puback{id = ID}, #state{in_flight = #publish{qos = 1, id = ID}} = State) -> + resend(State#state{in_flight = undefined}); +handle_packet(#puback{id = ID, code = Code}, State) -> + ?DEBUG("Ignoring unexpected PUBACK with id=~B and code '~s'", [ID, Code]), + {ok, State}; +handle_packet(#pubrec{id = ID, code = Code}, + #state{in_flight = #publish{qos = 2, id = ID}} = State) -> + case mqtt_codec:is_error_code(Code) of + true -> + ?DEBUG("Got PUBREC with error code '~s', " + "aborting acknowledgement", [Code]), + resend(State#state{in_flight = undefined}); + false -> + Pubrel = #pubrel{id = ID}, + send(State#state{in_flight = Pubrel}, Pubrel) + end; +handle_packet(#pubrec{id = ID, code = Code}, State) -> + case mqtt_codec:is_error_code(Code) of + true -> + ?DEBUG("Ignoring unexpected PUBREC with id=~B and code '~s'", + [ID, Code]), + {ok, State}; + false -> + Code1 = 'packet-identifier-not-found', + ?DEBUG("Got unexpected PUBREC with id=~B, " + "sending PUBREL with error code '~s'", [ID, Code1]), + send(State, #pubrel{id = ID, code = Code1}) + end; +handle_packet(#pubcomp{id = ID}, #state{in_flight = #pubrel{id = ID}} = State) -> + resend(State#state{in_flight = undefined}); +handle_packet(#pubcomp{id = ID}, State) -> + ?DEBUG("Ignoring unexpected PUBCOMP with id=~B: most likely " + "it's a repeated response to duplicated PUBREL", [ID]), + {ok, State}; +handle_packet(#pubrel{id = ID}, State) -> + case maps:take(ID, State#state.acks) of + {_, Acks} -> + send(State#state{acks = Acks}, #pubcomp{id = ID}); + error -> + Code = 'packet-identifier-not-found', + ?DEBUG("Got unexpected PUBREL with id=~B, " + "sending PUBCOMP with error code '~s'", [ID, Code]), + Pubcomp = #pubcomp{id = ID, code = Code}, + send(State, Pubcomp) + end; +handle_packet(#subscribe{} = Pkt, State) -> + handle_subscribe(Pkt, State); +handle_packet(#unsubscribe{} = Pkt, State) -> + handle_unsubscribe(Pkt, State); +handle_packet(#pingreq{}, State) -> + send(State, #pingresp{}); +handle_packet(#disconnect{properties = #{session_expiry_interval := SE}}, + #state{session_expiry = 0} = State) when SE>0 -> + %% Protocol violation + {error, State, session_expiry_non_zero}; +handle_packet(#disconnect{code = Code, properties = Props}, + #state{jid = #jid{lserver = Server}} = State) -> + Reason = maps:get(reason_string, Props, <<>>), + Expiry = case maps:get(session_expiry_interval, Props, undefined) of + undefined -> State#state.session_expiry; + SE -> min(SE, session_expiry(Server)) + end, + State1 = State#state{session_expiry = Expiry}, + State2 = case Code of + 'normal-disconnection' -> State1#state{will = undefined}; + _ -> State1 + end, + {error, State2, {peer_disconnected, Code, Reason}}; +handle_packet(Pkt, State) -> + ?WARNING_MSG("Unexpected packet:~n~s~n** when state:~n~s", + [pp(Pkt), pp(State)]), + {error, State, {unexpected_packet, element(1, Pkt)}}. + +terminate(_, #state{peername = undefined}) -> + ok; +terminate(Reason, State) -> + Reason1 = case Reason of + shutdown -> shutdown; + {shutdown, _} -> shutdown; + normal -> State#state.stop_reason; + {process_limit, _} -> queue_full; + _ -> internal_server_error + end, + case State#state.jid of + #jid{} -> unregister_session(State, Reason1); + undefined -> log_disconnection(State, Reason1) + end, + State1 = disconnect(State, Reason1), + publish_will(State1). + +code_change(_OldVsn, State, _Extra) -> + {ok, upgrade_state(State)}. + +%%%=================================================================== +%%% State transitions +%%%=================================================================== +-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. +noreply(#state{timeout = infinity} = State) -> + {noreply, State, infinity}; +noreply(#state{timeout = {MSecs, StartTime}} = State) -> + CurrentTime = current_time(), + Timeout = max(0, MSecs - CurrentTime + StartTime), + {noreply, State, Timeout}. + +-spec stop(state(), error_reason()) -> {noreply, state(), infinity} | + {stop, normal, state()}. +stop(#state{session_expiry = 0} = State, Reason) -> + {stop, normal, State#state{stop_reason = Reason}}; +stop(#state{session_expiry = SessExp} = State, Reason) -> + case State#state.socket of + undefined -> + noreply(State); + _ -> + WillDelay = State#state.will_delay, + log_disconnection(State, Reason), + State1 = disconnect(State, Reason), + State2 = if WillDelay == 0 -> + publish_will(State1); + WillDelay < SessExp -> + erlang:start_timer( + timer:seconds(WillDelay), self(), publish_will), + State1; + true -> + State1 + end, + State3 = set_timeout(State2, timer:seconds(SessExp)), + State4 = State3#state{stop_reason = Reason}, + noreply(State4) + end. + +-spec upgrade_state(term()) -> state(). +upgrade_state(State) -> + %% Here will be the code upgrading state between different + %% code versions. This is needed when doing session resumption from + %% remote node running the version of the code with incompatible #state{} + %% record fields. Also used by code_change/3 callback. + %% Use element(2, State) for vsn comparison. + State. + +%%%=================================================================== +%%% Session management +%%%=================================================================== +-spec open_session(state(), jid(), boolean()) -> {ok, boolean(), state()} | + {error, state(), error_reason()}. +open_session(State, JID, _CleanStart = false) -> + USR = {_, S, _} = jid:tolower(JID), + case mod_mqtt:lookup_session(USR) of + {ok, Pid} -> + try p1_server:call(Pid, {get_state, self()}, ?CALL_TIMEOUT) of + {ok, State1} -> + State2 = upgrade_state(State1), + Q1 = case queue_type(S) of + ram -> State2#state.queue; + _ -> p1_queue:ram_to_file(State2#state.queue) + end, + Q2 = p1_queue:set_limit(Q1, queue_limit(S)), + State3 = State#state{queue = Q2, + acks = State2#state.acks, + subscriptions = State2#state.subscriptions, + id = State2#state.id, + in_flight = State2#state.in_flight}, + ?DEBUG("Resumed state from ~p at ~s:~n~s", + [Pid, node(Pid), pp(State3)]), + register_session(State3, JID, Pid); + {error, Why} -> + {error, State, Why} + catch exit:{Why, {p1_server, _, _}} -> + ?WARNING_MSG("Failed to copy session state from ~p at ~s: ~s", + [Pid, node(Pid), format_exit_reason(Why)]), + register_session(State, JID, undefined) + end; + {error, notfound} -> + register_session(State, JID, undefined); + {error, Why} -> + {error, State, Why} + end; +open_session(State, JID, _CleanStart = true) -> + register_session(State, JID, undefined). + +-spec register_session(state(), jid(), undefined | pid()) -> + {ok, boolean(), state()} | {error, state(), error_reason()}. +register_session(#state{peername = IP} = State, JID, Parent) -> + USR = {_, S, _} = jid:tolower(JID), + case mod_mqtt:open_session(USR) of + ok -> + case resubscribe(USR, State#state.subscriptions) of + ok -> + ?INFO_MSG("~s for ~s from ~s", + [if is_pid(Parent) -> + io_lib:format( + "Reopened MQTT session via ~p", + [Parent]); + true -> + "Opened MQTT session" + end, + jid:encode(JID), + ejabberd_config:may_hide_data( + misc:ip_to_list(IP))]), + Q = case State#state.queue of + undefined -> + p1_queue:new(queue_type(S), queue_limit(S)); + Q1 -> + Q1 + end, + {ok, is_pid(Parent), State#state{jid = JID, queue = Q}}; + {error, Why} -> + mod_mqtt:close_session(USR), + {error, State#state{session_expiry = 0}, Why} + end; + {error, Reason} -> + ?ERROR_MSG("Failed to register MQTT session for ~s from ~s: ~s", + err_args(JID, IP, Reason)), + {error, State, Reason} + end. + +-spec unregister_session(state(), error_reason()) -> ok. +unregister_session(#state{jid = #jid{} = JID, peername = IP} = State, Reason) -> + Msg = "Closing MQTT session for ~s from ~s: ~s", + case Reason of + {Tag, _} when Tag == replaced; Tag == resumed -> + ?DEBUG(Msg, err_args(JID, IP, Reason)); + {socket, _} -> + ?INFO_MSG(Msg, err_args(JID, IP, Reason)); + Tag when Tag == idle_connection; Tag == session_expired; Tag == shutdown -> + ?INFO_MSG(Msg, err_args(JID, IP, Reason)); + {peer_disconnected, Code, _} -> + case mqtt_codec:is_error_code(Code) of + true -> ?WARNING_MSG(Msg, err_args(JID, IP, Reason)); + false -> ?INFO_MSG(Msg, err_args(JID, IP, Reason)) + end; + _ -> + ?WARNING_MSG(Msg, err_args(JID, IP, Reason)) + end, + USR = jid:tolower(JID), + unsubscribe(maps:keys(State#state.subscriptions), USR, #{}), + case mod_mqtt:close_session(USR) of + ok -> ok; + {error, Why} -> + ?ERROR_MSG( + "Failed to close MQTT session for ~s from ~s: ~s", + err_args(JID, IP, Why)) + end; +unregister_session(_, _) -> + ok. + +%%%=================================================================== +%%% CONNECT/PUBLISH/SUBSCRIBE/UNSUBSCRIBE handlers +%%%=================================================================== +-spec handle_connect(connect(), state()) -> {ok, state()} | + {error, state(), error_reason()}. +handle_connect(#connect{clean_start = CleanStart} = Pkt, + #state{jid = undefined, peername = IP} = State) -> + case authenticate(Pkt, IP) of + {ok, JID} -> + case validate_will(Pkt, JID) of + ok -> + case open_session(State, JID, CleanStart) of + {ok, SessionPresent, State1} -> + State2 = set_session_properties(State1, Pkt), + ConnackProps = get_connack_properties(State2, Pkt), + Connack = #connack{session_present = SessionPresent, + properties = ConnackProps}, + case send(State2, Connack) of + {ok, State3} -> resend(State3); + {error, _, _} = Err -> Err + end; + {error, _, _} = Err -> + Err + end; + {error, Reason} -> + {error, State, Reason} + end; + {error, Code} -> + {error, State, {auth, Code}} + end. + +-spec handle_publish(publish(), state()) -> {ok, state()} | + {error, state(), error_reason()}. +handle_publish(#publish{qos = QoS, id = ID} = Publish, State) -> + case QoS == 2 andalso maps:is_key(ID, State#state.acks) of + true -> + send(State, maps:get(ID, State#state.acks)); + false -> + case validate_publish(Publish, State) of + ok -> + State1 = store_topic_alias(State, Publish), + Ret = publish(State1, Publish), + {Code, Props} = get_publish_code_props(Ret), + case Ret of + {ok, _} when QoS == 2 -> + Pkt = #pubrec{id = ID, code = Code, + properties = Props}, + Acks = maps:put(ID, Pkt, State1#state.acks), + State2 = State1#state{acks = Acks}, + send(State2, Pkt); + {error, _} when QoS == 2 -> + Pkt = #pubrec{id = ID, code = Code, + properties = Props}, + send(State1, Pkt); + _ when QoS == 1 -> + Pkt = #puback{id = ID, code = Code, + properties = Props}, + send(State1, Pkt); + _ -> + {ok, State1} + end; + {error, Why} -> + {error, State, Why} + end + end. + +-spec handle_subscribe(subscribe(), state()) -> + {ok, state()} | {error, state(), error_reason()}. +handle_subscribe(#subscribe{id = ID, filters = TopicFilters} = Pkt, State) -> + case validate_subscribe(Pkt) of + ok -> + USR = jid:tolower(State#state.jid), + SubID = maps:get(subscription_identifier, Pkt#subscribe.properties, 0), + OldSubs = State#state.subscriptions, + {Codes, NewSubs, Props} = subscribe(TopicFilters, USR, SubID), + Subs = maps:merge(OldSubs, NewSubs), + State1 = State#state{subscriptions = Subs}, + Suback = #suback{id = ID, codes = Codes, properties = Props}, + case send(State1, Suback) of + {ok, State2} -> + Pubs = select_retained(USR, NewSubs, OldSubs), + send_retained(State2, Pubs); + {error, _, _} = Err -> + Err + end; + {error, Why} -> + {error, State, Why} + end. + +-spec handle_unsubscribe(unsubscribe(), state()) -> + {ok, state()} | {error, state(), error_reason()}. +handle_unsubscribe(#unsubscribe{id = ID, filters = TopicFilters}, State) -> + USR = jid:tolower(State#state.jid), + {Codes, Subs, Props} = unsubscribe(TopicFilters, USR, State#state.subscriptions), + State1 = State#state{subscriptions = Subs}, + Unsuback = #unsuback{id = ID, codes = Codes, properties = Props}, + send(State1, Unsuback). + +%%%=================================================================== +%%% Aux functions for CONNECT/PUBLISH/SUBSCRIBE/UNSUBSCRIBE handlers +%%%=================================================================== +-spec set_session_properties(state(), connect()) -> state(). +set_session_properties(#state{version = Version, + jid = #jid{lserver = Server}} = State, + #connect{clean_start = CleanStart, + keep_alive = KeepAlive, + properties = Props} = Pkt) -> + SEMin = case CleanStart of + false when Version == ?MQTT_VERSION_4 -> infinity; + _ -> maps:get(session_expiry_interval, Props, 0) + end, + SEConfig = session_expiry(Server), + State1 = State#state{session_expiry = min(SEMin, SEConfig)}, + State2 = set_will_properties(State1, Pkt), + set_keep_alive(State2, KeepAlive). + +-spec set_will_properties(state(), connect()) -> state(). +set_will_properties(State, #connect{will = #publish{} = Will, + will_properties = Props}) -> + {WillDelay, Props1} = case maps:take(will_delay_interval, Props) of + error -> {0, Props}; + Ret -> Ret + end, + State#state{will = Will#publish{properties = Props1}, + will_delay = WillDelay}; +set_will_properties(State, _) -> + State. + +-spec get_connack_properties(state(), connect()) -> properties(). +get_connack_properties(#state{session_expiry = SessExp, jid = JID}, + #connect{client_id = ClientID, + keep_alive = KeepAlive}) -> + Props1 = case ClientID of + <<>> -> #{assigned_client_identifier => JID#jid.lresource}; + _ -> #{} + end, + Props1#{session_expiry_interval => SessExp, + shared_subscription_available => false, + topic_alias_maximum => topic_alias_maximum(JID#jid.lserver), + server_keep_alive => KeepAlive}. + +-spec subscribe([{binary(), sub_opts()}], jid:ljid(), non_neg_integer()) -> + {[reason_code()], map(), properties()}. +subscribe(TopicFilters, USR, SubID) -> + subscribe(TopicFilters, USR, SubID, [], #{}, ok). + +-spec subscribe([{binary(), sub_opts()}], jid:ljid(), non_neg_integer(), + [reason_code()], map(), ok | {error, error_reason()}) -> + {[reason_code()], map(), properties()}. +subscribe([{TopicFilter, SubOpts}|TopicFilters], USR, SubID, Codes, Subs, Err) -> + case mod_mqtt:subscribe(USR, TopicFilter, SubOpts, SubID) of + ok -> + Code = subscribe_reason_code(SubOpts#sub_opts.qos), + subscribe(TopicFilters, USR, SubID, [Code|Codes], + maps:put(TopicFilter, {SubOpts, SubID}, Subs), Err); + {error, Why} = Err1 -> + Code = subscribe_reason_code(Why), + subscribe(TopicFilters, USR, SubID, [Code|Codes], Subs, Err1) + end; +subscribe([], _USR, _SubID, Codes, Subs, Err) -> + Props = case Err of + ok -> #{}; + {error, Why} -> + #{reason_string => format_reason_string(Why)} + end, + {lists:reverse(Codes), Subs, Props}. + +-spec unsubscribe([binary()], jid:ljid(), map()) -> + {[reason_code()], map(), properties()}. +unsubscribe(TopicFilters, USR, Subs) -> + unsubscribe(TopicFilters, USR, [], Subs, ok). + +-spec unsubscribe([binary()], jid:ljid(), + [reason_code()], map(), + ok | {error, error_reason()}) -> + {[reason_code()], map(), properties()}. +unsubscribe([TopicFilter|TopicFilters], USR, Codes, Subs, Err) -> + case mod_mqtt:unsubscribe(USR, TopicFilter) of + ok -> + unsubscribe(TopicFilters, USR, [success|Codes], + maps:remove(TopicFilter, Subs), Err); + {error, notfound} -> + unsubscribe(TopicFilters, USR, + ['no-subscription-existed'|Codes], + maps:remove(TopicFilter, Subs), Err); + {error, Why} = Err1 -> + Code = unsubscribe_reason_code(Why), + unsubscribe(TopicFilters, USR, [Code|Codes], Subs, Err1) + end; +unsubscribe([], _USR, Codes, Subs, Err) -> + Props = case Err of + ok -> #{}; + {error, Why} -> + #{reason_string => format_reason_string(Why)} + end, + {lists:reverse(Codes), Subs, Props}. + +-spec select_retained(jid:ljid(), map(), map()) -> [{publish(), seconds()}]. +select_retained(USR, NewSubs, OldSubs) -> + lists:flatten( + maps:fold( + fun(_Filter, {#sub_opts{retain_handling = 2}, _SubID}, Acc) -> + Acc; + (Filter, {#sub_opts{retain_handling = 1, qos = QoS}, SubID}, Acc) -> + case maps:is_key(Filter, OldSubs) of + true -> Acc; + false -> [mod_mqtt:select_retained(USR, Filter, QoS, SubID)|Acc] + end; + (Filter, {#sub_opts{qos = QoS}, SubID}, Acc) -> + [mod_mqtt:select_retained(USR, Filter, QoS, SubID)|Acc] + end, [], NewSubs)). + +-spec send_retained(state(), [{publish(), seconds()}]) -> + {ok, state()} | {error, state(), error_reason()}. +send_retained(State, [{#publish{meta = Meta} = Pub, Expiry}|Pubs]) -> + I = next_id(State#state.id), + Meta1 = Meta#{expiry_time => Expiry}, + Pub1 = Pub#publish{id = I, retain = true, meta = Meta1}, + case send(State#state{id = I}, Pub1) of + {ok, State1} -> + send_retained(State1, Pubs); + Err -> + Err + end; +send_retained(State, []) -> + {ok, State}. + +-spec publish(state(), publish()) -> {ok, non_neg_integer()} | + {error, error_reason()}. +publish(State, #publish{topic = Topic, properties = Props} = Pkt) -> + MessageExpiry = maps:get(message_expiry_interval, Props, ?MAX_UINT32), + ExpiryTime = min(unix_time() + MessageExpiry, ?MAX_UINT32), + USR = jid:tolower(State#state.jid), + Props1 = maps:filter( + fun(payload_format_indicator, _) -> true; + (content_type, _) -> true; + (response_topic, _) -> true; + (correlation_data, _) -> true; + (user_property, _) -> true; + (_, _) -> false + end, Props), + Topic1 = case Topic of + <<>> -> + Alias = maps:get(topic_alias, Props), + maps:get(Alias, State#state.topic_aliases); + _ -> + Topic + end, + Pkt1 = Pkt#publish{topic = Topic1, properties = Props1}, + mod_mqtt:publish(USR, Pkt1, ExpiryTime). + +-spec store_topic_alias(state(), publish()) -> state(). +store_topic_alias(State, #publish{topic = <<_, _/binary>> = Topic, + properties = #{topic_alias := Alias}}) -> + Aliases = maps:put(Alias, Topic, State#state.topic_aliases), + State#state{topic_aliases = Aliases}; +store_topic_alias(State, _) -> + State. + +%%%=================================================================== +%%% Socket management +%%%=================================================================== +-spec send(state(), mqtt_packet()) -> {ok, state()} | + {error, state(), error_reason()}. +send(State, #publish{} = Pkt) -> + case is_expired(Pkt) of + {false, Pkt1} -> + case State#state.in_flight == undefined andalso + p1_queue:is_empty(State#state.queue) of + true -> + Dup = case Pkt1#publish.qos of + 0 -> undefined; + _ -> Pkt1 + end, + State1 = State#state{in_flight = Dup}, + {ok, do_send(State1, Pkt1)}; + false -> + ?DEBUG("Queueing packet:~n~s~n** when state:~n~s", + [pp(Pkt), pp(State)]), + try p1_queue:in(Pkt, State#state.queue) of + Q -> + State1 = State#state{queue = Q}, + {ok, State1} + catch error:full -> + Q = p1_queue:clear(State#state.queue), + State1 = State#state{queue = Q, session_expiry = 0}, + {error, State1, queue_full} + end + end; + true -> + {ok, State} + end; +send(State, Pkt) -> + {ok, do_send(State, Pkt)}. + +-spec resend(state()) -> {ok, state()} | {error, state(), error_reason()}. +resend(#state{in_flight = undefined} = State) -> + case p1_queue:out(State#state.queue) of + {{value, #publish{qos = QoS} = Pkt}, Q} -> + case is_expired(Pkt) of + true -> + resend(State#state{queue = Q}); + {false, Pkt1} when QoS > 0 -> + State1 = State#state{in_flight = Pkt1, queue = Q}, + {ok, do_send(State1, Pkt1)}; + {false, Pkt1} -> + State1 = do_send(State#state{queue = Q}, Pkt1), + resend(State1) + end; + {empty, _} -> + {ok, State} + end; +resend(#state{in_flight = Pkt} = State) -> + {ok, do_send(State, set_dup_flag(Pkt))}. + +-spec do_send(state(), mqtt_packet()) -> state(). +do_send(#state{socket = {SockMod, Sock} = Socket} = State, Pkt) -> + ?DEBUG("Send MQTT packet:~n~s", [pp(Pkt)]), + Data = mqtt_codec:encode(State#state.version, Pkt), + Res = SockMod:send(Sock, Data), + check_sock_result(Socket, Res), + State; +do_send(State, _Pkt) -> + State. + +-spec activate(socket()) -> ok. +activate({SockMod, Sock} = Socket) -> + Res = case SockMod of + gen_tcp -> inet:setopts(Sock, [{active, once}]); + _ -> SockMod:setopts(Sock, [{active, once}]) + end, + check_sock_result(Socket, Res). + +-spec disconnect(state(), error_reason()) -> state(). +disconnect(#state{socket = {SockMod, Sock}} = State, Err) -> + State1 = case Err of + {auth, Code} -> + do_send(State, #connack{code = Code}); + {codec, {Tag, _, _}} when Tag == unsupported_protocol_version; + Tag == unsupported_protocol_name -> + do_send(State#state{version = ?MQTT_VERSION_4}, + #connack{code = connack_reason_code(Err)}); + _ when State#state.version == undefined -> + State; + {Tag, _} when Tag == socket; Tag == tls -> + State; + {peer_disconnected, _, _} -> + State; + _ -> + Props = #{reason_string => format_reason_string(Err)}, + case State#state.jid of + undefined -> + Code = connack_reason_code(Err), + Pkt = #connack{code = Code, properties = Props}, + do_send(State, Pkt); + _ when State#state.version == ?MQTT_VERSION_5 -> + Code = disconnect_reason_code(Err), + Pkt = #disconnect{code = Code, properties = Props}, + do_send(State, Pkt); + _ -> + State + end + end, + SockMod:close(Sock), + State1#state{socket = undefined, + version = undefined, + codec = mqtt_codec:renew(State#state.codec)}; +disconnect(State, _) -> + State. + +-spec check_sock_result(socket(), ok | {error, inet:posix()}) -> ok. +check_sock_result(_, ok) -> + ok; +check_sock_result({_, Sock}, {error, Why}) -> + self() ! {tcp_closed, Sock}, + ?DEBUG("MQTT socket error: ~p", [format_inet_error(Why)]). + +-spec starttls(socket()) -> {ok, socket()} | {error, error_reason()}. +starttls({fast_tls, Socket}) -> + case ejabberd_pkix:get_certfile() of + {ok, Cert} -> + case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of + {ok, TLSSock} -> + {ok, {fast_tls, TLSSock}}; + {error, Why} -> + {error, {tls, Why}} + end; + error -> + {error, {tls, no_certfile}} + end; +starttls(Socket) -> + {ok, Socket}. + +-spec recv_data(socket(), binary()) -> {ok, binary()} | {error, error_reason()}. +recv_data({fast_tls, Sock}, Data) -> + case fast_tls:recv_data(Sock, Data) of + {ok, _} = OK -> OK; + {error, E} when is_atom(E) -> {error, {socket, E}}; + {error, E} when is_binary(E) -> {error, {tls, E}}; + {error, _} = Err -> Err + end; +recv_data(_, Data) -> + {ok, Data}. + +%%%=================================================================== +%%% Formatters +%%%=================================================================== +-spec pp(any()) -> iolist(). +pp(Term) -> + io_lib_pretty:print(Term, fun pp/2). + +-spec format_inet_error(socket_error_reason()) -> string(). +format_inet_error(closed) -> + "connection closed"; +format_inet_error(timeout) -> + format_inet_error(etimedout); +format_inet_error(Reason) -> + case inet:format_error(Reason) of + "unknown POSIX error" -> atom_to_list(Reason); + Txt -> Txt + end. + +-spec format_tls_error(atom() | binary()) -> string() | binary(). +format_tls_error(no_cerfile) -> + "certificate not found"; +format_tls_error(Reason) when is_atom(Reason) -> + format_inet_error(Reason); +format_tls_error(Reason) -> + Reason. + +-spec format_exit_reason(term()) -> string(). +format_exit_reason(noproc) -> + "process is dead"; +format_exit_reason(normal) -> + "process has exited"; +format_exit_reason(killed) -> + "process has been killed"; +format_exit_reason(timeout) -> + "remote call to process timed out"; +format_exit_reason(Why) -> + format("unexpected error: ~p", [Why]). + +%% Same as format_error/1, but hides sensitive data +%% and returns result as binary +-spec format_reason_string(error_reason()) -> binary(). +format_reason_string({resumed, _}) -> + <<"Resumed by another connection">>; +format_reason_string({replaced, _}) -> + <<"Replaced by another connection">>; +format_reason_string(Err) -> + list_to_binary(format_error(Err)). + +-spec format(io:format(), list()) -> string(). +format(Fmt, Args) -> + lists:flatten(io_lib:format(Fmt, Args)). + +-spec pp(atom(), non_neg_integer()) -> [atom()] | no. +pp(state, 17) -> record_info(fields, state); +pp(Rec, Size) -> mqtt_codec:pp(Rec, Size). + +-spec publish_reason_code(error_reason()) -> reason_code(). +publish_reason_code(publish_forbidden) -> 'topic-name-invalid'; +publish_reason_code(_) -> 'implementation-specific-error'. + +-spec subscribe_reason_code(qos() | error_reason()) -> reason_code(). +subscribe_reason_code(0) -> 'granted-qos-0'; +subscribe_reason_code(1) -> 'granted-qos-1'; +subscribe_reason_code(2) -> 'granted-qos-2'; +subscribe_reason_code(subscribe_forbidden) -> 'topic-filter-invalid'; +subscribe_reason_code(_) -> 'implementation-specific-error'. + +-spec unsubscribe_reason_code(error_reason()) -> reason_code(). +unsubscribe_reason_code(_) -> 'implementation-specific-error'. + +-spec disconnect_reason_code(error_reason()) -> reason_code(). +disconnect_reason_code({code, Code}) -> Code; +disconnect_reason_code({codec, Err}) -> mqtt_codec:error_reason_code(Err); +disconnect_reason_code({unexpected_packet, _}) -> 'protocol-error'; +disconnect_reason_code({replaced, _}) -> 'session-taken-over'; +disconnect_reason_code({resumed, _}) -> 'session-taken-over'; +disconnect_reason_code(internal_server_error) -> 'implementation-specific-error'; +disconnect_reason_code(db_failure) -> 'implementation-specific-error'; +disconnect_reason_code(idle_connection) -> 'keep-alive-timeout'; +disconnect_reason_code(queue_full) -> 'quota-exceeded'; +disconnect_reason_code(shutdown) -> 'server-shutting-down'; +disconnect_reason_code(subscribe_forbidden) -> 'topic-filter-invalid'; +disconnect_reason_code(publish_forbidden) -> 'topic-name-invalid'; +disconnect_reason_code(will_topic_forbidden) -> 'topic-name-invalid'; +disconnect_reason_code({payload_format_invalid, _}) -> 'payload-format-invalid'; +disconnect_reason_code(session_expiry_non_zero) -> 'protocol-error'; +disconnect_reason_code(unknown_topic_alias) -> 'protocol-error'; +disconnect_reason_code(_) -> 'unspecified-error'. + +-spec connack_reason_code(error_reason()) -> reason_code(). +connack_reason_code({Tag, Code}) when Tag == auth; Tag == code -> Code; +connack_reason_code({codec, Err}) -> mqtt_codec:error_reason_code(Err); +connack_reason_code({unexpected_packet, _}) -> 'protocol-error'; +connack_reason_code(internal_server_error) -> 'implementation-specific-error'; +connack_reason_code(db_failure) -> 'implementation-specific-error'; +connack_reason_code(idle_connection) -> 'keep-alive-timeout'; +connack_reason_code(queue_full) -> 'quota-exceeded'; +connack_reason_code(shutdown) -> 'server-shutting-down'; +connack_reason_code(will_topic_forbidden) -> 'topic-name-invalid'; +connack_reason_code({payload_format_invalid, _}) -> 'payload-format-invalid'; +connack_reason_code(session_expiry_non_zero) -> 'protocol-error'; +connack_reason_code(_) -> 'unspecified-error'. + +%%%=================================================================== +%%% Configuration processing +%%%=================================================================== +-spec queue_type(binary()) -> ram | file. +queue_type(Host) -> + gen_mod:get_module_opt(Host, mod_mqtt, queue_type). + +-spec queue_limit(binary()) -> non_neg_integer() | unlimited. +queue_limit(Host) -> + gen_mod:get_module_opt(Host, mod_mqtt, max_queue). + +-spec session_expiry(binary()) -> seconds(). +session_expiry(Host) -> + gen_mod:get_module_opt(Host, mod_mqtt, session_expiry). + +-spec topic_alias_maximum(binary()) -> non_neg_integer(). +topic_alias_maximum(Host) -> + gen_mod:get_module_opt(Host, mod_mqtt, max_topic_aliases). + +%%%=================================================================== +%%% Timings +%%%=================================================================== +-spec current_time() -> milli_seconds(). +current_time() -> + p1_time_compat:monotonic_time(milli_seconds). + +-spec unix_time() -> seconds(). +unix_time() -> + p1_time_compat:system_time(seconds). + +-spec set_keep_alive(state(), seconds()) -> state(). +set_keep_alive(State, 0) -> + ?DEBUG("Disabling MQTT keep-alive", []), + State#state{timeout = infinity}; +set_keep_alive(State, Secs) -> + Secs1 = round(Secs * 1.5), + ?DEBUG("Setting MQTT keep-alive to ~B seconds", [Secs1]), + set_timeout(State, timer:seconds(Secs1)). + +-spec reset_keep_alive(state()) -> state(). +reset_keep_alive(#state{timeout = {MSecs, _}, jid = #jid{}} = State) -> + set_timeout(State, MSecs); +reset_keep_alive(State) -> + State. + +-spec set_timeout(state(), milli_seconds()) -> state(). +set_timeout(State, MSecs) -> + Time = current_time(), + State#state{timeout = {MSecs, Time}}. + +-spec is_expired(publish()) -> true | {false, publish()}. +is_expired(#publish{meta = Meta, properties = Props} = Pkt) -> + case maps:get(expiry_time, Meta, ?MAX_UINT32) of + ?MAX_UINT32 -> + {false, Pkt}; + ExpiryTime -> + Left = ExpiryTime - unix_time(), + if Left > 0 -> + Props1 = Props#{message_expiry_interval => Left}, + {false, Pkt#publish{properties = Props1}}; + true -> + ?DEBUG("Dropping expired packet:~n~s", [pp(Pkt)]), + true + end + end. + +%%%=================================================================== +%%% Authentication +%%%=================================================================== +-spec parse_credentials(connect()) -> {ok, jid:jid()} | {error, reason_code()}. +parse_credentials(#connect{client_id = <<>>}) -> + parse_credentials(#connect{client_id = p1_rand:get_string()}); +parse_credentials(#connect{username = <<>>, client_id = ClientID}) -> + Host = ejabberd_config:get_myname(), + JID = case jid:make(ClientID, Host) of + error -> jid:make(str:sha(ClientID), Host); + J -> J + end, + parse_credentials(JID, ClientID); +parse_credentials(#connect{username = User} = Pkt) -> + try jid:decode(User) of + #jid{luser = <<>>} -> + case jid:make(User, ejabberd_config:get_myname()) of + error -> + {error, 'bad-user-name-or-password'}; + JID -> + parse_credentials(JID, Pkt#connect.client_id) + end; + JID -> + parse_credentials(JID, Pkt#connect.client_id) + catch _:{bad_jid, _} -> + {error, 'bad-user-name-or-password'} + end. + +-spec parse_credentials(jid:jid(), binary()) -> {ok, jid:jid()} | {error, reason_code()}. +parse_credentials(JID, ClientID) -> + case gen_mod:is_loaded(JID#jid.lserver, mod_mqtt) of + false -> + {error, 'server-unavailable'}; + true -> + case jid:replace_resource(JID, ClientID) of + error -> + {error, 'client-identifier-not-valid'}; + JID1 -> + {ok, JID1} + end + end. + +-spec authenticate(connect(), peername()) -> {ok, jid:jid()} | {error, reason_code()}. +authenticate(#connect{password = Pass} = Pkt, IP) -> + case parse_credentials(Pkt) of + {ok, #jid{luser = LUser, lserver = LServer} = JID} -> + case ejabberd_auth:check_password_with_authmodule( + LUser, <<>>, LServer, Pass) of + {true, AuthModule} -> + ?INFO_MSG( + "Accepted MQTT authentication for ~s " + "by ~s backend from ~s", + [jid:encode(JID), + ejabberd_auth:backend_type(AuthModule), + ejabberd_config:may_hide_data(misc:ip_to_list(IP))]), + {ok, JID}; + false -> + {error, 'not-authorized'} + end; + {error, _} = Err -> + Err + end. + +%%%=================================================================== +%%% Validators +%%%=================================================================== +-spec validate_will(connect(), jid:jid()) -> ok | {error, reason_code()}. +validate_will(#connect{will = undefined}, _) -> + ok; +validate_will(#connect{will = #publish{topic = Topic, payload = Payload}, + will_properties = Props}, JID) -> + case mod_mqtt:check_publish_access(Topic, jid:tolower(JID)) of + deny -> {error, will_topic_forbidden}; + allow -> validate_payload(Props, Payload, will) + end. + +-spec validate_publish(publish(), state()) -> ok | {error, error_reason()}. +validate_publish(#publish{topic = Topic, payload = Payload, + properties = Props}, State) -> + case validate_topic(Topic, Props, State) of + ok -> validate_payload(Props, Payload, publish); + Err -> Err + end. + +-spec validate_subscribe(subscribe()) -> ok | {error, error_reason()}. +validate_subscribe(#subscribe{filters = Filters}) -> + case lists:any( + fun({<<"$share/", _/binary>>, _}) -> true; + (_) -> false + end, Filters) of + true -> + {error, {code, 'shared-subscriptions-not-supported'}}; + false -> + ok + end. + +-spec validate_topic(binary(), properties(), state()) -> ok | {error, error_reason()}. +validate_topic(<<>>, Props, State) -> + case maps:get(topic_alias, Props, 0) of + 0 -> + {error, {code, 'topic-alias-invalid'}}; + Alias -> + case maps:is_key(Alias, State#state.topic_aliases) of + true -> ok; + false -> {error, unknown_topic_alias} + end + end; +validate_topic(_, #{topic_alias := Alias}, State) -> + JID = State#state.jid, + Max = topic_alias_maximum(JID#jid.lserver), + if Alias > Max -> + {error, {code, 'topic-alias-invalid'}}; + true -> + ok + end; +validate_topic(_, _, _) -> + ok. + +-spec validate_payload(properties(), binary(), will | publish) -> ok | {error, error_reason()}. +validate_payload(#{payload_format_indicator := utf8}, Payload, Type) -> + try mqtt_codec:utf8(Payload) of + _ -> ok + catch _:_ -> + {error, {payload_format_invalid, Type}} + end; +validate_payload(_, _, _) -> + ok. + +%%%=================================================================== +%%% Misc +%%%=================================================================== +-spec resubscribe(jid:ljid(), map()) -> ok | {error, error_reason()}. +resubscribe(USR, Subs) -> + case maps:fold( + fun(TopicFilter, {SubOpts, ID}, ok) -> + mod_mqtt:subscribe(USR, TopicFilter, SubOpts, ID); + (_, _, {error, _} = Err) -> + Err + end, ok, Subs) of + ok -> + ok; + {error, _} = Err1 -> + unsubscribe(maps:keys(Subs), USR, #{}), + Err1 + end. + +-spec publish_will(state()) -> state(). +publish_will(#state{will = #publish{} = Will, + jid = #jid{} = JID} = State) -> + case publish(State, Will) of + {ok, _} -> + ?DEBUG("Will of ~s has been published to ~s", + [jid:encode(JID), Will#publish.topic]); + {error, Why} -> + ?WARNING_MSG("Failed to publish will of ~s to ~s: ~s", + [jid:encode(JID), Will#publish.topic, + format_error(Why)]) + end, + State#state{will = undefined}; +publish_will(State) -> + State. + +-spec next_id(non_neg_integer()) -> pos_integer(). +next_id(ID) -> + (ID rem 65535) + 1. + +-spec set_dup_flag(mqtt_packet()) -> mqtt_packet(). +set_dup_flag(#publish{qos = QoS} = Pkt) when QoS>0 -> + Pkt#publish{dup = true}; +set_dup_flag(Pkt) -> + Pkt. + +-spec get_publish_code_props({ok, non_neg_integer()} | + {error, error_reason()}) -> {reason_code(), properties()}. +get_publish_code_props({ok, 0}) -> + {'no-matching-subscribers', #{}}; +get_publish_code_props({ok, _}) -> + {success, #{}}; +get_publish_code_props({error, Err}) -> + Code = publish_reason_code(Err), + Reason = format_reason_string(Err), + {Code, #{reason_string => Reason}}. + +-spec err_args(undefined | jid:jid(), peername(), error_reason()) -> iolist(). +err_args(undefined, IP, Reason) -> + [ejabberd_config:may_hide_data(misc:ip_to_list(IP)), + format_error(Reason)]; +err_args(JID, IP, Reason) -> + [jid:encode(JID), + ejabberd_config:may_hide_data(misc:ip_to_list(IP)), + format_error(Reason)]. + +-spec log_disconnection(state(), error_reason()) -> ok. +log_disconnection(#state{jid = JID, peername = IP}, Reason) -> + Msg = case JID of + undefined -> "Rejected MQTT connection from ~s: ~s"; + _ -> "Closing MQTT connection for ~s from ~s: ~s" + end, + case Reason of + {Tag, _} when Tag == replaced; Tag == resumed; Tag == socket -> + ?DEBUG(Msg, err_args(JID, IP, Reason)); + idle_connection -> + ?DEBUG(Msg, err_args(JID, IP, Reason)); + Tag when Tag == session_expired; Tag == shutdown -> + ?INFO_MSG(Msg, err_args(JID, IP, Reason)); + {peer_disconnected, Code, _} -> + case mqtt_codec:is_error_code(Code) of + true -> ?WARNING_MSG(Msg, err_args(JID, IP, Reason)); + false -> ?DEBUG(Msg, err_args(JID, IP, Reason)) + end; + _ -> + ?WARNING_MSG(Msg, err_args(JID, IP, Reason)) + end. diff --git a/src/mod_mqtt_sql.erl b/src/mod_mqtt_sql.erl new file mode 100644 index 000000000..a11f8e04c --- /dev/null +++ b/src/mod_mqtt_sql.erl @@ -0,0 +1,151 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2002-2019 ProcessOne, SARL. All Rights Reserved. +%%% +%%% Licensed under the Apache License, Version 2.0 (the "License"); +%%% you may not use this file except in compliance with the License. +%%% You may obtain a copy of the License at +%%% +%%% http://www.apache.org/licenses/LICENSE-2.0 +%%% +%%% Unless required by applicable law or agreed to in writing, software +%%% distributed under the License is distributed on an "AS IS" BASIS, +%%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%%% See the License for the specific language governing permissions and +%%% limitations under the License. +%%% +%%%------------------------------------------------------------------- +-module(mod_mqtt_sql). +-behaviour(mod_mqtt). +-compile([{parse_transform, ejabberd_sql_pt}]). + +%% API +-export([init/2, publish/6, delete_published/2, lookup_published/2]). +-export([list_topics/1]). +%% Unsupported backend API +-export([init/0]). +-export([subscribe/4, unsubscribe/2, find_subscriber/2]). +-export([open_session/1, close_session/1, lookup_session/1]). + +-include("logger.hrl"). +-include("ejabberd_sql_pt.hrl"). + +%%%=================================================================== +%%% API +%%%=================================================================== +init() -> + ?ERROR_MSG("Backend 'sql' is only supported for db_type", []), + {error, db_failure}. + +init(_Host, _Opts) -> + ok. + +publish({U, LServer, R}, Topic, Payload, QoS, Props, ExpiryTime) -> + PayloadFormat = encode_pfi(maps:get(payload_format_indicator, Props, binary)), + ResponseTopic = maps:get(response_topic, Props, <<"">>), + CorrelationData = maps:get(correlation_data, Props, <<"">>), + ContentType = maps:get(content_type, Props, <<"">>), + UserProps = encode_props(maps:get(user_property, Props, [])), + case ?SQL_UPSERT(LServer, "mqtt_pub", + ["!topic=%(Topic)s", + "!server_host=%(LServer)s", + "username=%(U)s", + "resource=%(R)s", + "payload=%(Payload)s", + "qos=%(QoS)d", + "payload_format=%(PayloadFormat)d", + "response_topic=%(ResponseTopic)s", + "correlation_data=%(CorrelationData)s", + "content_type=%(ContentType)s", + "user_properties=%(UserProps)s", + "expiry=%(ExpiryTime)d"]) of + ok -> ok; + _Err -> {error, db_failure} + end. + +delete_published({_, LServer, _}, Topic) -> + case ejabberd_sql:sql_query( + LServer, + ?SQL("delete from mqtt_pub where " + "topic=%(Topic)s and %(LServer)H")) of + {updated, _} -> ok; + _Err -> {error, db_failure} + end. + +lookup_published({_, LServer, _}, Topic) -> + case ejabberd_sql:sql_query( + LServer, + ?SQL("select @(payload)s, @(qos)d, @(payload_format)d, " + "@(content_type)s, @(response_topic)s, " + "@(correlation_data)s, @(user_properties)s, @(expiry)d " + "from mqtt_pub where topic=%(Topic)s and %(LServer)H")) of + {selected, [{Payload, QoS, PayloadFormat, ContentType, + ResponseTopic, CorrelationData, EncProps, Expiry}]} -> + try decode_props(EncProps) of + UserProps -> + try decode_pfi(PayloadFormat) of + PFI -> + Props = #{payload_format_indicator => PFI, + content_type => ContentType, + response_topic => ResponseTopic, + correlation_data => CorrelationData, + user_property => UserProps}, + {ok, {Payload, QoS, Props, Expiry}} + catch _:badarg -> + ?ERROR_MSG("Malformed value of 'payload_format' column " + "for topic '~s'", [Topic]), + {error, db_failure} + end + catch _:badarg -> + ?ERROR_MSG("Malformed value of 'user_properties' column " + "for topic '~s'", [Topic]), + {error, db_failure} + end; + {selected, []} -> + {error, notfound}; + _ -> + {error, db_failure} + end. + +list_topics(LServer) -> + case ejabberd_sql:sql_query( + LServer, + ?SQL("select @(topic)s from mqtt_pub where %(LServer)H")) of + {selected, Res} -> + {ok, [Topic || {Topic} <- Res]}; + _ -> + {error, db_failure} + end. + +open_session(_) -> + erlang:nif_error(unsupported_db). + +close_session(_) -> + erlang:nif_error(unsupported_db). + +lookup_session(_) -> + erlang:nif_error(unsupported_db). + +subscribe(_, _, _, _) -> + erlang:nif_error(unsupported_db). + +unsubscribe(_, _) -> + erlang:nif_error(unsupported_db). + +find_subscriber(_, _) -> + erlang:nif_error(unsupported_db). + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +encode_pfi(binary) -> 0; +encode_pfi(utf8) -> 1. + +decode_pfi(0) -> binary; +decode_pfi(1) -> utf8. + +encode_props([]) -> <<"">>; +encode_props(L) -> term_to_binary(L). + +decode_props(<<"">>) -> []; +decode_props(Bin) -> binary_to_term(Bin). diff --git a/src/mqtt_codec.erl b/src/mqtt_codec.erl new file mode 100644 index 000000000..4cc23e1d0 --- /dev/null +++ b/src/mqtt_codec.erl @@ -0,0 +1,1402 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2002-2019 ProcessOne, SARL. All Rights Reserved. +%%% +%%% Licensed under the Apache License, Version 2.0 (the "License"); +%%% you may not use this file except in compliance with the License. +%%% You may obtain a copy of the License at +%%% +%%% http://www.apache.org/licenses/LICENSE-2.0 +%%% +%%% Unless required by applicable law or agreed to in writing, software +%%% distributed under the License is distributed on an "AS IS" BASIS, +%%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%%% See the License for the specific language governing permissions and +%%% limitations under the License. +%%% +%%%------------------------------------------------------------------- +-module(mqtt_codec). + +%% API +-export([new/1, new/2, renew/1, decode/2, encode/2]). +-export([pp/1, pp/2, format_error/1, format_reason_code/1]). +-export([error_reason_code/1, is_error_code/1]). +%% Validators +-export([topic/1, topic_filter/1, qos/1, utf8/1]). +-export([decode_varint/1]). + +-include("mqtt.hrl"). + +-define(MAX_UINT16, 65535). +-define(MAX_UINT32, 4294967295). +-define(MAX_VARINT, 268435456). + +-record(codec_state, {version :: undefined | mqtt_version(), + type :: undefined | non_neg_integer(), + flags :: undefined | non_neg_integer(), + size :: undefined | non_neg_integer(), + max_size :: pos_integer() | infinity, + buf = <<>> :: binary()}). + +-type error_reason() :: bad_varint | + {payload_too_big, integer()} | + {bad_packet_type, char()} | + {bad_packet, atom()} | + {unexpected_packet, atom()} | + {bad_reason_code, atom(), char()} | + {bad_properties, atom()} | + {bad_property, atom(), atom()} | + {duplicated_property, atom(), atom()} | + bad_will_topic_or_message | + bad_connect_username_or_password | + bad_publish_id_or_payload | + {bad_topic_filters, atom()} | + {bad_qos, char()} | + bad_topic | bad_topic_filter | bad_utf8_string | + {unsupported_protocol_name, binary(), binary()} | + {unsupported_protocol_version, char(), iodata()} | + {{bad_flag, atom()}, char(), term()} | + {{bad_flags, atom()}, char(), char()}. + +-opaque state() :: #codec_state{}. +-export_type([state/0, error_reason/0]). + +%%%=================================================================== +%%% API +%%%=================================================================== +-spec new(pos_integer() | infinity) -> state(). +new(MaxSize) -> + new(MaxSize, undefined). + +-spec new(pos_integer() | infinity, undefined | mqtt_version()) -> state(). +new(MaxSize, Version) -> + #codec_state{max_size = MaxSize, version = Version}. + +-spec renew(state()) -> state(). +renew(#codec_state{version = Version, max_size = MaxSize}) -> + #codec_state{version = Version, max_size = MaxSize}. + +-spec decode(state(), binary()) -> {ok, mqtt_packet(), state()} | + {more, state()} | + {error, error_reason()}. +decode(#codec_state{size = undefined, buf = Buf} = State, Data) -> + Buf1 = <>, + case Buf1 of + <> -> + try + case decode_varint(Data1) of + {Len, _} when Len >= State#codec_state.max_size -> + err({payload_too_big, State#codec_state.max_size}); + {Len, Data2} when size(Data2) >= Len -> + <> = Data2, + Version = State#codec_state.version, + Pkt = decode_pkt(Version, Type, Flags, Payload), + State1 = case Pkt of + #connect{proto_level = V} -> + State#codec_state{version = V}; + _ -> + State + end, + {ok, Pkt, State1#codec_state{buf = Data3}}; + {Len, Data2} -> + {more, State#codec_state{type = Type, + flags = Flags, + size = Len, + buf = Data2}}; + more -> + {more, State#codec_state{buf = Buf1}} + end + catch _:{?MODULE, Why} -> + {error, Why} + end; + <<>> -> + {more, State} + end; +decode(#codec_state{size = Len, buf = Buf, + version = Version, + type = Type, flags = Flags} = State, Data) -> + Buf1 = <>, + if size(Buf1) >= Len -> + <> = Buf1, + try + Pkt = decode_pkt(Version, Type, Flags, Payload), + State1 = case Pkt of + #connect{proto_level = V} -> + State#codec_state{version = V}; + _ -> + State + end, + {ok, Pkt, State1#codec_state{type = undefined, + flags = undefined, + size = undefined, + buf = Data1}} + catch _:{?MODULE, Why} -> + {error, Why} + end; + true -> + {more, State#codec_state{buf = Buf1}} + end. + +-spec encode(mqtt_version(), mqtt_packet()) -> binary(). +encode(Version, Pkt) -> + case Pkt of + #connect{proto_level = Version} -> encode_connect(Pkt); + #connack{} -> encode_connack(Version, Pkt); + #publish{} -> encode_publish(Version, Pkt); + #puback{} -> encode_puback(Version, Pkt); + #pubrec{} -> encode_pubrec(Version, Pkt); + #pubrel{} -> encode_pubrel(Version, Pkt); + #pubcomp{} -> encode_pubcomp(Version, Pkt); + #subscribe{} -> encode_subscribe(Version, Pkt); + #suback{} -> encode_suback(Version, Pkt); + #unsubscribe{} -> encode_unsubscribe(Version, Pkt); + #unsuback{} -> encode_unsuback(Version, Pkt); + #pingreq{} -> encode_pingreq(); + #pingresp{} -> encode_pingresp(); + #disconnect{} -> encode_disconnect(Version, Pkt); + #auth{} -> encode_auth(Pkt) + end. + +-spec pp(any()) -> iolist(). +pp(Term) -> + io_lib_pretty:print(Term, fun pp/2). + +-spec format_error(error_reason()) -> string(). +format_error({payload_too_big, Max}) -> + format("Payload exceeds ~B bytes", [Max]); +format_error(bad_varint) -> + "Variable Integer is out of boundaries"; +format_error({bad_packet_type, Type}) -> + format("Unexpected packet type: ~B", [Type]); +format_error({bad_packet, Name}) -> + format("Malformed ~s packet", [string:to_upper(atom_to_list(Name))]); +format_error({unexpected_packet, Name}) -> + format("Unexpected ~s packet", [string:to_upper(atom_to_list(Name))]); +format_error({bad_reason_code, Name, Code}) -> + format("Unexpected reason code in ~s code: ~B", + [string:to_upper(atom_to_list(Name)), Code]); +format_error({bad_properties, Name}) -> + format("Malformed properties of ~s packet", + [string:to_upper(atom_to_list(Name))]); +format_error({bad_property, Pkt, Prop}) -> + format("Malformed property ~s of ~s packet", + [Prop, string:to_upper(atom_to_list(Pkt))]); +format_error({duplicated_property, Pkt, Prop}) -> + format("Property ~s is included more than once into ~s packet", + [Prop, string:to_upper(atom_to_list(Pkt))]); +format_error(bad_will_topic_or_message) -> + "Malformed Will Topic or Will Message"; +format_error(bad_connect_username_or_password) -> + "Malformed username or password of CONNECT packet"; +format_error(bad_publish_id_or_payload) -> + "Malformed id or payload of PUBLISH packet"; +format_error({bad_topic_filters, Name}) -> + format("Malformed topic filters of ~s packet", + [string:to_upper(atom_to_list(Name))]); +format_error({bad_qos, Q}) -> + format_got_expected("Malformed QoS value", Q, "0, 1 or 2"); +format_error(bad_topic) -> + "Malformed topic"; +format_error(bad_topic_filter) -> + "Malformed topic filter"; +format_error(bad_utf8_string) -> + "Malformed UTF-8 string"; +format_error({unsupported_protocol_name, Got, Expected}) -> + format_got_expected("Unsupported protocol name", Got, Expected); +format_error({unsupported_protocol_version, Got, Expected}) -> + format_got_expected("Unsupported protocol version", Got, Expected); +format_error({{bad_flag, Name}, Got, Expected}) -> + Txt = "Unexpected " ++ atom_to_list(Name) ++ " flag", + format_got_expected(Txt, Got, Expected); +format_error({{bad_flags, Name}, Got, Expected}) -> + Txt = "Unexpected " ++ string:to_upper(atom_to_list(Name)) ++ " flags", + format_got_expected(Txt, Got, Expected); +format_error(Reason) -> + format("Unexpected error: ~w", [Reason]). + +-spec error_reason_code(error_reason()) -> reason_code(). +error_reason_code({unsupported_protocol_name, _, _}) -> + 'unsupported-protocol-version'; +error_reason_code({unsupported_protocol_version, _, _}) -> + 'unsupported-protocol-version'; +error_reason_code({payload_too_big, _}) -> 'packet-too-large'; +error_reason_code({unexpected_packet, _}) -> 'protocol-error'; +error_reason_code(_) -> 'malformed-packet'. + +-spec format_reason_code(reason_code()) -> string(). +format_reason_code('success') -> "Success"; +format_reason_code('normal-disconnection') -> "Normal disconnection"; +format_reason_code('granted-qos-0') -> "Granted QoS 0"; +format_reason_code('granted-qos-1') -> "Granted QoS 1"; +format_reason_code('granted-qos-2') -> "Granted QoS 2"; +format_reason_code('no-matching-subscribers') -> "No matching subscribers"; +format_reason_code('no-subscription-existed') -> "No subscription existed"; +format_reason_code('continue-authentication') -> "Continue authentication"; +format_reason_code('re-authenticate') -> "Re-authenticate"; +format_reason_code('unspecified-error') -> "Unspecified error"; +format_reason_code('malformed-packet') -> "Malformed Packet"; +format_reason_code('protocol-error') -> "Protocol Error"; +format_reason_code('bad-user-name-or-password') -> "Bad User Name or Password"; +format_reason_code('not-authorized') -> "Not authorized"; +format_reason_code('server-unavailable') -> "Server unavailable"; +format_reason_code('server-busy') -> "Server busy"; +format_reason_code('banned') -> "Banned"; +format_reason_code('server-shutting-down') -> "Server shutting down"; +format_reason_code('bad-authentication-method') -> "Bad authentication method"; +format_reason_code('keep-alive-timeout') -> "Keep Alive timeout"; +format_reason_code('session-taken-over') -> "Session taken over"; +format_reason_code('topic-filter-invalid') -> "Topic Filter invalid"; +format_reason_code('topic-name-invalid') -> "Topic Name invalid"; +format_reason_code('packet-identifier-in-use') -> "Packet Identifier in use"; +format_reason_code('receive-maximum-exceeded') -> "Receive Maximum exceeded"; +format_reason_code('topic-alias-invalid') -> "Topic Alias invalid"; +format_reason_code('packet-too-large') -> "Packet too large"; +format_reason_code('message-rate-too-high') -> "Message rate too high"; +format_reason_code('quota-exceeded') -> "Quota exceeded"; +format_reason_code('administrative-action') -> "Administrative action"; +format_reason_code('payload-format-invalid') -> "Payload format invalid"; +format_reason_code('retain-not-supported') -> "Retain not supported"; +format_reason_code('qos-not-supported') -> "QoS not supported"; +format_reason_code('use-another-server') -> "Use another server"; +format_reason_code('server-moved') -> "Server moved"; +format_reason_code('connection-rate-exceeded') -> "Connection rate exceeded"; +format_reason_code('maximum-connect-time') -> "Maximum connect time"; +format_reason_code('unsupported-protocol-version') -> + "Unsupported Protocol Version"; +format_reason_code('client-identifier-not-valid') -> + "Client Identifier not valid"; +format_reason_code('packet-identifier-not-found') -> + "Packet Identifier not found"; +format_reason_code('disconnect-with-will-message') -> + "Disconnect with Will Message"; +format_reason_code('implementation-specific-error') -> + "Implementation specific error"; +format_reason_code('shared-subscriptions-not-supported') -> + "Shared Subscriptions not supported"; +format_reason_code('subscription-identifiers-not-supported') -> + "Subscription Identifiers not supported"; +format_reason_code('wildcard-subscriptions-not-supported') -> + "Wildcard Subscriptions not supported"; +format_reason_code(Code) -> + format("Unexpected error: ~w", [Code]). + +-spec is_error_code(char() | reason_code()) -> boolean(). +is_error_code('success') -> false; +is_error_code('normal-disconnection') -> false; +is_error_code('granted-qos-0') -> false; +is_error_code('granted-qos-1') -> false; +is_error_code('granted-qos-2') -> false; +is_error_code('disconnect-with-will-message') -> false; +is_error_code('no-matching-subscribers') -> false; +is_error_code('no-subscription-existed') -> false; +is_error_code('continue-authentication') -> false; +is_error_code('re-authenticate') -> false; +is_error_code(Code) when is_integer(Code) -> Code >= 128; +is_error_code(_) -> true. + +%%%=================================================================== +%%% Decoder +%%%=================================================================== +-spec decode_varint(binary()) -> {non_neg_integer(), binary()} | more. +decode_varint(Data) -> + decode_varint(Data, 0, 1). + +-spec decode_varint(binary(), non_neg_integer(), pos_integer()) -> + {non_neg_integer(), binary()} | more. +decode_varint(<>, Val, Mult) -> + NewVal = Val + (C band 127) * Mult, + NewMult = Mult*128, + if NewMult > ?MAX_VARINT -> + err(bad_varint); + (C band 128) == 0 -> + {NewVal, Data}; + true -> + decode_varint(Data, NewVal, NewMult) + end; +decode_varint(_, _, _) -> + more. + +-spec decode_pkt(mqtt_version() | undefined, + non_neg_integer(), non_neg_integer(), binary()) -> mqtt_packet(). +decode_pkt(undefined, 1, Flags, Data) -> + decode_connect(Flags, Data); +decode_pkt(Version, Type, Flags, Data) when Version /= undefined, Type>1 -> + case Type of + 2 -> decode_connack(Version, Flags, Data); + 3 -> decode_publish(Version, Flags, Data); + 4 -> decode_puback(Version, Flags, Data); + 5 -> decode_pubrec(Version, Flags, Data); + 6 -> decode_pubrel(Version, Flags, Data); + 7 -> decode_pubcomp(Version, Flags, Data); + 8 -> decode_subscribe(Version, Flags, Data); + 9 -> decode_suback(Version, Flags, Data); + 10 -> decode_unsubscribe(Version, Flags, Data); + 11 -> decode_unsuback(Version, Flags, Data); + 12 -> decode_pingreq(Flags, Data); + 13 -> decode_pingresp(Flags, Data); + 14 -> decode_disconnect(Version, Flags, Data); + 15 when Version == ?MQTT_VERSION_5 -> decode_auth(Flags, Data); + _ -> err({bad_packet_type, Type}) + end; +decode_pkt(_, Type, _, _) -> + err({unexpected_packet, decode_packet_type(Type)}). + +-spec decode_connect(non_neg_integer(), binary()) -> connect(). +decode_connect(Flags, <>) -> + assert(Proto, <<"MQTT">>, unsupported_protocol_name), + if ProtoLevel == ?MQTT_VERSION_4; ProtoLevel == ?MQTT_VERSION_5 -> + decode_connect(ProtoLevel, Flags, Data); + true -> + err({unsupported_protocol_version, ProtoLevel, "4 or 5"}) + end; +decode_connect(_, _) -> + err({bad_packet, connect}). + +-spec decode_connect(mqtt_version(), non_neg_integer(), binary()) -> connect(). +decode_connect(Version, Flags, + <>) -> + assert(Flags, 0, {bad_flags, connect}), + assert(Reserved, 0, {bad_flag, reserved}), + {Props, Data1} = case Version of + ?MQTT_VERSION_5 -> decode_props(connect, Data); + ?MQTT_VERSION_4 -> {#{}, Data} + end, + case Data1 of + <> -> + {Will, WillProps, Data3} = + decode_will(Version, WillFlag, WillRetain, WillQoS, Data2), + {Username, Password} = decode_user_pass(UserFlag, PassFlag, Data3), + #connect{proto_level = Version, + will = Will, + will_properties = WillProps, + properties = Props, + clean_start = dec_bool(CleanStart), + keep_alive = KeepAlive, + client_id = utf8(ClientID), + username = utf8(Username), + password = Password}; + _ -> + err({bad_packet, connect}) + end; +decode_connect(_, _, _) -> + err({bad_packet, connect}). + +-spec decode_connack(mqtt_version(), non_neg_integer(), binary()) -> connack(). +decode_connack(Version, Flags, <<0:7, SessionPresent:1, Data/binary>>) -> + assert(Flags, 0, {bad_flags, connack}), + {Code, PropMap} = decode_code_with_props(Version, connack, Data), + #connack{session_present = dec_bool(SessionPresent), + code = Code, properties = PropMap}; +decode_connack(_, _, _) -> + err({bad_packet, connack}). + +-spec decode_publish(mqtt_version(), non_neg_integer(), binary()) -> publish(). +decode_publish(Version, Flags, <>) -> + Retain = Flags band 1, + QoS = qos((Flags bsr 1) band 3), + DUP = Flags band 8, + {ID, Props, Payload} = decode_id_props_payload(Version, QoS, Data), + #publish{dup = dec_bool(DUP), + qos = QoS, + retain = dec_bool(Retain), + topic = topic(Topic, Props), + id = ID, + properties = Props, + payload = Payload}; +decode_publish(_, _, _) -> + err({bad_packet, publish}). + +-spec decode_puback(mqtt_version(), non_neg_integer(), binary()) -> puback(). +decode_puback(Version, Flags, <>) when ID>0 -> + assert(Flags, 0, {bad_flags, puback}), + {Code, PropMap} = decode_code_with_props(Version, puback, Data), + #puback{id = ID, code = Code, properties = PropMap}; +decode_puback(_, _, _) -> + err({bad_packet, puback}). + +-spec decode_pubrec(mqtt_version(), non_neg_integer(), binary()) -> pubrec(). +decode_pubrec(Version, Flags, <>) when ID>0 -> + assert(Flags, 0, {bad_flags, pubrec}), + {Code, PropMap} = decode_code_with_props(Version, pubrec, Data), + #pubrec{id = ID, code = Code, properties = PropMap}; +decode_pubrec(_, _, _) -> + err({bad_packet, pubrec}). + +-spec decode_pubrel(mqtt_version(), non_neg_integer(), binary()) -> pubrel(). +decode_pubrel(Version, Flags, <>) when ID>0 -> + assert(Flags, 2, {bad_flags, pubrel}), + {Code, PropMap} = decode_code_with_props(Version, pubrel, Data), + #pubrel{id = ID, code = Code, properties = PropMap}; +decode_pubrel(_, _, _) -> + err({bad_packet, pubrel}). + +-spec decode_pubcomp(mqtt_version(), non_neg_integer(), binary()) -> pubcomp(). +decode_pubcomp(Version, Flags, <>) when ID>0 -> + assert(Flags, 0, {bad_flags, pubcomp}), + {Code, PropMap} = decode_code_with_props(Version, pubcomp, Data), + #pubcomp{id = ID, code = Code, properties = PropMap}; +decode_pubcomp(_, _, _) -> + err({bad_packet, pubcomp}). + +-spec decode_subscribe(mqtt_version(), non_neg_integer(), binary()) -> subscribe(). +decode_subscribe(Version, Flags, <>) when ID>0 -> + assert(Flags, 2, {bad_flags, subscribe}), + case Version of + ?MQTT_VERSION_4 -> + Filters = decode_subscribe_filters(Data), + #subscribe{id = ID, filters = Filters}; + ?MQTT_VERSION_5 -> + {Props, Payload} = decode_props(subscribe, Data), + Filters = decode_subscribe_filters(Payload), + #subscribe{id = ID, filters = Filters, properties = Props} + end; +decode_subscribe(_, _, _) -> + err({bad_packet, subscribe}). + +-spec decode_suback(mqtt_version(), non_neg_integer(), binary()) -> suback(). +decode_suback(Version, Flags, <>) when ID>0 -> + assert(Flags, 0, {bad_flags, suback}), + case Version of + ?MQTT_VERSION_4 -> + #suback{id = ID, + codes = decode_suback_codes(Data)}; + ?MQTT_VERSION_5 -> + {PropMap, Tail} = decode_props(suback, Data), + #suback{id = ID, + codes = decode_suback_codes(Tail), + properties = PropMap} + end; +decode_suback(_, _, _) -> + err({bad_packet, suback}). + +-spec decode_unsubscribe(mqtt_version(), non_neg_integer(), binary()) -> unsubscribe(). +decode_unsubscribe(Version, Flags, <>) when ID>0 -> + assert(Flags, 2, {bad_flags, unsubscribe}), + case Version of + ?MQTT_VERSION_4 -> + Filters = decode_unsubscribe_filters(Data), + #unsubscribe{id = ID, filters = Filters}; + ?MQTT_VERSION_5 -> + {Props, Payload} = decode_props(unsubscribe, Data), + Filters = decode_unsubscribe_filters(Payload), + #unsubscribe{id = ID, filters = Filters, properties = Props} + end; +decode_unsubscribe(_, _, _) -> + err({bad_packet, unsubscribe}). + +-spec decode_unsuback(mqtt_version(), non_neg_integer(), binary()) -> unsuback(). +decode_unsuback(Version, Flags, <>) when ID>0 -> + assert(Flags, 0, {bad_flags, unsuback}), + case Version of + ?MQTT_VERSION_4 -> + #unsuback{id = ID}; + ?MQTT_VERSION_5 -> + {PropMap, Tail} = decode_props(unsuback, Data), + #unsuback{id = ID, + codes = decode_unsuback_codes(Tail), + properties = PropMap} + end; +decode_unsuback(_, _, _) -> + err({bad_packet, unsuback}). + +-spec decode_pingreq(non_neg_integer(), binary()) -> pingreq(). +decode_pingreq(Flags, <<>>) -> + assert(Flags, 0, {bad_flags, pingreq}), + #pingreq{}; +decode_pingreq(_, _) -> + err({bad_packet, pingreq}). + +-spec decode_pingresp(non_neg_integer(), binary()) -> pingresp(). +decode_pingresp(Flags, <<>>) -> + assert(Flags, 0, {bad_flags, pingresp}), + #pingresp{}; +decode_pingresp(_, _) -> + err({bad_packet, pingresp}). + +-spec decode_disconnect(mqtt_version(), non_neg_integer(), binary()) -> disconnect(). +decode_disconnect(Version, Flags, Payload) -> + assert(Flags, 0, {bad_flags, disconnect}), + {Code, PropMap} = decode_code_with_props(Version, disconnect, Payload), + #disconnect{code = Code, properties = PropMap}. + +-spec decode_auth(non_neg_integer(), binary()) -> auth(). +decode_auth(Flags, Payload) -> + assert(Flags, 0, {bad_flags, auth}), + {Code, PropMap} = decode_code_with_props(?MQTT_VERSION_5, auth, Payload), + #auth{code = Code, properties = PropMap}. + +-spec decode_packet_type(char()) -> atom(). +decode_packet_type(1) -> connect; +decode_packet_type(2) -> connack; +decode_packet_type(3) -> publish; +decode_packet_type(4) -> puback; +decode_packet_type(5) -> pubrec; +decode_packet_type(6) -> pubrel; +decode_packet_type(7) -> pubcomp; +decode_packet_type(8) -> subscribe; +decode_packet_type(9) -> suback; +decode_packet_type(10) -> unsubscribe; +decode_packet_type(11) -> unsuback; +decode_packet_type(12) -> pingreq; +decode_packet_type(13) -> pingresp; +decode_packet_type(14) -> disconnect; +decode_packet_type(15) -> auth; +decode_packet_type(T) -> err({bad_packet_type, T}). + +-spec decode_will(mqtt_version(), 0|1, 0|1, qos(), binary()) -> + {undefined | publish(), properties(), binary()}. +decode_will(_, 0, WillRetain, WillQoS, Data) -> + assert(WillRetain, 0, {bad_flag, will_retain}), + assert(WillQoS, 0, {bad_flag, will_qos}), + {undefined, #{}, Data}; +decode_will(Version, 1, WillRetain, WillQoS, Data) -> + {Props, Data1} = case Version of + ?MQTT_VERSION_5 -> decode_props(connect, Data); + ?MQTT_VERSION_4 -> {#{}, Data} + end, + case Data1 of + <> -> + {#publish{retain = dec_bool(WillRetain), + qos = qos(WillQoS), + topic = topic(Topic), + payload = Message}, + Props, Data2}; + _ -> + err(bad_will_topic_or_message) + end. + +-spec decode_user_pass(non_neg_integer(), non_neg_integer(), + binary()) -> {binary(), binary()}. +decode_user_pass(1, 0, <>) -> + {utf8(User), <<>>}; +decode_user_pass(1, 1, <>) -> + {utf8(User), Pass}; +decode_user_pass(0, Flag, <<>>) -> + assert(Flag, 0, {bad_flag, password}), + {<<>>, <<>>}; +decode_user_pass(_, _, _) -> + err(bad_connect_username_or_password). + +-spec decode_id_props_payload(mqtt_version(), non_neg_integer(), binary()) -> + {undefined | non_neg_integer(), properties(), binary()}. +decode_id_props_payload(Version, 0, Data) -> + case Version of + ?MQTT_VERSION_4 -> + {undefined, #{}, Data}; + ?MQTT_VERSION_5 -> + {Props, Payload} = decode_props(publish, Data), + {undefined, Props, Payload} + end; +decode_id_props_payload(Version, _, <>) when ID>0 -> + case Version of + ?MQTT_VERSION_4 -> + {ID, #{}, Data}; + ?MQTT_VERSION_5 -> + {Props, Payload} = decode_props(publish, Data), + {ID, Props, Payload} + end; +decode_id_props_payload(_, _, _) -> + err(bad_publish_id_or_payload). + +-spec decode_subscribe_filters(binary()) -> [{binary(), sub_opts()}]. +decode_subscribe_filters(<>) -> + assert(Reserved, 0, {bad_flag, reserved}), + case RH of + 3 -> err({{bad_flag, retain_handling}, RH, "0, 1 or 2"}); + _ -> ok + end, + Opts = #sub_opts{qos = qos(QoS), + no_local = dec_bool(NL), + retain_as_published = dec_bool(RAP), + retain_handling = RH}, + [{topic_filter(Filter), Opts}|decode_subscribe_filters(Tail)]; +decode_subscribe_filters(<<>>) -> + []; +decode_subscribe_filters(_) -> + err({bad_topic_filters, subscribe}). + +-spec decode_unsubscribe_filters(binary()) -> [binary()]. +decode_unsubscribe_filters(<>) -> + [topic_filter(Filter)|decode_unsubscribe_filters(Tail)]; +decode_unsubscribe_filters(<<>>) -> + []; +decode_unsubscribe_filters(_) -> + err({bad_topic_filters, unsubscribe}). + +-spec decode_suback_codes(binary()) -> [reason_code()]. +decode_suback_codes(<>) -> + [decode_suback_code(Code)|decode_suback_codes(Data)]; +decode_suback_codes(<<>>) -> + []. + +-spec decode_unsuback_codes(binary()) -> [reason_code()]. +decode_unsuback_codes(<>) -> + [decode_unsuback_code(Code)|decode_unsuback_codes(Data)]; +decode_unsuback_codes(<<>>) -> + []. + +-spec decode_utf8_pair(binary()) -> {utf8_pair(), binary()}. +decode_utf8_pair(<>) -> + {{utf8(Name), utf8(Val)}, Tail}; +decode_utf8_pair(_) -> + err(bad_utf8_pair). + +-spec decode_props(atom(), binary()) -> {properties(), binary()}. +decode_props(Pkt, Data) -> + try + {Len, Data1} = decode_varint(Data), + <> = Data1, + {decode_props(Pkt, PData, #{}), Tail} + catch _:{badmatch, _} -> + err({bad_properties, Pkt}) + end. + +-spec decode_props(atom(), binary(), properties()) -> properties(). +decode_props(_, <<>>, Props) -> + Props; +decode_props(Pkt, Data, Props) -> + {Type, Payload} = decode_varint(Data), + {Name, Val, Tail} = decode_prop(Pkt, Type, Payload), + Props1 = maps:update_with( + Name, + fun(Vals) when is_list(Val) -> + Vals ++ Val; + (_) -> + err({duplicated_property, Pkt, Name}) + end, Val, Props), + decode_props(Pkt, Tail, Props1). + +-spec decode_prop(atom(), char(), binary()) -> {property(), term(), binary()}. +decode_prop(_, 18, <>) -> + {assigned_client_identifier, utf8(Data), Bin}; +decode_prop(_, 22, <>) -> + {authentication_data, Data, Bin}; +decode_prop(_, 21, <>) -> + {authentication_method, utf8(Data), Bin}; +decode_prop(_, 3, <>) -> + {content_type, utf8(Data), Bin}; +decode_prop(_, 9, <>) -> + {correlation_data, Data, Bin}; +decode_prop(_, 39, <>) when Size>0 -> + {maximum_packet_size, Size, Bin}; +decode_prop(Pkt, 36, <>) -> + {maximum_qos, + case QoS of + 0 -> 0; + 1 -> 1; + _ -> err({bad_property, Pkt, maximum_qos}) + end, Bin}; +decode_prop(_, 2, <>) -> + {message_expiry_interval, I, Bin}; +decode_prop(Pkt, 1, <>) -> + {payload_format_indicator, + case I of + 0 -> binary; + 1 -> utf8; + _ -> err({bad_property, Pkt, payload_format_indicator}) + end, Bin}; +decode_prop(_, 31, <>) -> + {reason_string, utf8(Data), Bin}; +decode_prop(_, 33, <>) when Max>0 -> + {receive_maximum, Max, Bin}; +decode_prop(Pkt, 23, Data) -> + decode_bool_prop(Pkt, request_problem_information, Data); +decode_prop(Pkt, 25, Data) -> + decode_bool_prop(Pkt, request_response_information, Data); +decode_prop(_, 26, <>) -> + {response_information, utf8(Data), Bin}; +decode_prop(_, 8, <>) -> + {response_topic, topic(Data), Bin}; +decode_prop(Pkt, 37, Data) -> + decode_bool_prop(Pkt, retain_available, Data); +decode_prop(_, 19, <>) -> + {server_keep_alive, Secs, Bin}; +decode_prop(_, 28, <>) -> + {server_reference, utf8(Data), Bin}; +decode_prop(_, 17, <>) -> + {session_expiry_interval, I, Bin}; +decode_prop(Pkt, 42, Data) -> + decode_bool_prop(Pkt, shared_subscription_available, Data); +decode_prop(Pkt, 11, Data) when Pkt == publish; Pkt == subscribe -> + case decode_varint(Data) of + {ID, Bin} when Pkt == publish -> + {subscription_identifier, [ID], Bin}; + {ID, Bin} when Pkt == subscribe -> + {subscription_identifier, ID, Bin}; + _ -> + err({bad_property, publish, subscription_identifier}) + end; +decode_prop(Pkt, 41, Data) -> + decode_bool_prop(Pkt, subscription_identifiers_available, Data); +decode_prop(_, 35, <>) when Alias>0 -> + {topic_alias, Alias, Bin}; +decode_prop(_, 34, <>) -> + {topic_alias_maximum, Max, Bin}; +decode_prop(_, 38, Data) -> + {Pair, Bin} = decode_utf8_pair(Data), + {user_property, [Pair], Bin}; +decode_prop(Pkt, 40, Data) -> + decode_bool_prop(Pkt, wildcard_subscription_available, Data); +decode_prop(_, 24, <>) -> + {will_delay_interval, I, Bin}; +decode_prop(Pkt, _, _) -> + err({bad_properties, Pkt}). + +decode_bool_prop(Pkt, Name, <>) -> + case Val of + 0 -> {Name, false, Bin}; + 1 -> {Name, true, Bin}; + _ -> err({bad_property, Pkt, Name}) + end; +decode_bool_prop(Pkt, Name, _) -> + err({bad_property, Pkt, Name}). + +-spec decode_code_with_props(mqtt_version(), atom(), binary()) -> + {reason_code(), properties()}. +decode_code_with_props(_, connack, <>) -> + {decode_connack_code(Code), + case Props of + <<>> -> + #{}; + _ -> + {PropMap, <<>>} = decode_props(connack, Props), + PropMap + end}; +decode_code_with_props(_, Pkt, <<>>) -> + {decode_reason_code(Pkt, 0), #{}}; +decode_code_with_props(?MQTT_VERSION_5, Pkt, <>) -> + {decode_reason_code(Pkt, Code), #{}}; +decode_code_with_props(?MQTT_VERSION_5, Pkt, <>) -> + {PropMap, <<>>} = decode_props(Pkt, Props), + {decode_reason_code(Pkt, Code), PropMap}; +decode_code_with_props(_, Pkt, _) -> + err({bad_packet, Pkt}). + +-spec decode_pubcomp_code(char()) -> reason_code(). +decode_pubcomp_code(0) -> 'success'; +decode_pubcomp_code(146) -> 'packet-identifier-not-found'; +decode_pubcomp_code(Code) -> err({bad_reason_code, pubcomp, Code}). + +-spec decode_pubrec_code(char()) -> reason_code(). +decode_pubrec_code(0) -> 'success'; +decode_pubrec_code(16) -> 'no-matching-subscribers'; +decode_pubrec_code(128) -> 'unspecified-error'; +decode_pubrec_code(131) -> 'implementation-specific-error'; +decode_pubrec_code(135) -> 'not-authorized'; +decode_pubrec_code(144) -> 'topic-name-invalid'; +decode_pubrec_code(145) -> 'packet-identifier-in-use'; +decode_pubrec_code(151) -> 'quota-exceeded'; +decode_pubrec_code(153) -> 'payload-format-invalid'; +decode_pubrec_code(Code) -> err({bad_reason_code, pubrec, Code}). + +-spec decode_disconnect_code(char()) -> reason_code(). +decode_disconnect_code(0) -> 'normal-disconnection'; +decode_disconnect_code(4) -> 'disconnect-with-will-message'; +decode_disconnect_code(128) -> 'unspecified-error'; +decode_disconnect_code(129) -> 'malformed-packet'; +decode_disconnect_code(130) -> 'protocol-error'; +decode_disconnect_code(131) -> 'implementation-specific-error'; +decode_disconnect_code(135) -> 'not-authorized'; +decode_disconnect_code(137) -> 'server-busy'; +decode_disconnect_code(139) -> 'server-shutting-down'; +decode_disconnect_code(140) -> 'bad-authentication-method'; +decode_disconnect_code(141) -> 'keep-alive-timeout'; +decode_disconnect_code(142) -> 'session-taken-over'; +decode_disconnect_code(143) -> 'topic-filter-invalid'; +decode_disconnect_code(144) -> 'topic-name-invalid'; +decode_disconnect_code(147) -> 'receive-maximum-exceeded'; +decode_disconnect_code(148) -> 'topic-alias-invalid'; +decode_disconnect_code(149) -> 'packet-too-large'; +decode_disconnect_code(150) -> 'message-rate-too-high'; +decode_disconnect_code(151) -> 'quota-exceeded'; +decode_disconnect_code(152) -> 'administrative-action'; +decode_disconnect_code(153) -> 'payload-format-invalid'; +decode_disconnect_code(154) -> 'retain-not-supported'; +decode_disconnect_code(155) -> 'qos-not-supported'; +decode_disconnect_code(156) -> 'use-another-server'; +decode_disconnect_code(157) -> 'server-moved'; +decode_disconnect_code(158) -> 'shared-subscriptions-not-supported'; +decode_disconnect_code(159) -> 'connection-rate-exceeded'; +decode_disconnect_code(160) -> 'maximum-connect-time'; +decode_disconnect_code(161) -> 'subscription-identifiers-not-supported'; +decode_disconnect_code(162) -> 'wildcard-subscriptions-not-supported'; +decode_disconnect_code(Code) -> err({bad_reason_code, disconnect, Code}). + +-spec decode_auth_code(char()) -> reason_code(). +decode_auth_code(0) -> 'success'; +decode_auth_code(24) -> 'continue-authentication'; +decode_auth_code(25) -> 're-authenticate'; +decode_auth_code(Code) -> err({bad_reason_code, auth, Code}). + +-spec decode_suback_code(char()) -> 0..2 | reason_code(). +decode_suback_code(0) -> 0; +decode_suback_code(1) -> 1; +decode_suback_code(2) -> 2; +decode_suback_code(128) -> 'unspecified-error'; +decode_suback_code(131) -> 'implementation-specific-error'; +decode_suback_code(135) -> 'not-authorized'; +decode_suback_code(143) -> 'topic-filter-invalid'; +decode_suback_code(145) -> 'packet-identifier-in-use'; +decode_suback_code(151) -> 'quota-exceeded'; +decode_suback_code(158) -> 'shared-subscriptions-not-supported'; +decode_suback_code(161) -> 'subscription-identifiers-not-supported'; +decode_suback_code(162) -> 'wildcard-subscriptions-not-supported'; +decode_suback_code(Code) -> err({bad_reason_code, suback, Code}). + +-spec decode_unsuback_code(char()) -> reason_code(). +decode_unsuback_code(0) -> 'success'; +decode_unsuback_code(17) -> 'no-subscription-existed'; +decode_unsuback_code(128) -> 'unspecified-error'; +decode_unsuback_code(131) -> 'implementation-specific-error'; +decode_unsuback_code(135) -> 'not-authorized'; +decode_unsuback_code(143) -> 'topic-filter-invalid'; +decode_unsuback_code(145) -> 'packet-identifier-in-use'; +decode_unsuback_code(Code) -> err({bad_reason_code, unsuback, Code}). + +-spec decode_puback_code(char()) -> reason_code(). +decode_puback_code(0) -> 'success'; +decode_puback_code(16) -> 'no-matching-subscribers'; +decode_puback_code(128) -> 'unspecified-error'; +decode_puback_code(131) -> 'implementation-specific-error'; +decode_puback_code(135) -> 'not-authorized'; +decode_puback_code(144) -> 'topic-name-invalid'; +decode_puback_code(145) -> 'packet-identifier-in-use'; +decode_puback_code(151) -> 'quota-exceeded'; +decode_puback_code(153) -> 'payload-format-invalid'; +decode_puback_code(Code) -> err({bad_reason_code, puback, Code}). + +-spec decode_pubrel_code(char()) -> reason_code(). +decode_pubrel_code(0) -> 'success'; +decode_pubrel_code(146) -> 'packet-identifier-not-found'; +decode_pubrel_code(Code) -> err({bad_reason_code, pubrel, Code}). + +-spec decode_connack_code(char()) -> reason_code(). +decode_connack_code(0) -> 'success'; +decode_connack_code(1) -> 'unsupported-protocol-version'; +decode_connack_code(2) -> 'client-identifier-not-valid'; +decode_connack_code(3) -> 'server-unavailable'; +decode_connack_code(4) -> 'bad-user-name-or-password'; +decode_connack_code(5) -> 'not-authorized'; +decode_connack_code(128) -> 'unspecified-error'; +decode_connack_code(129) -> 'malformed-packet'; +decode_connack_code(130) -> 'protocol-error'; +decode_connack_code(131) -> 'implementation-specific-error'; +decode_connack_code(132) -> 'unsupported-protocol-version'; +decode_connack_code(133) -> 'client-identifier-not-valid'; +decode_connack_code(134) -> 'bad-user-name-or-password'; +decode_connack_code(135) -> 'not-authorized'; +decode_connack_code(136) -> 'server-unavailable'; +decode_connack_code(137) -> 'server-busy'; +decode_connack_code(138) -> 'banned'; +decode_connack_code(140) -> 'bad-authentication-method'; +decode_connack_code(144) -> 'topic-name-invalid'; +decode_connack_code(149) -> 'packet-too-large'; +decode_connack_code(151) -> 'quota-exceeded'; +decode_connack_code(153) -> 'payload-format-invalid'; +decode_connack_code(154) -> 'retain-not-supported'; +decode_connack_code(155) -> 'qos-not-supported'; +decode_connack_code(156) -> 'use-another-server'; +decode_connack_code(157) -> 'server-moved'; +decode_connack_code(159) -> 'connection-rate-exceeded'; +decode_connack_code(Code) -> err({bad_reason_code, connack, Code}). + +-spec decode_reason_code(atom(), char()) -> reason_code(). +decode_reason_code(pubcomp, Code) -> decode_pubcomp_code(Code); +decode_reason_code(pubrec, Code) -> decode_pubrec_code(Code); +decode_reason_code(disconnect, Code) -> decode_disconnect_code(Code); +decode_reason_code(auth, Code) -> decode_auth_code(Code); +decode_reason_code(puback, Code) -> decode_puback_code(Code); +decode_reason_code(pubrel, Code) -> decode_pubrel_code(Code); +decode_reason_code(connack, Code) -> decode_connack_code(Code). + +%%%=================================================================== +%%% Encoder +%%%=================================================================== +encode_connect(#connect{proto_level = Version, properties = Props, + will = Will, will_properties = WillProps, + clean_start = CleanStart, + keep_alive = KeepAlive, client_id = ClientID, + username = Username, password = Password}) -> + UserFlag = Username /= <<>>, + PassFlag = UserFlag andalso Password /= <<>>, + WillFlag = is_record(Will, publish), + WillRetain = WillFlag andalso Will#publish.retain, + WillQoS = if WillFlag -> Will#publish.qos; + true -> 0 + end, + Header = <<4:16, "MQTT", Version, (enc_bool(UserFlag)):1, + (enc_bool(PassFlag)):1, (enc_bool(WillRetain)):1, + WillQoS:2, (enc_bool(WillFlag)):1, + (enc_bool(CleanStart)):1, 0:1, + KeepAlive:16>>, + EncClientID = <<(size(ClientID)):16, ClientID/binary>>, + EncWill = encode_will(Will), + EncUserPass = encode_user_pass(Username, Password), + Payload = case Version of + ?MQTT_VERSION_5 -> + [Header, encode_props(Props), EncClientID, + if WillFlag -> encode_props(WillProps); + true -> <<>> + end, + EncWill, EncUserPass]; + _ -> + [Header, EncClientID, EncWill, EncUserPass] + end, + <<1:4, 0:4, (encode_with_len(Payload))/binary>>. + +encode_connack(Version, #connack{session_present = SP, + code = Code, properties = Props}) -> + Payload = [enc_bool(SP), + encode_connack_code(Version, Code), + encode_props(Version, Props)], + <<2:4, 0:4, (encode_with_len(Payload))/binary>>. + +encode_publish(Version, #publish{qos = QoS, retain = Retain, dup = Dup, + topic = Topic, id = ID, payload = Payload, + properties = Props}) -> + Data1 = <<(size(Topic)):16, Topic/binary>>, + Data2 = case QoS of + 0 -> <<>>; + _ when ID>0 -> <> + end, + Data3 = encode_props(Version, Props), + Data4 = encode_with_len([Data1, Data2, Data3, Payload]), + <<3:4, (enc_bool(Dup)):1, QoS:2, (enc_bool(Retain)):1, Data4/binary>>. + +encode_puback(Version, #puback{id = ID, code = Code, + properties = Props}) when ID>0 -> + Data = encode_code_with_props(Version, Code, Props), + <<4:4, 0:4, (encode_with_len([<>|Data]))/binary>>. + +encode_pubrec(Version, #pubrec{id = ID, code = Code, + properties = Props}) when ID>0 -> + Data = encode_code_with_props(Version, Code, Props), + <<5:4, 0:4, (encode_with_len([<>|Data]))/binary>>. + +encode_pubrel(Version, #pubrel{id = ID, code = Code, + properties = Props}) when ID>0 -> + Data = encode_code_with_props(Version, Code, Props), + <<6:4, 2:4, (encode_with_len([<>|Data]))/binary>>. + +encode_pubcomp(Version, #pubcomp{id = ID, code = Code, + properties = Props}) when ID>0 -> + Data = encode_code_with_props(Version, Code, Props), + <<7:4, 0:4, (encode_with_len([<>|Data]))/binary>>. + +encode_subscribe(Version, #subscribe{id = ID, + filters = [_|_] = Filters, + properties = Props}) when ID>0 -> + EncFilters = [<<(size(Filter)):16, Filter/binary, + (encode_subscription_options(SubOpts))>> || + {Filter, SubOpts} <- Filters], + Payload = [<>, encode_props(Version, Props), EncFilters], + <<8:4, 2:4, (encode_with_len(Payload))/binary>>. + +encode_suback(Version, #suback{id = ID, codes = Codes, + properties = Props}) when ID>0 -> + Payload = [<>, encode_props(Version, Props) + |[encode_reason_code(Code) || Code <- Codes]], + <<9:4, 0:4, (encode_with_len(Payload))/binary>>. + +encode_unsubscribe(Version, #unsubscribe{id = ID, + filters = [_|_] = Filters, + properties = Props}) when ID>0 -> + EncFilters = [<<(size(Filter)):16, Filter/binary>> || Filter <- Filters], + Payload = [<>, encode_props(Version, Props), EncFilters], + <<10:4, 2:4, (encode_with_len(Payload))/binary>>. + +encode_unsuback(Version, #unsuback{id = ID, codes = Codes, + properties = Props}) when ID>0 -> + EncCodes = case Version of + ?MQTT_VERSION_5 -> + [encode_reason_code(Code) || Code <- Codes]; + ?MQTT_VERSION_4 -> + [] + end, + Payload = [<>, encode_props(Version, Props)|EncCodes], + <<11:4, 0:4, (encode_with_len(Payload))/binary>>. + +encode_pingreq() -> + <<12:4, 0:4, 0>>. + +encode_pingresp() -> + <<13:4, 0:4, 0>>. + +encode_disconnect(Version, #disconnect{code = Code, properties = Props}) -> + Data = encode_code_with_props(Version, Code, Props), + <<14:4, 0:4, (encode_with_len(Data))/binary>>. + +encode_auth(#auth{code = Code, properties = Props}) -> + Data = encode_code_with_props(?MQTT_VERSION_5, Code, Props), + <<15:4, 0:4, (encode_with_len(Data))/binary>>. + +-spec encode_with_len(iodata()) -> binary(). +encode_with_len(IOData) -> + Data = iolist_to_binary(IOData), + Len = encode_varint(size(Data)), + <>. + +-spec encode_varint(non_neg_integer()) -> binary(). +encode_varint(X) when X < 128 -> + <<0:1, X:7>>; +encode_varint(X) when X < ?MAX_VARINT -> + <<1:1, (X rem 128):7, (encode_varint(X div 128))/binary>>. + +-spec encode_props(mqtt_version(), properties()) -> binary(). +encode_props(?MQTT_VERSION_5, Props) -> + encode_props(Props); +encode_props(?MQTT_VERSION_4, _) -> + <<>>. + +-spec encode_props(properties()) -> binary(). +encode_props(Props) -> + encode_with_len( + maps:fold( + fun(Name, Val, Acc) -> + [encode_prop(Name, Val)|Acc] + end, [], Props)). + +-spec encode_prop(property(), term()) -> iodata(). +encode_prop(assigned_client_identifier, <<>>) -> + <<>>; +encode_prop(assigned_client_identifier, ID) -> + <<18, (size(ID)):16, ID/binary>>; +encode_prop(authentication_data, <<>>) -> + <<>>; +encode_prop(authentication_data, Data) -> + <<22, (size(Data)):16, Data/binary>>; +encode_prop(authentication_method, <<>>) -> + <<>>; +encode_prop(authentication_method, M) -> + <<21, (size(M)):16, M/binary>>; +encode_prop(content_type, <<>>) -> + <<>>; +encode_prop(content_type, T) -> + <<3, (size(T)):16, T/binary>>; +encode_prop(correlation_data, <<>>) -> + <<>>; +encode_prop(correlation_data, Data) -> + <<9, (size(Data)):16, Data/binary>>; +encode_prop(maximum_packet_size, Size) when Size>0, Size= + <<39, Size:32>>; +encode_prop(maximum_qos, QoS) when QoS>=0, QoS<2 -> + <<36, QoS>>; +encode_prop(message_expiry_interval, I) when I>=0, I= + <<2, I:32>>; +encode_prop(payload_format_indicator, binary) -> + <<>>; +encode_prop(payload_format_indicator, utf8) -> + <<1, 1>>; +encode_prop(reason_string, <<>>) -> + <<>>; +encode_prop(reason_string, S) -> + <<31, (size(S)):16, S/binary>>; +encode_prop(receive_maximum, Max) when Max>0, Max= + <<33, Max:16>>; +encode_prop(request_problem_information, true) -> + <<>>; +encode_prop(request_problem_information, false) -> + <<23, 0>>; +encode_prop(request_response_information, false) -> + <<>>; +encode_prop(request_response_information, true) -> + <<25, 1>>; +encode_prop(response_information, <<>>) -> + <<>>; +encode_prop(response_information, S) -> + <<26, (size(S)):16, S/binary>>; +encode_prop(response_topic, <<>>) -> + <<>>; +encode_prop(response_topic, T) -> + <<8, (size(T)):16, T/binary>>; +encode_prop(retain_available, true) -> + <<>>; +encode_prop(retain_available, false) -> + <<37, 0>>; +encode_prop(server_keep_alive, Secs) when Secs>=0, Secs= + <<19, Secs:16>>; +encode_prop(server_reference, <<>>) -> + <<>>; +encode_prop(server_reference, S) -> + <<28, (size(S)):16, S/binary>>; +encode_prop(session_expiry_interval, I) when I>=0, I= + <<17, I:32>>; +encode_prop(shared_subscription_available, true) -> + <<>>; +encode_prop(shared_subscription_available, false) -> + <<42, 0>>; +encode_prop(subscription_identifier, [_|_] = IDs) -> + [encode_prop(subscription_identifier, ID) || ID <- IDs]; +encode_prop(subscription_identifier, ID) when ID>0, ID + <<11, (encode_varint(ID))/binary>>; +encode_prop(subscription_identifiers_available, true) -> + <<>>; +encode_prop(subscription_identifiers_available, false) -> + <<41, 0>>; +encode_prop(topic_alias, Alias) when Alias>0, Alias= + <<35, Alias:16>>; +encode_prop(topic_alias_maximum, 0) -> + <<>>; +encode_prop(topic_alias_maximum, Max) when Max>0, Max= + <<34, Max:16>>; +encode_prop(user_property, Pairs) -> + [<<38, (encode_utf8_pair(Pair))/binary>> || Pair <- Pairs]; +encode_prop(wildcard_subscription_available, true) -> + <<>>; +encode_prop(wildcard_subscription_available, false) -> + <<40, 0>>; +encode_prop(will_delay_interval, 0) -> + <<>>; +encode_prop(will_delay_interval, I) when I>0, I= + <<24, I:32>>. + +-spec encode_user_pass(binary(), binary()) -> binary(). +encode_user_pass(User, Pass) when User /= <<>> andalso Pass /= <<>> -> + <<(size(User)):16, User/binary, (size(Pass)):16, Pass/binary>>; +encode_user_pass(User, _) when User /= <<>> -> + <<(size(User)):16, User/binary>>; +encode_user_pass(_, _) -> + <<>>. + +-spec encode_will(undefined | publish()) -> binary(). +encode_will(#publish{topic = Topic, payload = Payload}) -> + <<(size(Topic)):16, Topic/binary, + (size(Payload)):16, Payload/binary>>; +encode_will(undefined) -> + <<>>. + +encode_subscription_options(#sub_opts{qos = QoS, + no_local = NL, + retain_as_published = RAP, + retain_handling = RH}) + when QoS>=0, RH>=0, QoS<3, RH<3 -> + (RH bsl 4) bor (enc_bool(RAP) bsl 3) bor (enc_bool(NL) bsl 2) bor QoS. + +-spec encode_code_with_props(mqtt_version(), reason_code(), properties()) -> [binary()]. +encode_code_with_props(Version, Code, Props) -> + if Version == ?MQTT_VERSION_4 orelse + (Code == success andalso Props == #{}) -> + []; + Props == #{} -> + [encode_reason_code(Code)]; + true -> + [encode_reason_code(Code), encode_props(Props)] + end. + +-spec encode_utf8_pair({binary(), binary()}) -> binary(). +encode_utf8_pair({Key, Val}) -> + <<(size(Key)):16, Key/binary, (size(Val)):16, Val/binary>>. + +-spec encode_connack_code(mqtt_version(), atom()) -> char(). +encode_connack_code(?MQTT_VERSION_5, Reason) -> encode_reason_code(Reason); +encode_connack_code(_, success) -> 0; +encode_connack_code(_, 'unsupported-protocol-version') -> 1; +encode_connack_code(_, 'client-identifier-not-valid') -> 2; +encode_connack_code(_, 'server-unavailable') -> 3; +encode_connack_code(_, 'bad-user-name-or-password') -> 4; +encode_connack_code(_, 'not-authorized') -> 5; +encode_connack_code(_, _) -> 128. + +-spec encode_reason_code(char() | reason_code()) -> char(). +encode_reason_code('success') -> 0; +encode_reason_code('normal-disconnection') -> 0; +encode_reason_code('granted-qos-0') -> 0; +encode_reason_code('granted-qos-1') -> 1; +encode_reason_code('granted-qos-2') -> 2; +encode_reason_code('disconnect-with-will-message') -> 4; +encode_reason_code('no-matching-subscribers') -> 16; +encode_reason_code('no-subscription-existed') -> 17; +encode_reason_code('continue-authentication') -> 24; +encode_reason_code('re-authenticate') -> 25; +encode_reason_code('unspecified-error') -> 128; +encode_reason_code('malformed-packet') -> 129; +encode_reason_code('protocol-error') -> 130; +encode_reason_code('implementation-specific-error') -> 131; +encode_reason_code('unsupported-protocol-version') -> 132; +encode_reason_code('client-identifier-not-valid') -> 133; +encode_reason_code('bad-user-name-or-password') -> 134; +encode_reason_code('not-authorized') -> 135; +encode_reason_code('server-unavailable') -> 136; +encode_reason_code('server-busy') -> 137; +encode_reason_code('banned') -> 138; +encode_reason_code('server-shutting-down') -> 139; +encode_reason_code('bad-authentication-method') -> 140; +encode_reason_code('keep-alive-timeout') -> 141; +encode_reason_code('session-taken-over') -> 142; +encode_reason_code('topic-filter-invalid') -> 143; +encode_reason_code('topic-name-invalid') -> 144; +encode_reason_code('packet-identifier-in-use') -> 145; +encode_reason_code('packet-identifier-not-found') -> 146; +encode_reason_code('receive-maximum-exceeded') -> 147; +encode_reason_code('topic-alias-invalid') -> 148; +encode_reason_code('packet-too-large') -> 149; +encode_reason_code('message-rate-too-high') -> 150; +encode_reason_code('quota-exceeded') -> 151; +encode_reason_code('administrative-action') -> 152; +encode_reason_code('payload-format-invalid') -> 153; +encode_reason_code('retain-not-supported') -> 154; +encode_reason_code('qos-not-supported') -> 155; +encode_reason_code('use-another-server') -> 156; +encode_reason_code('server-moved') -> 157; +encode_reason_code('shared-subscriptions-not-supported') -> 158; +encode_reason_code('connection-rate-exceeded') -> 159; +encode_reason_code('maximum-connect-time') -> 160; +encode_reason_code('subscription-identifiers-not-supported') -> 161; +encode_reason_code('wildcard-subscriptions-not-supported') -> 162; +encode_reason_code(Code) when is_integer(Code) -> Code. + +%%%=================================================================== +%%% Formatters +%%%=================================================================== +-spec pp(atom(), non_neg_integer()) -> [atom()] | no. +pp(codec_state, 6) -> record_info(fields, codec_state); +pp(connect, 9) -> record_info(fields, connect); +pp(connack, 3) -> record_info(fields, connack); +pp(publish, 8) -> record_info(fields, publish); +pp(puback, 3) -> record_info(fields, puback); +pp(pubrec, 3) -> record_info(fields, pubrec); +pp(pubrel, 4) -> record_info(fields, pubrel); +pp(pubcomp, 3) -> record_info(fields, pubcomp); +pp(subscribe, 4) -> record_info(fields, subscribe); +pp(suback, 3) -> record_info(fields, suback); +pp(unsubscribe, 3) -> record_info(fields, unsubscribe); +pp(unsuback, 1) -> record_info(fields, unsuback); +pp(pingreq, 1) -> record_info(fields, pingreq); +pp(pingresp, 0) -> record_info(fields, pingresp); +pp(disconnect, 2) -> record_info(fields, disconnect); +pp(sub_opts, 4) -> record_info(fields, sub_opts); +pp(_, _) -> no. + +-spec format(io:format(), list()) -> string(). +format(Fmt, Args) -> + lists:flatten(io_lib:format(Fmt, Args)). + +format_got_expected(Txt, Got, Expected) -> + FmtGot = term_format(Got), + FmtExp = term_format(Expected), + format("~s: " ++ FmtGot ++ " (expected: " ++ FmtExp ++ ")", + [Txt, Got, Expected]). + +term_format(I) when is_integer(I) -> + "~B"; +term_format(B) when is_binary(B) -> + term_format(binary_to_list(B)); +term_format(A) when is_atom(A) -> + term_format(atom_to_list(A)); +term_format(T) -> + case io_lib:printable_latin1_list(T) of + true -> "~s"; + false -> "~w" + end. + +%%%=================================================================== +%%% Validators +%%%=================================================================== +-spec assert(T, any(), any()) -> T. +assert(Got, Got, _) -> + Got; +assert(Got, Expected, Reason) -> + err({Reason, Got, Expected}). + +-spec qos(qos()) -> qos(). +qos(QoS) when is_integer(QoS), QoS>=0, QoS<3 -> + QoS; +qos(QoS) -> + err({bad_qos, QoS}). + +-spec topic(binary()) -> binary(). +topic(Topic) -> + topic(Topic, #{}). + +-spec topic(binary(), properties()) -> binary(). +topic(<<>>, Props) -> + case maps:is_key(topic_alias, Props) of + true -> <<>>; + false -> err(bad_topic) + end; +topic(Bin, _) when is_binary(Bin) -> + ok = check_topic(Bin), + ok = check_utf8(Bin), + Bin; +topic(_, _) -> + err(bad_topic). + +-spec topic_filter(binary()) -> binary(). +topic_filter(<<>>) -> + err(bad_topic_filter); +topic_filter(Bin) when is_binary(Bin) -> + ok = check_topic_filter(Bin, $/), + ok = check_utf8(Bin), + Bin; +topic_filter(_) -> + err(bad_topic_filter). + +-spec utf8(binary()) -> binary(). +utf8(Bin) -> + ok = check_utf8(Bin), + ok = check_zero(Bin), + Bin. + +-spec check_topic(binary()) -> ok. +check_topic(<>) when H == $#; H == $+; H == 0 -> + err(bad_topic); +check_topic(<<_, T/binary>>) -> + check_topic(T); +check_topic(<<>>) -> + ok. + +-spec check_topic_filter(binary(), char()) -> ok. +check_topic_filter(<<>>, _) -> + ok; +check_topic_filter(_, $#) -> + err(bad_topic_filter); +check_topic_filter(<<$#, _/binary>>, C) when C /= $/ -> + err(bad_topic_filter); +check_topic_filter(<<$+, _/binary>>, C) when C /= $/ -> + err(bad_topic_filter); +check_topic_filter(<>, $+) when C /= $/ -> + err(bad_topic_filter); +check_topic_filter(<<0, _/binary>>, _) -> + err(bad_topic_filter); +check_topic_filter(<>, _) -> + check_topic_filter(T, H). + +-spec check_utf8(binary()) -> ok. +check_utf8(Bin) -> + case unicode:characters_to_binary(Bin, utf8) of + UTF8Str when is_binary(UTF8Str) -> + ok; + _ -> + err(bad_utf8_string) + end. + +-spec check_zero(binary()) -> ok. +check_zero(<<0, _/binary>>) -> + err(bad_utf8_string); +check_zero(<<_, T/binary>>) -> + check_zero(T); +check_zero(<<>>) -> + ok. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec dec_bool(non_neg_integer()) -> boolean(). +dec_bool(0) -> false; +dec_bool(_) -> true. + +-spec enc_bool(boolean()) -> 0..1. +enc_bool(true) -> 1; +enc_bool(false) -> 0. + +-spec err(any()) -> no_return(). +err(Reason) -> + erlang:error({?MODULE, Reason}).