diff --git a/src/mod_multicast.erl b/src/mod_multicast.erl index 161d3a4c4..fa076da70 100644 --- a/src/mod_multicast.erl +++ b/src/mod_multicast.erl @@ -35,7 +35,7 @@ %% API -export([start/2, stop/1, reload/3, - user_send_packet/1]). + user_send_packet/1]). %% gen_server callbacks -export([init/1, handle_info/2, handle_call/3, @@ -51,11 +51,6 @@ response, ts :: integer()}). --record(dest, {jid_string :: binary() | none, - jid_jid :: jid() | undefined, - type :: bcc | cc | noreply | ofrom | replyroom | replyto | to, - address :: address()}). - -type limit_value() :: {default | custom, integer()}. -record(limits, {message :: limit_value(), presence :: limit_value()}). @@ -63,14 +58,6 @@ -record(service_limits, {local :: #limits{}, remote :: #limits{}}). --type routing() :: route_single | {route_multicast, binary(), #service_limits{}}. - --record(group, {server :: binary(), - dests :: [#dest{}], - multicast :: routing() | undefined, - others :: [address()], - addresses :: [address()]}). - -record(state, {lserver :: binary(), lservice :: binary(), access :: atom(), @@ -117,7 +104,7 @@ reload(LServerS, NewOpts, OldOpts) -> user_send_packet({#presence{} = Packet, C2SState} = Acc) -> case xmpp:get_subtag(Packet, #addresses{}) of #addresses{list = Addresses} -> - {ToDeliver, _Delivereds} = split_addresses_todeliver(Addresses), + {CC, BCC, _Invalid, _Delivered} = partition_addresses(Addresses), NewState = lists:foldl( fun(Address, St) -> @@ -138,7 +125,7 @@ user_send_packet({#presence{} = Packet, C2SState} = Acc) -> undefined -> St end - end, C2SState, ToDeliver), + end, C2SState, CC ++ BCC), {Packet, NewState}; false -> Acc @@ -308,19 +295,10 @@ iq_vcard(Lang, State) -> %%%------------------------- -spec route_trusted(binary(), binary(), jid(), [jid()], stanza()) -> 'ok'. -route_trusted(LServiceS, LServerS, FromJID, - Destinations, Packet) -> - Packet_stripped = Packet, - Delivereds = [], - Dests2 = lists:map( - fun(D) -> - #dest{jid_string = jid:encode(D), - jid_jid = D, type = bcc, - address = #address{type = bcc, jid = D}} - end, Destinations), - Groups = group_dests(Dests2), - route_common(LServerS, LServiceS, FromJID, Groups, - Delivereds, Packet_stripped). +route_trusted(LServiceS, LServerS, FromJID, Destinations, Packet) -> + Addresses = [#address{type = bcc, jid = D} || D <- Destinations], + Groups = group_by_destinations(Addresses, #{}), + route_grouped(LServerS, LServiceS, FromJID, Groups, [], Packet). -spec route_untrusted(binary(), binary(), atom(), #service_limits{}, stanza()) -> 'ok'. route_untrusted(LServiceS, LServerS, Access, SLimits, Packet) -> @@ -356,50 +334,88 @@ route_untrusted(LServiceS, LServerS, Access, SLimits, Packet) -> route_untrusted2(LServiceS, LServerS, Access, SLimits, Packet) -> FromJID = xmpp:get_from(Packet), ok = check_access(LServerS, Access, FromJID), - {ok, Packet_stripped, Addresses} = strip_addresses_element(Packet), - {To_deliver, Delivereds} = split_addresses_todeliver(Addresses), - Dests = convert_dest_record(To_deliver), - {Dests2, Not_jids} = split_dests_jid(Dests), - report_not_jid(FromJID, Packet, Not_jids), - ok = check_limit_dests(SLimits, FromJID, Packet, Dests2), - Groups = group_dests(Dests2), + {ok, PacketStripped, Addresses} = strip_addresses_element(Packet), + {CC, BCC, NotJids, Rest} = partition_addresses(Addresses), + report_not_jid(FromJID, Packet, NotJids), + ok = check_limit_dests(SLimits, FromJID, Packet, length(CC) + length(BCC)), + Groups0 = group_by_destinations(CC, #{}), + Groups = group_by_destinations(BCC, Groups0), ok = check_relay(FromJID#jid.server, LServerS, Groups), - route_common(LServerS, LServiceS, FromJID, Groups, - Delivereds, Packet_stripped). + route_grouped(LServerS, LServiceS, FromJID, Groups, Rest, PacketStripped). --spec route_common(binary(), binary(), jid(), [#group{}], - [address()], stanza()) -> 'ok'. -route_common(LServerS, LServiceS, FromJID, Groups, - Delivereds, Packet_stripped) -> - Groups2 = look_cached_servers(LServerS, LServiceS, Groups), - Groups3 = build_others_xml(Groups2), - Groups4 = add_addresses(Delivereds, Groups3), - AGroups = decide_action_groups(Groups4), - act_groups(FromJID, Packet_stripped, LServiceS, - AGroups). +-spec mark_as_delivered([address()]) -> [address()]. +mark_as_delivered(Addresses) -> + [A#address{delivered = true} || A <- Addresses]. --spec act_groups(jid(), stanza(), binary(), [{routing(), #group{}}]) -> 'ok'. -act_groups(FromJID, Packet_stripped, LServiceS, AGroups) -> +-spec route_individual(jid(), [address()], [address()], [address()], stanza()) -> ok. +route_individual(From, CC, BCC, Other, Packet) -> + CCDelivered = mark_as_delivered(CC), + Addresses = CCDelivered ++ Other, + PacketWithAddresses = xmpp:append_subtags(Packet, [#addresses{list = Addresses}]), lists:foreach( - fun(AGroup) -> - perform(FromJID, Packet_stripped, LServiceS, - AGroup) - end, AGroups). - --spec perform(jid(), stanza(), binary(), - {routing(), #group{}}) -> 'ok'. -perform(From, Packet, _, - {route_single, Group}) -> + fun(#address{jid = To}) -> + ejabberd_router:route(xmpp:set_from_to(PacketWithAddresses, From, To)) + end, CC), lists:foreach( - fun(ToUser) -> - Group_others = strip_other_bcc(ToUser, Group#group.others), - route_packet(From, ToUser, Packet, - Group_others, Group#group.addresses) - end, Group#group.dests); -perform(From, Packet, _, - {{route_multicast, JID, RLimits}, Group}) -> - route_packet_multicast(From, JID, Packet, - Group#group.dests, Group#group.addresses, RLimits). + fun(#address{jid = To} = Address) -> + Packet2 = case Addresses of + [] -> + Packet; + _ -> + xmpp:append_subtags(Packet, [#addresses{list = [Address | Addresses]}]) + end, + ejabberd_router:route(xmpp:set_from_to(Packet2, From, To)) + end, BCC). + +-spec route_chunk(jid(), jid(), stanza(), [address()]) -> ok. +route_chunk(From, To, Packet, Addresses) -> + PacketWithAddresses = xmpp:append_subtags(Packet, [#addresses{list = Addresses}]), + ejabberd_router:route(xmpp:set_from_to(PacketWithAddresses, From, To)). + +-spec route_in_chunks(jid(), jid(), stanza(), integer(), [address()], [address()], [address()]) -> ok. +route_in_chunks(_From, _To, _Packet, _Limit, [], [], _) -> + ok; +route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses) when length(CC) > Limit -> + {Chunk, Rest} = lists:split(Limit, CC), + route_chunk(From, To, Packet, Chunk ++ RestOfAddresses), + route_in_chunks(From, To, Packet, Limit, Rest, BCC, RestOfAddresses); +route_in_chunks(From, To, Packet, Limit, [], BCC, RestOfAddresses) when length(BCC) > Limit -> + {Chunk, Rest} = lists:split(Limit, BCC), + route_chunk(From, To, Packet, Chunk ++ RestOfAddresses), + route_in_chunks(From, To, Packet, Limit, [], Rest, RestOfAddresses); +route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses) when length(BCC) + length(CC) > Limit -> + {Chunk, Rest} = lists:split(Limit - length(CC), BCC), + route_chunk(From, To, Packet, CC ++ Chunk ++ RestOfAddresses), + route_in_chunks(From, To, Packet, Limit, [], Rest, RestOfAddresses); +route_in_chunks(From, To, Packet, _Limit, CC, BCC, RestOfAddresses) -> + route_chunk(From, To, Packet, CC ++ BCC ++ RestOfAddresses). + +-spec route_multicast(jid(), jid(), [address()], [address()], [address()], stanza(), #limits{}) -> ok. +route_multicast(From, To, CC, BCC, RestOfAddresses, Packet, Limits) -> + {_Type, Limit} = get_limit_number(element(1, Packet), + Limits), + route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses). + +-spec route_grouped(binary(), binary(), jid(), #{}, [address()], stanza()) -> ok. +route_grouped(LServer, LService, From, Groups, RestOfAddresses, Packet) -> + maps:fold( + fun(Server, {CC, BCC}, _) -> + OtherCC = maps:fold( + fun(Server2, _, Res) when Server2 == Server -> + Res; + (_, {CC2, _}, Res) -> + mark_as_delivered(CC2) ++ Res + end, [], Groups), + case search_server_on_cache(Server, + LServer, LService, + {?MAXTIME_CACHE_POSITIVE, + ?MAXTIME_CACHE_NEGATIVE}) of + route_single -> + route_individual(From, CC, BCC, OtherCC ++ RestOfAddresses, Packet); + {route_multicast, Service, Limits} -> + route_multicast(From, Service, CC, BCC, OtherCC ++ RestOfAddresses, Packet, Limits) + end + end, ok, Groups). %%%------------------------- %%% Check access permission @@ -425,245 +441,89 @@ strip_addresses_element(Packet) -> throw(eadsele) end. -%%%------------------------- -%%% Strip third-party bcc 'addresses' -%%%------------------------- - -strip_other_bcc(#dest{jid_jid = ToUserJid}, Group_others) -> - lists:filter( - fun(#address{jid = JID, type = Type}) -> - case {JID, Type} of - {ToUserJid, bcc} -> true; - {_, bcc} -> false; - _ -> true - end - end, - Group_others). - %%%------------------------- %%% Split Addresses %%%------------------------- --spec split_addresses_todeliver([address()]) -> {[address()], [address()]}. -split_addresses_todeliver(Addresses) -> - lists:partition( - fun(#address{delivered = true}) -> - false; - (#address{type = Type}) -> - case Type of - to -> true; - cc -> true; - bcc -> true; - _ -> false - end - end, Addresses). +partition_addresses(Addresses) -> + lists:foldl( + fun(#address{delivered = true} = A, {C, B, I, D}) -> + {C, B, I, [A | D]}; + (#address{type = T, jid = undefined} = A, {C, B, I, D}) + when T == to; T == cc; T == bcc -> + {C, B, [A | I], D}; + (#address{type = T} = A, {C, B, I, D}) + when T == to; T == cc -> + {[A | C], B, I, D}; + (#address{type = bcc} = A, {C, B, I, D}) -> + {C, [A | B], I, D}; + (A, {C, B, I, D}) -> + {C, B, I, [A | D]} + end, {[], [], [], []}, Addresses). %%%------------------------- %%% Check does not exceed limit of destinations %%%------------------------- --spec check_limit_dests(#service_limits{}, jid(), stanza(), [address()]) -> ok. -check_limit_dests(SLimits, FromJID, Packet, - Addresses) -> +-spec check_limit_dests(#service_limits{}, jid(), stanza(), integer()) -> ok. +check_limit_dests(SLimits, FromJID, Packet, NumOfAddresses) -> SenderT = sender_type(FromJID), Limits = get_slimit_group(SenderT, SLimits), - Type_of_stanza = type_of_stanza(Packet), - {_Type, Limit_number} = get_limit_number(Type_of_stanza, - Limits), - case length(Addresses) > Limit_number of + StanzaType = type_of_stanza(Packet), + {_Type, Limit} = get_limit_number(StanzaType, + Limits), + case NumOfAddresses > Limit of false -> ok; true -> throw(etoorec) end. -%%%------------------------- -%%% Convert Destination XML to record -%%%------------------------- --spec convert_dest_record([address()]) -> [#dest{}]. -convert_dest_record(Addrs) -> - lists:map( - fun(#address{jid = undefined, type = Type} = Addr) -> - #dest{jid_string = none, - type = Type, address = Addr}; - (#address{jid = JID, type = Type} = Addr) -> - #dest{jid_string = jid:encode(JID), jid_jid = JID, - type = Type, address = Addr} - end, Addrs). - -%%%------------------------- -%%% Split destinations by existence of JID -%%% and send error messages for other dests -%%%------------------------- - --spec split_dests_jid([#dest{}]) -> {[#dest{}], [#dest{}]}. -split_dests_jid(Dests) -> - lists:partition(fun (Dest) -> - case Dest#dest.jid_string of - none -> false; - _ -> true - end - end, - Dests). - --spec report_not_jid(jid(), stanza(), [#dest{}]) -> any(). -report_not_jid(From, Packet, Dests) -> - Dests2 = [fxml:element_to_binary(xmpp:encode(Dest#dest.address)) - || Dest <- Dests], - [route_error( - xmpp:set_from_to(Packet, From, From), jid_malformed, - str:format(?T("This service can not process the address: ~s"), [D])) - || D <- Dests2]. +-spec report_not_jid(jid(), stanza(), [address()]) -> any(). +report_not_jid(From, Packet, Addresses) -> + lists:foreach( + fun(Address) -> + route_error( + xmpp:set_from_to(Packet, From, From), jid_malformed, + str:format(?T("This service can not process the address: ~s"), + [fxml:element_to_binary(xmpp:encode(Address))])) + end, Addresses). %%%------------------------- %%% Group destinations by their servers %%%------------------------- --spec group_dests([#dest{}]) -> [#group{}]. -group_dests(Dests) -> - D = lists:foldl(fun (Dest, Dict) -> - ServerS = (Dest#dest.jid_jid)#jid.server, - dict:append(ServerS, Dest, Dict) - end, - dict:new(), Dests), - Keys = dict:fetch_keys(D), - [#group{server = Key, dests = dict:fetch(Key, D), - addresses = [], others = []} - || Key <- Keys]. - -%%%------------------------- -%%% Look for cached responses -%%%------------------------- - -look_cached_servers(LServerS, LServiceS, Groups) -> - [look_cached(LServerS, LServiceS, Group) || Group <- Groups]. - -look_cached(LServerS, LServiceS, G) -> - Maxtime_positive = (?MAXTIME_CACHE_POSITIVE), - Maxtime_negative = (?MAXTIME_CACHE_NEGATIVE), - Cached_response = search_server_on_cache(G#group.server, - LServerS, LServiceS, - {Maxtime_positive, - Maxtime_negative}), - G#group{multicast = Cached_response}. - -%%%------------------------- -%%% Build delivered XML element -%%%------------------------- - -build_others_xml(Groups) -> - [Group#group{others = - build_other_xml(Group#group.dests)} - || Group <- Groups]. - -build_other_xml(Dests) -> - lists:foldl(fun (Dest, R) -> - XML = Dest#dest.address, - case Dest#dest.type of - to -> [add_delivered(XML) | R]; - cc -> [add_delivered(XML) | R]; - _ -> [XML | R] - end - end, - [], Dests). - --spec add_delivered(address()) -> address(). -add_delivered(Addr) -> - Addr#address{delivered = true}. - -%%%------------------------- -%%% Add preliminary packets -%%%------------------------- - -add_addresses(Delivereds, Groups) -> - Ps = [Group#group.others || Group <- Groups], - add_addresses2(Delivereds, Groups, [], [], Ps). - -add_addresses2(_, [], Res, _, []) -> Res; -add_addresses2(Delivereds, [Group | Groups], Res, Pa, - [Pi | Pz]) -> - Addresses = lists:append([Delivereds] ++ Pa ++ Pz), - Group2 = Group#group{addresses = Addresses}, - add_addresses2(Delivereds, Groups, [Group2 | Res], - [Pi | Pa], Pz). - -%%%------------------------- -%%% Decide action groups -%%%------------------------- - --spec decide_action_groups([#group{}]) -> [{routing(), #group{}}]. -decide_action_groups(Groups) -> - [{Group#group.multicast, Group} - || Group <- Groups]. +group_by_destinations(Addrs, Map) -> + lists:foldl( + fun + (#address{type = Type, jid = #jid{lserver = Server}} = Addr, Map2) when Type == to; Type == cc -> + maps:update_with(Server, + fun({CC, BCC}) -> + {[Addr | CC], BCC} + end, {[Addr], []}, Map2); + (#address{type = bcc, jid = #jid{lserver = Server}} = Addr, Map2) -> + maps:update_with(Server, + fun({CC, BCC}) -> + {CC, [Addr | BCC]} + end, {[], [Addr]}, Map2) + end, Map, Addrs). %%%------------------------- %%% Route packet %%%------------------------- --spec route_packet(jid(), #dest{}, stanza(), [addresses()], [addresses()]) -> 'ok'. -route_packet(From, ToDest, Packet, Others, Addresses) -> - Dests = case ToDest#dest.type of - bcc -> []; - _ -> [ToDest] - end, - route_packet2(From, ToDest#dest.jid_string, Dests, - Packet, {Others, Addresses}). - --spec route_packet_multicast(jid(), binary(), stanza(), [#dest{}], [address()], #limits{}) -> 'ok'. -route_packet_multicast(From, ToS, Packet, Dests, - Addresses, Limits) -> - Type_of_stanza = type_of_stanza(Packet), - {_Type, Limit_number} = get_limit_number(Type_of_stanza, - Limits), - Fragmented_dests = fragment_dests(Dests, Limit_number), - lists:foreach(fun(DFragment) -> - route_packet2(From, ToS, DFragment, Packet, - Addresses) - end, Fragmented_dests). - --spec route_packet2(jid(), binary(), [#dest{}], stanza(), {[address()], [address()]} | [address()]) -> 'ok'. -route_packet2(From, ToS, Dests, Packet, Addresses) -> - Els = case append_dests(Dests, Addresses) of - [] -> - xmpp:get_els(Packet); - ACs -> - [#addresses{list = ACs}|xmpp:get_els(Packet)] - end, - Packet2 = xmpp:set_els(Packet, Els), - ToJID = stj(ToS), - ejabberd_router:route(xmpp:set_from_to(Packet2, From, ToJID)). - --spec append_dests([#dest{}], {[address()], [address()]} | [address()]) -> [address()]. -append_dests(_Dests, {Others, Addresses}) -> - Addresses ++ Others; -append_dests([], Addresses) -> Addresses; -append_dests([Dest | Dests], Addresses) -> - append_dests(Dests, [Dest#dest.address | Addresses]). - %%%------------------------- %%% Check relay %%%------------------------- --spec check_relay(binary(), binary(), [#group{}]) -> ok. +-spec check_relay(binary(), binary(), #{}) -> ok. check_relay(RS, LS, Gs) -> - case check_relay_required(RS, LS, Gs) of - false -> ok; - true -> throw(edrelay) + case lists:suffix(str:tokens(LS, <<".">>), + str:tokens(RS, <<".">>)) orelse + (maps:is_key(LS, Gs) andalso maps:size(Gs) == 1) of + true -> ok; + _ -> throw(edrelay) end. --spec check_relay_required(binary(), binary(), [#group{}]) -> boolean(). -check_relay_required(RServer, LServerS, Groups) -> - case lists:suffix(str:tokens(LServerS, <<".">>), - str:tokens(RServer, <<".">>)) of - true -> false; - false -> check_relay_required(LServerS, Groups) - end. - --spec check_relay_required(binary(), [#group{}]) -> boolean(). -check_relay_required(LServerS, Groups) -> - lists:any(fun (Group) -> Group#group.server /= LServerS - end, - Groups). - %%%------------------------- %%% Check protocol support: Send request %%%------------------------- @@ -1060,20 +920,6 @@ get_slimit_group(local, SLimits) -> get_slimit_group(remote, SLimits) -> SLimits#service_limits.remote. -fragment_dests(Dests, Limit_number) -> - {R, _} = lists:foldl(fun (Dest, {Res, Count}) -> - case Count of - Limit_number -> - Head2 = [Dest], {[Head2 | Res], 0}; - _ -> - [Head | Tail] = Res, - Head2 = [Dest | Head], - {[Head2 | Tail], Count + 1} - end - end, - {[[]], 0}, Dests), - R. - %%%------------------------- %%% Limits: XEP-0128 Service Discovery Extensions %%%-------------------------