25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-11-20 16:15:59 +01:00

Improve mod_multicast

This commit is contained in:
Paweł Chmielowski 2021-11-16 19:39:59 +01:00
parent 97b8373fd2
commit 405a5172d5

View File

@ -35,7 +35,7 @@
%% API
-export([start/2, stop/1, reload/3,
user_send_packet/1]).
user_send_packet/1]).
%% gen_server callbacks
-export([init/1, handle_info/2, handle_call/3,
@ -51,11 +51,6 @@
response,
ts :: integer()}).
-record(dest, {jid_string :: binary() | none,
jid_jid :: jid() | undefined,
type :: bcc | cc | noreply | ofrom | replyroom | replyto | to,
address :: address()}).
-type limit_value() :: {default | custom, integer()}.
-record(limits, {message :: limit_value(),
presence :: limit_value()}).
@ -63,14 +58,6 @@
-record(service_limits, {local :: #limits{},
remote :: #limits{}}).
-type routing() :: route_single | {route_multicast, binary(), #service_limits{}}.
-record(group, {server :: binary(),
dests :: [#dest{}],
multicast :: routing() | undefined,
others :: [address()],
addresses :: [address()]}).
-record(state, {lserver :: binary(),
lservice :: binary(),
access :: atom(),
@ -117,7 +104,7 @@ reload(LServerS, NewOpts, OldOpts) ->
user_send_packet({#presence{} = Packet, C2SState} = Acc) ->
case xmpp:get_subtag(Packet, #addresses{}) of
#addresses{list = Addresses} ->
{ToDeliver, _Delivereds} = split_addresses_todeliver(Addresses),
{CC, BCC, _Invalid, _Delivered} = partition_addresses(Addresses),
NewState =
lists:foldl(
fun(Address, St) ->
@ -138,7 +125,7 @@ user_send_packet({#presence{} = Packet, C2SState} = Acc) ->
undefined ->
St
end
end, C2SState, ToDeliver),
end, C2SState, CC ++ BCC),
{Packet, NewState};
false ->
Acc
@ -308,19 +295,10 @@ iq_vcard(Lang, State) ->
%%%-------------------------
-spec route_trusted(binary(), binary(), jid(), [jid()], stanza()) -> 'ok'.
route_trusted(LServiceS, LServerS, FromJID,
Destinations, Packet) ->
Packet_stripped = Packet,
Delivereds = [],
Dests2 = lists:map(
fun(D) ->
#dest{jid_string = jid:encode(D),
jid_jid = D, type = bcc,
address = #address{type = bcc, jid = D}}
end, Destinations),
Groups = group_dests(Dests2),
route_common(LServerS, LServiceS, FromJID, Groups,
Delivereds, Packet_stripped).
route_trusted(LServiceS, LServerS, FromJID, Destinations, Packet) ->
Addresses = [#address{type = bcc, jid = D} || D <- Destinations],
Groups = group_by_destinations(Addresses, #{}),
route_grouped(LServerS, LServiceS, FromJID, Groups, [], Packet).
-spec route_untrusted(binary(), binary(), atom(), #service_limits{}, stanza()) -> 'ok'.
route_untrusted(LServiceS, LServerS, Access, SLimits, Packet) ->
@ -356,50 +334,88 @@ route_untrusted(LServiceS, LServerS, Access, SLimits, Packet) ->
route_untrusted2(LServiceS, LServerS, Access, SLimits, Packet) ->
FromJID = xmpp:get_from(Packet),
ok = check_access(LServerS, Access, FromJID),
{ok, Packet_stripped, Addresses} = strip_addresses_element(Packet),
{To_deliver, Delivereds} = split_addresses_todeliver(Addresses),
Dests = convert_dest_record(To_deliver),
{Dests2, Not_jids} = split_dests_jid(Dests),
report_not_jid(FromJID, Packet, Not_jids),
ok = check_limit_dests(SLimits, FromJID, Packet, Dests2),
Groups = group_dests(Dests2),
{ok, PacketStripped, Addresses} = strip_addresses_element(Packet),
{CC, BCC, NotJids, Rest} = partition_addresses(Addresses),
report_not_jid(FromJID, Packet, NotJids),
ok = check_limit_dests(SLimits, FromJID, Packet, length(CC) + length(BCC)),
Groups0 = group_by_destinations(CC, #{}),
Groups = group_by_destinations(BCC, Groups0),
ok = check_relay(FromJID#jid.server, LServerS, Groups),
route_common(LServerS, LServiceS, FromJID, Groups,
Delivereds, Packet_stripped).
route_grouped(LServerS, LServiceS, FromJID, Groups, Rest, PacketStripped).
-spec route_common(binary(), binary(), jid(), [#group{}],
[address()], stanza()) -> 'ok'.
route_common(LServerS, LServiceS, FromJID, Groups,
Delivereds, Packet_stripped) ->
Groups2 = look_cached_servers(LServerS, LServiceS, Groups),
Groups3 = build_others_xml(Groups2),
Groups4 = add_addresses(Delivereds, Groups3),
AGroups = decide_action_groups(Groups4),
act_groups(FromJID, Packet_stripped, LServiceS,
AGroups).
-spec mark_as_delivered([address()]) -> [address()].
mark_as_delivered(Addresses) ->
[A#address{delivered = true} || A <- Addresses].
-spec act_groups(jid(), stanza(), binary(), [{routing(), #group{}}]) -> 'ok'.
act_groups(FromJID, Packet_stripped, LServiceS, AGroups) ->
-spec route_individual(jid(), [address()], [address()], [address()], stanza()) -> ok.
route_individual(From, CC, BCC, Other, Packet) ->
CCDelivered = mark_as_delivered(CC),
Addresses = CCDelivered ++ Other,
PacketWithAddresses = xmpp:append_subtags(Packet, [#addresses{list = Addresses}]),
lists:foreach(
fun(AGroup) ->
perform(FromJID, Packet_stripped, LServiceS,
AGroup)
end, AGroups).
-spec perform(jid(), stanza(), binary(),
{routing(), #group{}}) -> 'ok'.
perform(From, Packet, _,
{route_single, Group}) ->
fun(#address{jid = To}) ->
ejabberd_router:route(xmpp:set_from_to(PacketWithAddresses, From, To))
end, CC),
lists:foreach(
fun(ToUser) ->
Group_others = strip_other_bcc(ToUser, Group#group.others),
route_packet(From, ToUser, Packet,
Group_others, Group#group.addresses)
end, Group#group.dests);
perform(From, Packet, _,
{{route_multicast, JID, RLimits}, Group}) ->
route_packet_multicast(From, JID, Packet,
Group#group.dests, Group#group.addresses, RLimits).
fun(#address{jid = To} = Address) ->
Packet2 = case Addresses of
[] ->
Packet;
_ ->
xmpp:append_subtags(Packet, [#addresses{list = [Address | Addresses]}])
end,
ejabberd_router:route(xmpp:set_from_to(Packet2, From, To))
end, BCC).
-spec route_chunk(jid(), jid(), stanza(), [address()]) -> ok.
route_chunk(From, To, Packet, Addresses) ->
PacketWithAddresses = xmpp:append_subtags(Packet, [#addresses{list = Addresses}]),
ejabberd_router:route(xmpp:set_from_to(PacketWithAddresses, From, To)).
-spec route_in_chunks(jid(), jid(), stanza(), integer(), [address()], [address()], [address()]) -> ok.
route_in_chunks(_From, _To, _Packet, _Limit, [], [], _) ->
ok;
route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses) when length(CC) > Limit ->
{Chunk, Rest} = lists:split(Limit, CC),
route_chunk(From, To, Packet, Chunk ++ RestOfAddresses),
route_in_chunks(From, To, Packet, Limit, Rest, BCC, RestOfAddresses);
route_in_chunks(From, To, Packet, Limit, [], BCC, RestOfAddresses) when length(BCC) > Limit ->
{Chunk, Rest} = lists:split(Limit, BCC),
route_chunk(From, To, Packet, Chunk ++ RestOfAddresses),
route_in_chunks(From, To, Packet, Limit, [], Rest, RestOfAddresses);
route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses) when length(BCC) + length(CC) > Limit ->
{Chunk, Rest} = lists:split(Limit - length(CC), BCC),
route_chunk(From, To, Packet, CC ++ Chunk ++ RestOfAddresses),
route_in_chunks(From, To, Packet, Limit, [], Rest, RestOfAddresses);
route_in_chunks(From, To, Packet, _Limit, CC, BCC, RestOfAddresses) ->
route_chunk(From, To, Packet, CC ++ BCC ++ RestOfAddresses).
-spec route_multicast(jid(), jid(), [address()], [address()], [address()], stanza(), #limits{}) -> ok.
route_multicast(From, To, CC, BCC, RestOfAddresses, Packet, Limits) ->
{_Type, Limit} = get_limit_number(element(1, Packet),
Limits),
route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses).
-spec route_grouped(binary(), binary(), jid(), #{}, [address()], stanza()) -> ok.
route_grouped(LServer, LService, From, Groups, RestOfAddresses, Packet) ->
maps:fold(
fun(Server, {CC, BCC}, _) ->
OtherCC = maps:fold(
fun(Server2, _, Res) when Server2 == Server ->
Res;
(_, {CC2, _}, Res) ->
mark_as_delivered(CC2) ++ Res
end, [], Groups),
case search_server_on_cache(Server,
LServer, LService,
{?MAXTIME_CACHE_POSITIVE,
?MAXTIME_CACHE_NEGATIVE}) of
route_single ->
route_individual(From, CC, BCC, OtherCC ++ RestOfAddresses, Packet);
{route_multicast, Service, Limits} ->
route_multicast(From, Service, CC, BCC, OtherCC ++ RestOfAddresses, Packet, Limits)
end
end, ok, Groups).
%%%-------------------------
%%% Check access permission
@ -425,245 +441,89 @@ strip_addresses_element(Packet) ->
throw(eadsele)
end.
%%%-------------------------
%%% Strip third-party bcc 'addresses'
%%%-------------------------
strip_other_bcc(#dest{jid_jid = ToUserJid}, Group_others) ->
lists:filter(
fun(#address{jid = JID, type = Type}) ->
case {JID, Type} of
{ToUserJid, bcc} -> true;
{_, bcc} -> false;
_ -> true
end
end,
Group_others).
%%%-------------------------
%%% Split Addresses
%%%-------------------------
-spec split_addresses_todeliver([address()]) -> {[address()], [address()]}.
split_addresses_todeliver(Addresses) ->
lists:partition(
fun(#address{delivered = true}) ->
false;
(#address{type = Type}) ->
case Type of
to -> true;
cc -> true;
bcc -> true;
_ -> false
end
end, Addresses).
partition_addresses(Addresses) ->
lists:foldl(
fun(#address{delivered = true} = A, {C, B, I, D}) ->
{C, B, I, [A | D]};
(#address{type = T, jid = undefined} = A, {C, B, I, D})
when T == to; T == cc; T == bcc ->
{C, B, [A | I], D};
(#address{type = T} = A, {C, B, I, D})
when T == to; T == cc ->
{[A | C], B, I, D};
(#address{type = bcc} = A, {C, B, I, D}) ->
{C, [A | B], I, D};
(A, {C, B, I, D}) ->
{C, B, I, [A | D]}
end, {[], [], [], []}, Addresses).
%%%-------------------------
%%% Check does not exceed limit of destinations
%%%-------------------------
-spec check_limit_dests(#service_limits{}, jid(), stanza(), [address()]) -> ok.
check_limit_dests(SLimits, FromJID, Packet,
Addresses) ->
-spec check_limit_dests(#service_limits{}, jid(), stanza(), integer()) -> ok.
check_limit_dests(SLimits, FromJID, Packet, NumOfAddresses) ->
SenderT = sender_type(FromJID),
Limits = get_slimit_group(SenderT, SLimits),
Type_of_stanza = type_of_stanza(Packet),
{_Type, Limit_number} = get_limit_number(Type_of_stanza,
Limits),
case length(Addresses) > Limit_number of
StanzaType = type_of_stanza(Packet),
{_Type, Limit} = get_limit_number(StanzaType,
Limits),
case NumOfAddresses > Limit of
false -> ok;
true -> throw(etoorec)
end.
%%%-------------------------
%%% Convert Destination XML to record
%%%-------------------------
-spec convert_dest_record([address()]) -> [#dest{}].
convert_dest_record(Addrs) ->
lists:map(
fun(#address{jid = undefined, type = Type} = Addr) ->
#dest{jid_string = none,
type = Type, address = Addr};
(#address{jid = JID, type = Type} = Addr) ->
#dest{jid_string = jid:encode(JID), jid_jid = JID,
type = Type, address = Addr}
end, Addrs).
%%%-------------------------
%%% Split destinations by existence of JID
%%% and send error messages for other dests
%%%-------------------------
-spec split_dests_jid([#dest{}]) -> {[#dest{}], [#dest{}]}.
split_dests_jid(Dests) ->
lists:partition(fun (Dest) ->
case Dest#dest.jid_string of
none -> false;
_ -> true
end
end,
Dests).
-spec report_not_jid(jid(), stanza(), [#dest{}]) -> any().
report_not_jid(From, Packet, Dests) ->
Dests2 = [fxml:element_to_binary(xmpp:encode(Dest#dest.address))
|| Dest <- Dests],
[route_error(
xmpp:set_from_to(Packet, From, From), jid_malformed,
str:format(?T("This service can not process the address: ~s"), [D]))
|| D <- Dests2].
-spec report_not_jid(jid(), stanza(), [address()]) -> any().
report_not_jid(From, Packet, Addresses) ->
lists:foreach(
fun(Address) ->
route_error(
xmpp:set_from_to(Packet, From, From), jid_malformed,
str:format(?T("This service can not process the address: ~s"),
[fxml:element_to_binary(xmpp:encode(Address))]))
end, Addresses).
%%%-------------------------
%%% Group destinations by their servers
%%%-------------------------
-spec group_dests([#dest{}]) -> [#group{}].
group_dests(Dests) ->
D = lists:foldl(fun (Dest, Dict) ->
ServerS = (Dest#dest.jid_jid)#jid.server,
dict:append(ServerS, Dest, Dict)
end,
dict:new(), Dests),
Keys = dict:fetch_keys(D),
[#group{server = Key, dests = dict:fetch(Key, D),
addresses = [], others = []}
|| Key <- Keys].
%%%-------------------------
%%% Look for cached responses
%%%-------------------------
look_cached_servers(LServerS, LServiceS, Groups) ->
[look_cached(LServerS, LServiceS, Group) || Group <- Groups].
look_cached(LServerS, LServiceS, G) ->
Maxtime_positive = (?MAXTIME_CACHE_POSITIVE),
Maxtime_negative = (?MAXTIME_CACHE_NEGATIVE),
Cached_response = search_server_on_cache(G#group.server,
LServerS, LServiceS,
{Maxtime_positive,
Maxtime_negative}),
G#group{multicast = Cached_response}.
%%%-------------------------
%%% Build delivered XML element
%%%-------------------------
build_others_xml(Groups) ->
[Group#group{others =
build_other_xml(Group#group.dests)}
|| Group <- Groups].
build_other_xml(Dests) ->
lists:foldl(fun (Dest, R) ->
XML = Dest#dest.address,
case Dest#dest.type of
to -> [add_delivered(XML) | R];
cc -> [add_delivered(XML) | R];
_ -> [XML | R]
end
end,
[], Dests).
-spec add_delivered(address()) -> address().
add_delivered(Addr) ->
Addr#address{delivered = true}.
%%%-------------------------
%%% Add preliminary packets
%%%-------------------------
add_addresses(Delivereds, Groups) ->
Ps = [Group#group.others || Group <- Groups],
add_addresses2(Delivereds, Groups, [], [], Ps).
add_addresses2(_, [], Res, _, []) -> Res;
add_addresses2(Delivereds, [Group | Groups], Res, Pa,
[Pi | Pz]) ->
Addresses = lists:append([Delivereds] ++ Pa ++ Pz),
Group2 = Group#group{addresses = Addresses},
add_addresses2(Delivereds, Groups, [Group2 | Res],
[Pi | Pa], Pz).
%%%-------------------------
%%% Decide action groups
%%%-------------------------
-spec decide_action_groups([#group{}]) -> [{routing(), #group{}}].
decide_action_groups(Groups) ->
[{Group#group.multicast, Group}
|| Group <- Groups].
group_by_destinations(Addrs, Map) ->
lists:foldl(
fun
(#address{type = Type, jid = #jid{lserver = Server}} = Addr, Map2) when Type == to; Type == cc ->
maps:update_with(Server,
fun({CC, BCC}) ->
{[Addr | CC], BCC}
end, {[Addr], []}, Map2);
(#address{type = bcc, jid = #jid{lserver = Server}} = Addr, Map2) ->
maps:update_with(Server,
fun({CC, BCC}) ->
{CC, [Addr | BCC]}
end, {[], [Addr]}, Map2)
end, Map, Addrs).
%%%-------------------------
%%% Route packet
%%%-------------------------
-spec route_packet(jid(), #dest{}, stanza(), [addresses()], [addresses()]) -> 'ok'.
route_packet(From, ToDest, Packet, Others, Addresses) ->
Dests = case ToDest#dest.type of
bcc -> [];
_ -> [ToDest]
end,
route_packet2(From, ToDest#dest.jid_string, Dests,
Packet, {Others, Addresses}).
-spec route_packet_multicast(jid(), binary(), stanza(), [#dest{}], [address()], #limits{}) -> 'ok'.
route_packet_multicast(From, ToS, Packet, Dests,
Addresses, Limits) ->
Type_of_stanza = type_of_stanza(Packet),
{_Type, Limit_number} = get_limit_number(Type_of_stanza,
Limits),
Fragmented_dests = fragment_dests(Dests, Limit_number),
lists:foreach(fun(DFragment) ->
route_packet2(From, ToS, DFragment, Packet,
Addresses)
end, Fragmented_dests).
-spec route_packet2(jid(), binary(), [#dest{}], stanza(), {[address()], [address()]} | [address()]) -> 'ok'.
route_packet2(From, ToS, Dests, Packet, Addresses) ->
Els = case append_dests(Dests, Addresses) of
[] ->
xmpp:get_els(Packet);
ACs ->
[#addresses{list = ACs}|xmpp:get_els(Packet)]
end,
Packet2 = xmpp:set_els(Packet, Els),
ToJID = stj(ToS),
ejabberd_router:route(xmpp:set_from_to(Packet2, From, ToJID)).
-spec append_dests([#dest{}], {[address()], [address()]} | [address()]) -> [address()].
append_dests(_Dests, {Others, Addresses}) ->
Addresses ++ Others;
append_dests([], Addresses) -> Addresses;
append_dests([Dest | Dests], Addresses) ->
append_dests(Dests, [Dest#dest.address | Addresses]).
%%%-------------------------
%%% Check relay
%%%-------------------------
-spec check_relay(binary(), binary(), [#group{}]) -> ok.
-spec check_relay(binary(), binary(), #{}) -> ok.
check_relay(RS, LS, Gs) ->
case check_relay_required(RS, LS, Gs) of
false -> ok;
true -> throw(edrelay)
case lists:suffix(str:tokens(LS, <<".">>),
str:tokens(RS, <<".">>)) orelse
(maps:is_key(LS, Gs) andalso maps:size(Gs) == 1) of
true -> ok;
_ -> throw(edrelay)
end.
-spec check_relay_required(binary(), binary(), [#group{}]) -> boolean().
check_relay_required(RServer, LServerS, Groups) ->
case lists:suffix(str:tokens(LServerS, <<".">>),
str:tokens(RServer, <<".">>)) of
true -> false;
false -> check_relay_required(LServerS, Groups)
end.
-spec check_relay_required(binary(), [#group{}]) -> boolean().
check_relay_required(LServerS, Groups) ->
lists:any(fun (Group) -> Group#group.server /= LServerS
end,
Groups).
%%%-------------------------
%%% Check protocol support: Send request
%%%-------------------------
@ -1060,20 +920,6 @@ get_slimit_group(local, SLimits) ->
get_slimit_group(remote, SLimits) ->
SLimits#service_limits.remote.
fragment_dests(Dests, Limit_number) ->
{R, _} = lists:foldl(fun (Dest, {Res, Count}) ->
case Count of
Limit_number ->
Head2 = [Dest], {[Head2 | Res], 0};
_ ->
[Head | Tail] = Res,
Head2 = [Dest | Head],
{[Head2 | Tail], Count + 1}
end
end,
{[[]], 0}, Dests),
R.
%%%-------------------------
%%% Limits: XEP-0128 Service Discovery Extensions
%%%-------------------------