25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-11-24 16:23:40 +01:00

Improve mod_multicast

This commit is contained in:
Paweł Chmielowski 2021-11-16 19:39:59 +01:00
parent 97b8373fd2
commit 405a5172d5

View File

@ -35,7 +35,7 @@
%% API %% API
-export([start/2, stop/1, reload/3, -export([start/2, stop/1, reload/3,
user_send_packet/1]). user_send_packet/1]).
%% gen_server callbacks %% gen_server callbacks
-export([init/1, handle_info/2, handle_call/3, -export([init/1, handle_info/2, handle_call/3,
@ -51,11 +51,6 @@
response, response,
ts :: integer()}). ts :: integer()}).
-record(dest, {jid_string :: binary() | none,
jid_jid :: jid() | undefined,
type :: bcc | cc | noreply | ofrom | replyroom | replyto | to,
address :: address()}).
-type limit_value() :: {default | custom, integer()}. -type limit_value() :: {default | custom, integer()}.
-record(limits, {message :: limit_value(), -record(limits, {message :: limit_value(),
presence :: limit_value()}). presence :: limit_value()}).
@ -63,14 +58,6 @@
-record(service_limits, {local :: #limits{}, -record(service_limits, {local :: #limits{},
remote :: #limits{}}). remote :: #limits{}}).
-type routing() :: route_single | {route_multicast, binary(), #service_limits{}}.
-record(group, {server :: binary(),
dests :: [#dest{}],
multicast :: routing() | undefined,
others :: [address()],
addresses :: [address()]}).
-record(state, {lserver :: binary(), -record(state, {lserver :: binary(),
lservice :: binary(), lservice :: binary(),
access :: atom(), access :: atom(),
@ -117,7 +104,7 @@ reload(LServerS, NewOpts, OldOpts) ->
user_send_packet({#presence{} = Packet, C2SState} = Acc) -> user_send_packet({#presence{} = Packet, C2SState} = Acc) ->
case xmpp:get_subtag(Packet, #addresses{}) of case xmpp:get_subtag(Packet, #addresses{}) of
#addresses{list = Addresses} -> #addresses{list = Addresses} ->
{ToDeliver, _Delivereds} = split_addresses_todeliver(Addresses), {CC, BCC, _Invalid, _Delivered} = partition_addresses(Addresses),
NewState = NewState =
lists:foldl( lists:foldl(
fun(Address, St) -> fun(Address, St) ->
@ -138,7 +125,7 @@ user_send_packet({#presence{} = Packet, C2SState} = Acc) ->
undefined -> undefined ->
St St
end end
end, C2SState, ToDeliver), end, C2SState, CC ++ BCC),
{Packet, NewState}; {Packet, NewState};
false -> false ->
Acc Acc
@ -308,19 +295,10 @@ iq_vcard(Lang, State) ->
%%%------------------------- %%%-------------------------
-spec route_trusted(binary(), binary(), jid(), [jid()], stanza()) -> 'ok'. -spec route_trusted(binary(), binary(), jid(), [jid()], stanza()) -> 'ok'.
route_trusted(LServiceS, LServerS, FromJID, route_trusted(LServiceS, LServerS, FromJID, Destinations, Packet) ->
Destinations, Packet) -> Addresses = [#address{type = bcc, jid = D} || D <- Destinations],
Packet_stripped = Packet, Groups = group_by_destinations(Addresses, #{}),
Delivereds = [], route_grouped(LServerS, LServiceS, FromJID, Groups, [], Packet).
Dests2 = lists:map(
fun(D) ->
#dest{jid_string = jid:encode(D),
jid_jid = D, type = bcc,
address = #address{type = bcc, jid = D}}
end, Destinations),
Groups = group_dests(Dests2),
route_common(LServerS, LServiceS, FromJID, Groups,
Delivereds, Packet_stripped).
-spec route_untrusted(binary(), binary(), atom(), #service_limits{}, stanza()) -> 'ok'. -spec route_untrusted(binary(), binary(), atom(), #service_limits{}, stanza()) -> 'ok'.
route_untrusted(LServiceS, LServerS, Access, SLimits, Packet) -> route_untrusted(LServiceS, LServerS, Access, SLimits, Packet) ->
@ -356,50 +334,88 @@ route_untrusted(LServiceS, LServerS, Access, SLimits, Packet) ->
route_untrusted2(LServiceS, LServerS, Access, SLimits, Packet) -> route_untrusted2(LServiceS, LServerS, Access, SLimits, Packet) ->
FromJID = xmpp:get_from(Packet), FromJID = xmpp:get_from(Packet),
ok = check_access(LServerS, Access, FromJID), ok = check_access(LServerS, Access, FromJID),
{ok, Packet_stripped, Addresses} = strip_addresses_element(Packet), {ok, PacketStripped, Addresses} = strip_addresses_element(Packet),
{To_deliver, Delivereds} = split_addresses_todeliver(Addresses), {CC, BCC, NotJids, Rest} = partition_addresses(Addresses),
Dests = convert_dest_record(To_deliver), report_not_jid(FromJID, Packet, NotJids),
{Dests2, Not_jids} = split_dests_jid(Dests), ok = check_limit_dests(SLimits, FromJID, Packet, length(CC) + length(BCC)),
report_not_jid(FromJID, Packet, Not_jids), Groups0 = group_by_destinations(CC, #{}),
ok = check_limit_dests(SLimits, FromJID, Packet, Dests2), Groups = group_by_destinations(BCC, Groups0),
Groups = group_dests(Dests2),
ok = check_relay(FromJID#jid.server, LServerS, Groups), ok = check_relay(FromJID#jid.server, LServerS, Groups),
route_common(LServerS, LServiceS, FromJID, Groups, route_grouped(LServerS, LServiceS, FromJID, Groups, Rest, PacketStripped).
Delivereds, Packet_stripped).
-spec route_common(binary(), binary(), jid(), [#group{}], -spec mark_as_delivered([address()]) -> [address()].
[address()], stanza()) -> 'ok'. mark_as_delivered(Addresses) ->
route_common(LServerS, LServiceS, FromJID, Groups, [A#address{delivered = true} || A <- Addresses].
Delivereds, Packet_stripped) ->
Groups2 = look_cached_servers(LServerS, LServiceS, Groups),
Groups3 = build_others_xml(Groups2),
Groups4 = add_addresses(Delivereds, Groups3),
AGroups = decide_action_groups(Groups4),
act_groups(FromJID, Packet_stripped, LServiceS,
AGroups).
-spec act_groups(jid(), stanza(), binary(), [{routing(), #group{}}]) -> 'ok'. -spec route_individual(jid(), [address()], [address()], [address()], stanza()) -> ok.
act_groups(FromJID, Packet_stripped, LServiceS, AGroups) -> route_individual(From, CC, BCC, Other, Packet) ->
CCDelivered = mark_as_delivered(CC),
Addresses = CCDelivered ++ Other,
PacketWithAddresses = xmpp:append_subtags(Packet, [#addresses{list = Addresses}]),
lists:foreach( lists:foreach(
fun(AGroup) -> fun(#address{jid = To}) ->
perform(FromJID, Packet_stripped, LServiceS, ejabberd_router:route(xmpp:set_from_to(PacketWithAddresses, From, To))
AGroup) end, CC),
end, AGroups).
-spec perform(jid(), stanza(), binary(),
{routing(), #group{}}) -> 'ok'.
perform(From, Packet, _,
{route_single, Group}) ->
lists:foreach( lists:foreach(
fun(ToUser) -> fun(#address{jid = To} = Address) ->
Group_others = strip_other_bcc(ToUser, Group#group.others), Packet2 = case Addresses of
route_packet(From, ToUser, Packet, [] ->
Group_others, Group#group.addresses) Packet;
end, Group#group.dests); _ ->
perform(From, Packet, _, xmpp:append_subtags(Packet, [#addresses{list = [Address | Addresses]}])
{{route_multicast, JID, RLimits}, Group}) -> end,
route_packet_multicast(From, JID, Packet, ejabberd_router:route(xmpp:set_from_to(Packet2, From, To))
Group#group.dests, Group#group.addresses, RLimits). end, BCC).
-spec route_chunk(jid(), jid(), stanza(), [address()]) -> ok.
route_chunk(From, To, Packet, Addresses) ->
PacketWithAddresses = xmpp:append_subtags(Packet, [#addresses{list = Addresses}]),
ejabberd_router:route(xmpp:set_from_to(PacketWithAddresses, From, To)).
-spec route_in_chunks(jid(), jid(), stanza(), integer(), [address()], [address()], [address()]) -> ok.
route_in_chunks(_From, _To, _Packet, _Limit, [], [], _) ->
ok;
route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses) when length(CC) > Limit ->
{Chunk, Rest} = lists:split(Limit, CC),
route_chunk(From, To, Packet, Chunk ++ RestOfAddresses),
route_in_chunks(From, To, Packet, Limit, Rest, BCC, RestOfAddresses);
route_in_chunks(From, To, Packet, Limit, [], BCC, RestOfAddresses) when length(BCC) > Limit ->
{Chunk, Rest} = lists:split(Limit, BCC),
route_chunk(From, To, Packet, Chunk ++ RestOfAddresses),
route_in_chunks(From, To, Packet, Limit, [], Rest, RestOfAddresses);
route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses) when length(BCC) + length(CC) > Limit ->
{Chunk, Rest} = lists:split(Limit - length(CC), BCC),
route_chunk(From, To, Packet, CC ++ Chunk ++ RestOfAddresses),
route_in_chunks(From, To, Packet, Limit, [], Rest, RestOfAddresses);
route_in_chunks(From, To, Packet, _Limit, CC, BCC, RestOfAddresses) ->
route_chunk(From, To, Packet, CC ++ BCC ++ RestOfAddresses).
-spec route_multicast(jid(), jid(), [address()], [address()], [address()], stanza(), #limits{}) -> ok.
route_multicast(From, To, CC, BCC, RestOfAddresses, Packet, Limits) ->
{_Type, Limit} = get_limit_number(element(1, Packet),
Limits),
route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses).
-spec route_grouped(binary(), binary(), jid(), #{}, [address()], stanza()) -> ok.
route_grouped(LServer, LService, From, Groups, RestOfAddresses, Packet) ->
maps:fold(
fun(Server, {CC, BCC}, _) ->
OtherCC = maps:fold(
fun(Server2, _, Res) when Server2 == Server ->
Res;
(_, {CC2, _}, Res) ->
mark_as_delivered(CC2) ++ Res
end, [], Groups),
case search_server_on_cache(Server,
LServer, LService,
{?MAXTIME_CACHE_POSITIVE,
?MAXTIME_CACHE_NEGATIVE}) of
route_single ->
route_individual(From, CC, BCC, OtherCC ++ RestOfAddresses, Packet);
{route_multicast, Service, Limits} ->
route_multicast(From, Service, CC, BCC, OtherCC ++ RestOfAddresses, Packet, Limits)
end
end, ok, Groups).
%%%------------------------- %%%-------------------------
%%% Check access permission %%% Check access permission
@ -425,245 +441,89 @@ strip_addresses_element(Packet) ->
throw(eadsele) throw(eadsele)
end. end.
%%%-------------------------
%%% Strip third-party bcc 'addresses'
%%%-------------------------
strip_other_bcc(#dest{jid_jid = ToUserJid}, Group_others) ->
lists:filter(
fun(#address{jid = JID, type = Type}) ->
case {JID, Type} of
{ToUserJid, bcc} -> true;
{_, bcc} -> false;
_ -> true
end
end,
Group_others).
%%%------------------------- %%%-------------------------
%%% Split Addresses %%% Split Addresses
%%%------------------------- %%%-------------------------
-spec split_addresses_todeliver([address()]) -> {[address()], [address()]}. partition_addresses(Addresses) ->
split_addresses_todeliver(Addresses) -> lists:foldl(
lists:partition( fun(#address{delivered = true} = A, {C, B, I, D}) ->
fun(#address{delivered = true}) -> {C, B, I, [A | D]};
false; (#address{type = T, jid = undefined} = A, {C, B, I, D})
(#address{type = Type}) -> when T == to; T == cc; T == bcc ->
case Type of {C, B, [A | I], D};
to -> true; (#address{type = T} = A, {C, B, I, D})
cc -> true; when T == to; T == cc ->
bcc -> true; {[A | C], B, I, D};
_ -> false (#address{type = bcc} = A, {C, B, I, D}) ->
end {C, [A | B], I, D};
end, Addresses). (A, {C, B, I, D}) ->
{C, B, I, [A | D]}
end, {[], [], [], []}, Addresses).
%%%------------------------- %%%-------------------------
%%% Check does not exceed limit of destinations %%% Check does not exceed limit of destinations
%%%------------------------- %%%-------------------------
-spec check_limit_dests(#service_limits{}, jid(), stanza(), [address()]) -> ok. -spec check_limit_dests(#service_limits{}, jid(), stanza(), integer()) -> ok.
check_limit_dests(SLimits, FromJID, Packet, check_limit_dests(SLimits, FromJID, Packet, NumOfAddresses) ->
Addresses) ->
SenderT = sender_type(FromJID), SenderT = sender_type(FromJID),
Limits = get_slimit_group(SenderT, SLimits), Limits = get_slimit_group(SenderT, SLimits),
Type_of_stanza = type_of_stanza(Packet), StanzaType = type_of_stanza(Packet),
{_Type, Limit_number} = get_limit_number(Type_of_stanza, {_Type, Limit} = get_limit_number(StanzaType,
Limits), Limits),
case length(Addresses) > Limit_number of case NumOfAddresses > Limit of
false -> ok; false -> ok;
true -> throw(etoorec) true -> throw(etoorec)
end. end.
%%%-------------------------
%%% Convert Destination XML to record
%%%-------------------------
-spec convert_dest_record([address()]) -> [#dest{}]. -spec report_not_jid(jid(), stanza(), [address()]) -> any().
convert_dest_record(Addrs) -> report_not_jid(From, Packet, Addresses) ->
lists:map( lists:foreach(
fun(#address{jid = undefined, type = Type} = Addr) -> fun(Address) ->
#dest{jid_string = none, route_error(
type = Type, address = Addr}; xmpp:set_from_to(Packet, From, From), jid_malformed,
(#address{jid = JID, type = Type} = Addr) -> str:format(?T("This service can not process the address: ~s"),
#dest{jid_string = jid:encode(JID), jid_jid = JID, [fxml:element_to_binary(xmpp:encode(Address))]))
type = Type, address = Addr} end, Addresses).
end, Addrs).
%%%-------------------------
%%% Split destinations by existence of JID
%%% and send error messages for other dests
%%%-------------------------
-spec split_dests_jid([#dest{}]) -> {[#dest{}], [#dest{}]}.
split_dests_jid(Dests) ->
lists:partition(fun (Dest) ->
case Dest#dest.jid_string of
none -> false;
_ -> true
end
end,
Dests).
-spec report_not_jid(jid(), stanza(), [#dest{}]) -> any().
report_not_jid(From, Packet, Dests) ->
Dests2 = [fxml:element_to_binary(xmpp:encode(Dest#dest.address))
|| Dest <- Dests],
[route_error(
xmpp:set_from_to(Packet, From, From), jid_malformed,
str:format(?T("This service can not process the address: ~s"), [D]))
|| D <- Dests2].
%%%------------------------- %%%-------------------------
%%% Group destinations by their servers %%% Group destinations by their servers
%%%------------------------- %%%-------------------------
-spec group_dests([#dest{}]) -> [#group{}]. group_by_destinations(Addrs, Map) ->
group_dests(Dests) -> lists:foldl(
D = lists:foldl(fun (Dest, Dict) -> fun
ServerS = (Dest#dest.jid_jid)#jid.server, (#address{type = Type, jid = #jid{lserver = Server}} = Addr, Map2) when Type == to; Type == cc ->
dict:append(ServerS, Dest, Dict) maps:update_with(Server,
end, fun({CC, BCC}) ->
dict:new(), Dests), {[Addr | CC], BCC}
Keys = dict:fetch_keys(D), end, {[Addr], []}, Map2);
[#group{server = Key, dests = dict:fetch(Key, D), (#address{type = bcc, jid = #jid{lserver = Server}} = Addr, Map2) ->
addresses = [], others = []} maps:update_with(Server,
|| Key <- Keys]. fun({CC, BCC}) ->
{CC, [Addr | BCC]}
%%%------------------------- end, {[], [Addr]}, Map2)
%%% Look for cached responses end, Map, Addrs).
%%%-------------------------
look_cached_servers(LServerS, LServiceS, Groups) ->
[look_cached(LServerS, LServiceS, Group) || Group <- Groups].
look_cached(LServerS, LServiceS, G) ->
Maxtime_positive = (?MAXTIME_CACHE_POSITIVE),
Maxtime_negative = (?MAXTIME_CACHE_NEGATIVE),
Cached_response = search_server_on_cache(G#group.server,
LServerS, LServiceS,
{Maxtime_positive,
Maxtime_negative}),
G#group{multicast = Cached_response}.
%%%-------------------------
%%% Build delivered XML element
%%%-------------------------
build_others_xml(Groups) ->
[Group#group{others =
build_other_xml(Group#group.dests)}
|| Group <- Groups].
build_other_xml(Dests) ->
lists:foldl(fun (Dest, R) ->
XML = Dest#dest.address,
case Dest#dest.type of
to -> [add_delivered(XML) | R];
cc -> [add_delivered(XML) | R];
_ -> [XML | R]
end
end,
[], Dests).
-spec add_delivered(address()) -> address().
add_delivered(Addr) ->
Addr#address{delivered = true}.
%%%-------------------------
%%% Add preliminary packets
%%%-------------------------
add_addresses(Delivereds, Groups) ->
Ps = [Group#group.others || Group <- Groups],
add_addresses2(Delivereds, Groups, [], [], Ps).
add_addresses2(_, [], Res, _, []) -> Res;
add_addresses2(Delivereds, [Group | Groups], Res, Pa,
[Pi | Pz]) ->
Addresses = lists:append([Delivereds] ++ Pa ++ Pz),
Group2 = Group#group{addresses = Addresses},
add_addresses2(Delivereds, Groups, [Group2 | Res],
[Pi | Pa], Pz).
%%%-------------------------
%%% Decide action groups
%%%-------------------------
-spec decide_action_groups([#group{}]) -> [{routing(), #group{}}].
decide_action_groups(Groups) ->
[{Group#group.multicast, Group}
|| Group <- Groups].
%%%------------------------- %%%-------------------------
%%% Route packet %%% Route packet
%%%------------------------- %%%-------------------------
-spec route_packet(jid(), #dest{}, stanza(), [addresses()], [addresses()]) -> 'ok'.
route_packet(From, ToDest, Packet, Others, Addresses) ->
Dests = case ToDest#dest.type of
bcc -> [];
_ -> [ToDest]
end,
route_packet2(From, ToDest#dest.jid_string, Dests,
Packet, {Others, Addresses}).
-spec route_packet_multicast(jid(), binary(), stanza(), [#dest{}], [address()], #limits{}) -> 'ok'.
route_packet_multicast(From, ToS, Packet, Dests,
Addresses, Limits) ->
Type_of_stanza = type_of_stanza(Packet),
{_Type, Limit_number} = get_limit_number(Type_of_stanza,
Limits),
Fragmented_dests = fragment_dests(Dests, Limit_number),
lists:foreach(fun(DFragment) ->
route_packet2(From, ToS, DFragment, Packet,
Addresses)
end, Fragmented_dests).
-spec route_packet2(jid(), binary(), [#dest{}], stanza(), {[address()], [address()]} | [address()]) -> 'ok'.
route_packet2(From, ToS, Dests, Packet, Addresses) ->
Els = case append_dests(Dests, Addresses) of
[] ->
xmpp:get_els(Packet);
ACs ->
[#addresses{list = ACs}|xmpp:get_els(Packet)]
end,
Packet2 = xmpp:set_els(Packet, Els),
ToJID = stj(ToS),
ejabberd_router:route(xmpp:set_from_to(Packet2, From, ToJID)).
-spec append_dests([#dest{}], {[address()], [address()]} | [address()]) -> [address()].
append_dests(_Dests, {Others, Addresses}) ->
Addresses ++ Others;
append_dests([], Addresses) -> Addresses;
append_dests([Dest | Dests], Addresses) ->
append_dests(Dests, [Dest#dest.address | Addresses]).
%%%------------------------- %%%-------------------------
%%% Check relay %%% Check relay
%%%------------------------- %%%-------------------------
-spec check_relay(binary(), binary(), [#group{}]) -> ok. -spec check_relay(binary(), binary(), #{}) -> ok.
check_relay(RS, LS, Gs) -> check_relay(RS, LS, Gs) ->
case check_relay_required(RS, LS, Gs) of case lists:suffix(str:tokens(LS, <<".">>),
false -> ok; str:tokens(RS, <<".">>)) orelse
true -> throw(edrelay) (maps:is_key(LS, Gs) andalso maps:size(Gs) == 1) of
true -> ok;
_ -> throw(edrelay)
end. end.
-spec check_relay_required(binary(), binary(), [#group{}]) -> boolean().
check_relay_required(RServer, LServerS, Groups) ->
case lists:suffix(str:tokens(LServerS, <<".">>),
str:tokens(RServer, <<".">>)) of
true -> false;
false -> check_relay_required(LServerS, Groups)
end.
-spec check_relay_required(binary(), [#group{}]) -> boolean().
check_relay_required(LServerS, Groups) ->
lists:any(fun (Group) -> Group#group.server /= LServerS
end,
Groups).
%%%------------------------- %%%-------------------------
%%% Check protocol support: Send request %%% Check protocol support: Send request
%%%------------------------- %%%-------------------------
@ -1060,20 +920,6 @@ get_slimit_group(local, SLimits) ->
get_slimit_group(remote, SLimits) -> get_slimit_group(remote, SLimits) ->
SLimits#service_limits.remote. SLimits#service_limits.remote.
fragment_dests(Dests, Limit_number) ->
{R, _} = lists:foldl(fun (Dest, {Res, Count}) ->
case Count of
Limit_number ->
Head2 = [Dest], {[Head2 | Res], 0};
_ ->
[Head | Tail] = Res,
Head2 = [Dest | Head],
{[Head2 | Tail], Count + 1}
end
end,
{[[]], 0}, Dests),
R.
%%%------------------------- %%%-------------------------
%%% Limits: XEP-0128 Service Discovery Extensions %%% Limits: XEP-0128 Service Discovery Extensions
%%%------------------------- %%%-------------------------