25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-12-24 17:29:28 +01:00

Don't call Mod:function() in xmpp_stream callbacks

If a callback function is not defined by the `Mod` then
a call to code_server process is performed. Under heavy load
this may cause code_server to get overloaded. We now avoid this.
This commit is contained in:
Evgeniy Khramtsov 2018-05-26 09:06:24 +03:00
parent bfe2545c01
commit fc77051b68
2 changed files with 247 additions and 190 deletions

View File

@ -210,14 +210,14 @@ format_error(Err) ->
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([Module, {_SockMod, Socket}, Opts]) ->
init([Mod, {_SockMod, Socket}, Opts]) ->
Encrypted = proplists:get_bool(tls, Opts),
SocketMonitor = xmpp_socket:monitor(Socket),
case xmpp_socket:peername(Socket) of
{ok, IP} ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State = #{owner => self(),
mod => Module,
mod => Mod,
socket => Socket,
socket_monitor => SocketMonitor,
stream_timeout => {timer:seconds(30), Time},
@ -238,15 +238,15 @@ init([Module, {_SockMod, Socket}, Opts]) ->
resource => <<"">>,
lserver => <<"">>,
ip => IP},
case try Module:init([State, Opts])
case try Mod:init([State, Opts])
catch _:undef -> {ok, State}
end of
{ok, State1} when not Encrypted ->
{_, State2, Timeout} = noreply(State1),
{ok, State2, Timeout};
{ok, State1} when Encrypted ->
TLSOpts = try Module:tls_options(State1)
catch _:undef -> []
TLSOpts = try callback(tls_options, State1)
catch _:{?MODULE, undef} -> []
end,
case xmpp_socket:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
@ -276,14 +276,14 @@ handle_cast({close, Reason}, State) ->
true -> State1;
false -> process_stream_end({socket, Reason}, State)
end);
handle_cast(Cast, #{mod := Mod} = State) ->
noreply(try Mod:handle_cast(Cast, State)
catch _:undef -> State
handle_cast(Cast, State) ->
noreply(try callback(handle_cast, Cast, State)
catch _:{?MODULE, undef} -> State
end).
handle_call(Call, From, #{mod := Mod} = State) ->
noreply(try Mod:handle_call(Call, From, State)
catch _:undef -> State
handle_call(Call, From, State) ->
noreply(try callback(handle_call, Call, From, State)
catch _:{?MODULE, undef} -> State
end).
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
@ -343,20 +343,20 @@ handle_info({'$gen_event', El}, #{stream_state := wait_for_stream} = State) ->
false -> send_pkt(State1, xmpp:serr_invalid_xml())
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, mod := Mod, codec_options := Opts} = State) ->
#{xmlns := NS, codec_options := Opts} = State) ->
noreply(
try xmpp:decode(El, NS, Opts) of
Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State)
catch _:undef -> State
State1 = try callback(handle_recv, El, Pkt, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_element(Pkt, State1)
end
catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, {error, Why}, State)
catch _:undef -> State
State1 = try callback(handle_recv, El, {error, Why}, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@ -364,17 +364,17 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) ->
noreply(try Mod:handle_cdata(Data, State)
catch _:undef -> State
State) ->
noreply(try callback(handle_cdata, Data, State)
catch _:{?MODULE, undef} -> State
end);
handle_info(timeout, #{mod := Mod, lang := Lang} = State) ->
handle_info(timeout, #{lang := Lang} = State) ->
Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State)
catch _:undef when not Disconnected ->
noreply(try callback(handle_timeout, State)
catch _:{?MODULE, undef} when not Disconnected ->
Txt = <<"Idle connection">>,
send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
_:undef ->
_:{?MODULE, undef} ->
stop(State)
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
@ -395,25 +395,25 @@ handle_info({tcp_closed, _}, State) ->
handle_info({'$gen_event', closed}, State);
handle_info({tcp_error, _, Reason}, State) ->
noreply(process_stream_end({socket, Reason}, State));
handle_info(Info, #{mod := Mod} = State) ->
noreply(try Mod:handle_info(Info, State)
catch _:undef -> State
handle_info(Info, State) ->
noreply(try callback(handle_info, Info, State)
catch _:{?MODULE, undef} -> State
end).
terminate(Reason, #{mod := Mod} = State) ->
terminate(Reason, State) ->
case get(already_terminated) of
true ->
State;
_ ->
put(already_terminated, true),
try Mod:terminate(Reason, State)
catch _:undef -> ok
try callback(terminate, Reason, State)
catch _:{?MODULE, undef} -> ok
end,
send_trailer(State)
end.
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
Mod:code_change(OldVsn, State, Extra).
code_change(OldVsn, State, Extra) ->
callback(code_change, OldVsn, State, Extra).
%%%===================================================================
%%% Internal functions
@ -464,11 +464,11 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
process_stream_end(Reason, #{mod := Mod} = State) ->
process_stream_end(Reason, State) ->
State1 = State#{stream_timeout => infinity,
stream_state => disconnected},
try Mod:handle_stream_end(Reason, State1)
catch _:undef -> stop(State1)
try callback(handle_stream_end, Reason, State1)
catch _:{?MODULE, undef} -> stop(State1)
end.
-spec process_stream(stream_start(), state()) -> state().
@ -503,17 +503,17 @@ process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
Txt = <<"Improper 'to' attribute">>,
send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang));
process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
#{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
#{xmlns := ?NS_COMPONENT} = State) ->
State1 = State#{remote_server => RemoteServer,
stream_state => wait_for_handshake},
try Mod:handle_stream_start(StreamStart, State1)
catch _:undef -> State1
try callback(handle_stream_start, StreamStart, State1)
catch _:{?MODULE, undef} -> State1
end;
process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
from = From} = StreamStart,
#{stream_authenticated := Authenticated,
stream_restarted := StreamWasRestarted,
mod := Mod, xmlns := NS, resource := Resource,
xmlns := NS, resource := Resource,
stream_encrypted := Encrypted} = State) ->
State1 = if not StreamWasRestarted ->
State#{server => Server, lserver => LServer};
@ -526,8 +526,8 @@ process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
_ ->
State1
end,
State3 = try Mod:handle_stream_start(StreamStart, State2)
catch _:undef -> State2
State3 = try callback(handle_stream_start, StreamStart, State2)
catch _:{?MODULE, undef} -> State2
end,
case is_disconnected(State3) of
true -> State3;
@ -604,21 +604,21 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
end.
-spec process_unauthenticated_packet(xmpp_element(), state()) -> state().
process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
process_unauthenticated_packet(Pkt, State) ->
NewPkt = set_lang(Pkt, State),
try Mod:handle_unauthenticated_packet(NewPkt, State)
catch _:undef ->
try callback(handle_unauthenticated_packet, NewPkt, State)
catch _:{?MODULE, undef} ->
Err = xmpp:serr_not_authorized(),
send(State, Err)
end.
-spec process_authenticated_packet(xmpp_element(), state()) -> state().
process_authenticated_packet(Pkt, #{mod := Mod} = State) ->
process_authenticated_packet(Pkt, State) ->
Pkt1 = set_lang(Pkt, State),
case set_from_to(Pkt1, State) of
{ok, Pkt2} ->
try Mod:handle_authenticated_packet(Pkt2, State)
catch _:undef ->
try callback(handle_authenticated_packet, Pkt2, State)
catch _:{?MODULE, undef} ->
Err = xmpp:err_service_unavailable(),
send_error(State, Pkt, Err)
end;
@ -628,10 +628,10 @@ process_authenticated_packet(Pkt, #{mod := Mod} = State) ->
-spec process_bind(xmpp_element(), state()) -> state().
process_bind(#iq{type = set, sub_els = [_]} = Pkt,
#{xmlns := ?NS_CLIENT, mod := Mod, lang := MyLang} = State) ->
#{xmlns := ?NS_CLIENT, lang := MyLang} = State) ->
try xmpp:try_subtag(Pkt, #bind{}) of
#bind{resource = R} ->
case Mod:bind(R, State) of
case callback(bind, R, State) of
{ok, #{user := U, server := S, resource := NewR} = State1}
when NewR /= <<"">> ->
Reply = #bind{jid = jid:make(U, S, NewR)},
@ -641,8 +641,8 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt,
send_error(State1, Pkt, Err)
end;
_ ->
try Mod:handle_unbinded_packet(Pkt, State)
catch _:undef ->
try callback(handle_unbinded_packet, Pkt, State)
catch _:{?MODULE, undef} ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end
@ -652,19 +652,19 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt,
Err = xmpp:err_bad_request(Txt, Lang),
send_error(State, Pkt, Err)
end;
process_bind(Pkt, #{mod := Mod} = State) ->
try Mod:handle_unbinded_packet(Pkt, State)
catch _:undef ->
process_bind(Pkt, State) ->
try callback(handle_unbinded_packet, Pkt, State)
catch _:{?MODULE, undef} ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end.
-spec process_handshake(handshake(), state()) -> state().
process_handshake(#handshake{data = Digest},
#{mod := Mod, stream_id := StreamID,
#{stream_id := StreamID,
remote_server := RemoteServer} = State) ->
GetPW = try Mod:get_password_fun(State)
catch _:undef -> fun(_) -> {false, undefined} end
GetPW = try callback(get_password_fun, State)
catch _:{?MODULE, undef} -> fun(_) -> {false, undefined} end
end,
AuthRes = case GetPW(<<"">>) of
{false, _} ->
@ -674,9 +674,9 @@ process_handshake(#handshake{data = Digest},
end,
case AuthRes of
true ->
State1 = try Mod:handle_auth_success(
State1 = try callback(handle_auth_success,
RemoteServer, <<"handshake">>, undefined, State)
catch _:undef -> State
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@ -685,9 +685,9 @@ process_handshake(#handshake{data = Digest},
process_stream_established(State2)
end;
false ->
State1 = try Mod:handle_auth_failure(
State1 = try callback(handle_auth_failure,
RemoteServer, <<"handshake">>, <<"not authorized">>, State)
catch _:undef -> State
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@ -699,12 +699,12 @@ process_handshake(#handshake{data = Digest},
process_stream_established(#{stream_state := StateName} = State)
when StateName == disconnected; StateName == established ->
State;
process_stream_established(#{mod := Mod} = State) ->
process_stream_established(State) ->
State1 = State#{stream_authenticated => true,
stream_state => established,
stream_timeout => infinity},
try Mod:handle_stream_established(State1)
catch _:undef -> State1
try callback(handle_stream_established, State1)
catch _:{?MODULE, undef} -> State1
end.
-spec process_compress(compress(), state()) -> state().
@ -714,9 +714,9 @@ process_compress(#compress{},
when Compressed or not Authenticated ->
send_pkt(State, #compress_failure{reason = 'setup-failed'});
process_compress(#compress{methods = HisMethods},
#{socket := Socket, mod := Mod} = State) ->
MyMethods = try Mod:compress_methods(State)
catch _:undef -> []
#{socket := Socket} = State) ->
MyMethods = try callback(compress_methods, State)
catch _:{?MODULE, undef} -> []
end,
CommonMethods = lists_intersection(MyMethods, HisMethods),
case lists:member(<<"zlib">>, CommonMethods) of
@ -745,12 +745,11 @@ process_compress(#compress{methods = HisMethods},
-spec process_starttls(state()) -> state().
process_starttls(#{stream_encrypted := true} = State) ->
process_starttls_failure(already_encrypted, State);
process_starttls(#{socket := Socket,
mod := Mod} = State) ->
process_starttls(#{socket := Socket} = State) ->
case is_starttls_available(State) of
true ->
TLSOpts = try Mod:tls_options(State)
catch _:undef -> []
TLSOpts = try callback(tls_options, State)
catch _:{?MODULE, undef} -> []
end,
case xmpp_socket:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
@ -782,7 +781,7 @@ process_starttls_failure(Why, State) ->
-spec process_sasl_request(sasl_auth(), state()) -> state().
process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
#{mod := Mod, lserver := LServer} = State) ->
#{lserver := LServer} = State) ->
State1 = State#{sasl_mech => Mech},
Mechs = get_sasl_mechanisms(State1),
case lists:member(Mech, Mechs) of
@ -795,14 +794,14 @@ process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
end,
process_sasl_result(Res, State1);
true ->
GetPW = try Mod:get_password_fun(State1)
catch _:undef -> fun(_) -> false end
GetPW = try callback(get_password_fun, State1)
catch _:{?MODULE, undef} -> fun(_) -> false end
end,
CheckPW = try Mod:check_password_fun(State1)
catch _:undef -> fun(_, _, _) -> false end
CheckPW = try callback(check_password_fun, State1)
catch _:{?MODULE, undef} -> fun(_, _, _) -> false end
end,
CheckPWDigest = try Mod:check_password_digest_fun(State1)
catch _:undef -> fun(_, _, _, _, _) -> false end
CheckPWDigest = try callback(check_password_digest_fun, State1)
catch _:{?MODULE, undef} -> fun(_, _, _, _, _) -> false end
end,
SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
GetPW, CheckPW, CheckPWDigest),
@ -831,13 +830,13 @@ process_sasl_result({error, Reason, User}, State) ->
-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
process_sasl_success(Props, ServerOut,
#{socket := Socket,
mod := Mod, sasl_mech := Mech} = State) ->
sasl_mech := Mech} = State) ->
User = identity(Props),
AuthModule = proplists:get_value(auth_module, Props),
Socket1 = xmpp_socket:reset_stream(Socket),
State0 = State#{socket => Socket1},
State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State0)
catch _:undef -> State
State1 = try callback(handle_auth_success, User, Mech, AuthModule, State0)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@ -865,10 +864,10 @@ process_sasl_continue(ServerOut, NewSASLState, State) ->
-spec process_sasl_failure(atom(), binary(), state()) -> state().
process_sasl_failure(Err, User,
#{mod := Mod, sasl_mech := Mech, lang := Lang} = State) ->
#{sasl_mech := Mech, lang := Lang} = State) ->
{Reason, Text} = format_sasl_error(Mech, Err),
State1 = try Mod:handle_auth_failure(User, Mech, Text, State)
catch _:undef -> State
State1 = try callback(handle_auth_failure, User, Mech, Text, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@ -906,21 +905,21 @@ send_features(State) ->
State.
-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()].
get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod,
get_sasl_mechanisms(#{stream_encrypted := Encrypted,
xmlns := NS, lserver := LServer} = State) ->
Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer);
true -> []
end,
TLSVerify = try Mod:tls_verify(State)
catch _:undef -> false
TLSVerify = try callback(tls_verify, State)
catch _:{?MODULE, undef} -> false
end,
Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
[<<"EXTERNAL">>|Mechs];
true ->
Mechs
end,
try Mod:sasl_mechanisms(Mechs1, State)
catch _:undef -> Mechs1
try callback(sasl_mechanisms, Mechs1, State)
catch _:{?MODULE, undef} -> Mechs1
end.
-spec get_sasl_feature(state()) -> [sasl_mechanisms()].
@ -937,12 +936,12 @@ get_sasl_feature(_) ->
[].
-spec get_compress_feature(state()) -> [compression()].
get_compress_feature(#{stream_compressed := false, mod := Mod,
get_compress_feature(#{stream_compressed := false,
stream_authenticated := true} = State) ->
try Mod:compress_methods(State) of
try callback(compress_methods, State) of
[] -> [];
Ms -> [#compression{methods = Ms}]
catch _:undef ->
catch _:{?MODULE, undef} ->
[]
end;
get_compress_feature(_) ->
@ -978,25 +977,25 @@ get_session_feature(_) ->
[].
-spec get_other_features(state()) -> [xmpp_element()].
get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
get_other_features(#{stream_authenticated := Auth} = State) ->
try
if Auth -> Mod:authenticated_stream_features(State);
true -> Mod:unauthenticated_stream_features(State)
if Auth -> callback(authenticated_stream_features, State);
true -> callback(unauthenticated_stream_features, State)
end
catch _:undef ->
catch _:{?MODULE, undef} ->
[]
end.
-spec is_starttls_available(state()) -> boolean().
is_starttls_available(#{mod := Mod} = State) ->
try Mod:tls_enabled(State)
catch _:undef -> true
is_starttls_available(State) ->
try callback(tls_enabled, State)
catch _:{?MODULE, undef} -> true
end.
-spec is_starttls_required(state()) -> boolean().
is_starttls_required(#{mod := Mod} = State) ->
try Mod:tls_required(State)
catch _:undef -> false
is_starttls_required(State) ->
try callback(tls_required, State)
catch _:{?MODULE, undef} -> false
end.
-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} |
@ -1076,10 +1075,10 @@ send_header(State, _) ->
State.
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
send_pkt(#{mod := Mod} = State, Pkt) ->
send_pkt(State, Pkt) ->
Result = socket_send(State, Pkt),
State1 = try Mod:handle_send(Pkt, Result, State)
catch _:undef -> State
State1 = try callback(handle_send, Pkt, Result, State)
catch _:{?MODULE, undef} -> State
end,
case Result of
_ when is_record(Pkt, stream_error) ->
@ -1200,3 +1199,36 @@ identity(Props) ->
<<>> -> proplists:get_value(username, Props, <<>>);
AuthzId -> AuthzId
end.
%%%===================================================================
%%% Callbacks
%%%===================================================================
callback(F, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 1) of
true -> Mod:F(State);
false -> erlang:error({?MODULE, undef})
end.
callback(F, Arg1, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 2) of
true -> Mod:F(Arg1, State);
false -> erlang:error({?MODULE, undef})
end.
callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
%% code_change/3 callback is a special snowflake
case erlang:function_exported(Mod, code_change, 3) of
true -> Mod:code_change(OldVsn, State, Extra);
false -> {ok, State}
end;
callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 3) of
true -> Mod:F(Arg1, Arg2, State);
false -> erlang:error({?MODULE, undef})
end.
callback(F, Arg1, Arg2, Arg3, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 4) of
true -> Mod:F(Arg1, Arg2, Arg3, State);
false -> erlang:error({?MODULE, undef})
end.

View File

@ -266,9 +266,9 @@ init([Mod, _SockMod, From, To, Opts]) ->
end.
-spec handle_call(term(), term(), state()) -> noreply().
handle_call(Call, From, #{mod := Mod} = State) ->
noreply(try Mod:handle_call(Call, From, State)
catch _:undef -> State
handle_call(Call, From, State) ->
noreply(try callback(handle_call, Call, From, State)
catch _:{?MODULE, undef} -> State
end).
-spec handle_cast(term(), state()) -> noreply().
@ -311,9 +311,9 @@ handle_cast({close, Reason}, State) ->
true -> State1;
false -> process_stream_end({socket, Reason}, State)
end);
handle_cast(Cast, #{mod := Mod} = State) ->
noreply(try Mod:handle_cast(Cast, State)
catch _:undef -> State
handle_cast(Cast, State) ->
noreply(try callback(handle_cast, Cast, State)
catch _:{?MODULE, undef} -> State
end).
-spec handle_info(term(), state()) -> noreply().
@ -348,20 +348,20 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
send_pkt(State1, Err)
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, mod := Mod, codec_options := Opts} = State) ->
#{xmlns := NS, codec_options := Opts} = State) ->
noreply(
try xmpp:decode(El, NS, Opts) of
Pkt ->
State1 = try Mod:handle_recv(El, Pkt, State)
catch _:undef -> State
State1 = try callback(handle_recv, El, Pkt, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_element(Pkt, State1)
end
catch _:{xmpp_codec, Why} ->
State1 = try Mod:handle_recv(El, {error, Why}, State)
catch _:undef -> State
State1 = try callback(handle_recv, El, {error, Why}, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@ -369,21 +369,21 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) ->
noreply(try Mod:handle_cdata(Data, State)
catch _:undef -> State
State) ->
noreply(try callback(handle_cdata, Data, State)
catch _:{?MODULE, undef} -> State
end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_end({socket, closed}, State));
handle_info(timeout, #{mod := Mod, lang := Lang} = State) ->
handle_info(timeout, #{lang := Lang} = State) ->
Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State)
catch _:undef when not Disconnected ->
noreply(try callback(handle_timeout, State)
catch _:{?MODULE, undef} when not Disconnected ->
Txt = <<"Idle connection">>,
send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
_:undef ->
_:{?MODULE, undef} ->
stop(State)
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
@ -404,26 +404,26 @@ handle_info({tcp_closed, _}, State) ->
handle_info({'$gen_event', closed}, State);
handle_info({tcp_error, _, Reason}, State) ->
noreply(process_stream_end({socket, Reason}, State));
handle_info(Info, #{mod := Mod} = State) ->
noreply(try Mod:handle_info(Info, State)
catch _:undef -> State
handle_info(Info, State) ->
noreply(try callback(handle_info, Info, State)
catch _:{?MODULE, undef} -> State
end).
-spec terminate(term(), state()) -> any().
terminate(Reason, #{mod := Mod} = State) ->
terminate(Reason, State) ->
case get(already_terminated) of
true ->
State;
_ ->
put(already_terminated, true),
try Mod:terminate(Reason, State)
catch _:undef -> ok
try callback(terminate, Reason, State)
catch _:{?MODULE, undef} -> ok
end,
send_trailer(State)
end.
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
Mod:code_change(OldVsn, State, Extra).
code_change(OldVsn, State, Extra) ->
callback(code_change, OldVsn, State, Extra).
%%%===================================================================
%%% Internal functions
@ -458,11 +458,11 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
process_stream_end(Reason, #{mod := Mod} = State) ->
process_stream_end(Reason, State) ->
State1 = State#{stream_timeout => infinity,
stream_state => disconnected},
try Mod:handle_stream_end(Reason, State1)
catch _:undef -> stop(State1)
try callback(handle_stream_end, Reason, State1)
catch _:{?MODULE, undef} -> stop(State1)
end.
-spec process_stream(stream_start(), state()) -> state().
@ -475,10 +475,10 @@ process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
send_pkt(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang, id = ID,
version = Version} = StreamStart,
#{mod := Mod} = State) ->
State) ->
State1 = State#{stream_remote_id => ID, lang => Lang},
State2 = try Mod:handle_stream_start(StreamStart, State1)
catch _:undef -> State1
State2 = try callback(handle_stream_start, StreamStart, State1)
catch _:{?MODULE, undef} -> State1
end,
case is_disconnected(State2) of
true -> State2;
@ -522,16 +522,16 @@ process_element(Pkt, #{stream_state := StateName} = State) ->
-spec process_features(stream_features(), state()) -> state().
process_features(StreamFeatures,
#{stream_authenticated := true, mod := Mod} = State) ->
State1 = try Mod:handle_authenticated_features(StreamFeatures, State)
catch _:undef -> State
#{stream_authenticated := true} = State) ->
State1 = try callback(handle_authenticated_features, StreamFeatures, State)
catch _:{?MODULE, undef} -> State
end,
process_stream_established(State1);
process_features(StreamFeatures,
#{stream_encrypted := Encrypted,
mod := Mod, lang := Lang} = State) ->
State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State)
catch _:undef -> State
lang := Lang} = State) ->
State1 = try callback(handle_unauthenticated_features, StreamFeatures, State)
catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@ -582,12 +582,12 @@ process_features(StreamFeatures,
process_stream_established(#{stream_state := StateName} = State)
when StateName == disconnected; StateName == established ->
State;
process_stream_established(#{mod := Mod} = State) ->
process_stream_established(State) ->
State1 = State#{stream_authenticated := true,
stream_state => established,
stream_timeout => infinity},
try Mod:handle_stream_established(State1)
catch _:undef -> State1
try callback(handle_stream_established, State1)
catch _:{?MODULE, undef} -> State1
end.
-spec process_sasl_mechanisms([binary()], state()) -> state().
@ -620,7 +620,7 @@ process_starttls(#{socket := Socket} = State) ->
-spec process_stream_downgrade(stream_start(), state()) -> state().
process_stream_downgrade(StreamStart,
#{mod := Mod, lang := Lang,
#{lang := Lang,
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
if not Encrypted and TLSRequired ->
@ -628,18 +628,17 @@ process_stream_downgrade(StreamStart,
send_pkt(State, xmpp:serr_policy_violation(Txt, Lang));
true ->
State1 = State#{stream_state => downgraded},
try Mod:handle_stream_downgraded(StreamStart, State1)
catch _:undef ->
try callback(handle_stream_downgraded, StreamStart, State1)
catch _:{?MODULE, undef} ->
send_pkt(State1, xmpp:serr_unsupported_version())
end
end.
-spec process_cert_verification(state()) -> state().
process_cert_verification(#{stream_encrypted := true,
stream_verified := false,
mod := Mod} = State) ->
case try Mod:tls_verify(State)
catch _:undef -> true
stream_verified := false} = State) ->
case try callback(tls_verify, State)
catch _:{?MODULE, undef} -> true
end of
true ->
case xmpp_stream_pkix:authenticate(State) of
@ -655,8 +654,7 @@ process_cert_verification(State) ->
State.
-spec process_sasl_success(state()) -> state().
process_sasl_success(#{mod := Mod,
socket := Socket} = State) ->
process_sasl_success(#{socket := Socket} = State) ->
Socket1 = xmpp_socket:reset_stream(Socket),
State0 = State#{socket => Socket1},
State1 = State0#{stream_id => new_id(),
@ -667,8 +665,8 @@ process_sasl_success(#{mod := Mod,
case is_disconnected(State2) of
true -> State2;
false ->
try Mod:handle_auth_success(<<"EXTERNAL">>, State2)
catch _:undef -> State2
try callback(handle_auth_success, <<"EXTERNAL">>, State2)
catch _:{?MODULE, undef} -> State2
end
end.
@ -677,27 +675,27 @@ process_sasl_failure(#sasl_failure{} = Failure, State) ->
Reason = format("Peer responded with error: ~s",
[format_sasl_failure(Failure)]),
process_sasl_failure(Reason, State);
process_sasl_failure(Reason, #{mod := Mod} = State) ->
try Mod:handle_auth_failure(<<"EXTERNAL">>, {auth, Reason}, State)
catch _:undef -> process_stream_end({auth, Reason}, State)
process_sasl_failure(Reason, State) ->
try callback(handle_auth_failure, <<"EXTERNAL">>, {auth, Reason}, State)
catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State)
end.
-spec process_packet(xmpp_element(), state()) -> state().
process_packet(Pkt, #{mod := Mod} = State) ->
try Mod:handle_packet(Pkt, State)
catch _:undef -> State
process_packet(Pkt, State) ->
try callback(handle_packet, Pkt, State)
catch _:{?MODULE, undef} -> State
end.
-spec is_starttls_required(state()) -> boolean().
is_starttls_required(#{mod := Mod} = State) ->
try Mod:tls_required(State)
catch _:undef -> false
is_starttls_required(State) ->
try callback(tls_required, State)
catch _:{?MODULE, undef} -> false
end.
-spec is_starttls_available(state()) -> boolean().
is_starttls_available(#{mod := Mod} = State) ->
try Mod:tls_enabled(State)
catch _:undef -> true
is_starttls_available(State) ->
try callback(tls_enabled, State)
catch _:{?MODULE, undef} -> true
end.
-spec send_header(state()) -> state().
@ -731,10 +729,10 @@ send_header(#{remote_server := RemoteServer,
end.
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
send_pkt(#{mod := Mod} = State, Pkt) ->
send_pkt(State, Pkt) ->
Result = socket_send(State, Pkt),
State1 = try Mod:handle_send(Pkt, Result, State)
catch _:undef -> State
State1 = try callback(handle_send, Pkt, Result, State)
catch _:{?MODULE, undef} -> State
end,
case Result of
_ when is_record(Pkt, stream_error) ->
@ -795,10 +793,10 @@ close_socket(State) ->
stream_state => disconnected}.
-spec starttls(term(), state()) -> {ok, term()} | {error, tls_error_reason()}.
starttls(Socket, #{mod := Mod, xmlns := NS,
starttls(Socket, #{xmlns := NS,
remote_server := RemoteServer} = State) ->
TLSOpts = try Mod:tls_options(State)
catch _:undef -> []
TLSOpts = try callback(tls_options, State)
catch _:{?MODULE, undef} -> []
end,
SNI = idna_to_ascii(RemoteServer),
ALPN = case NS of
@ -1077,32 +1075,59 @@ get_addr_type({_, _, _, _}) -> inet;
get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
-spec get_dns_timeout(state()) -> timeout().
get_dns_timeout(#{mod := Mod} = State) ->
try Mod:dns_timeout(State)
catch _:undef -> timer:seconds(10)
get_dns_timeout(State) ->
try callback(dns_timeout, State)
catch _:{?MODULE, undef} -> timer:seconds(10)
end.
-spec get_dns_retries(state()) -> non_neg_integer().
get_dns_retries(#{mod := Mod} = State) ->
try Mod:dns_retries(State)
catch _:undef -> 2
get_dns_retries(State) ->
try callback(dns_retries, State)
catch _:{?MODULE, undef} -> 2
end.
-spec get_default_port(state()) -> inet:port_number().
get_default_port(#{mod := Mod, xmlns := NS} = State) ->
try Mod:default_port(State)
catch _:undef when NS == ?NS_SERVER -> 5269;
_:undef when NS == ?NS_CLIENT -> 5222
get_default_port(#{xmlns := NS} = State) ->
try callback(default_port, State)
catch _:{?MODULE, undef} when NS == ?NS_SERVER -> 5269;
_:{?MODULE, undef} when NS == ?NS_CLIENT -> 5222
end.
-spec get_address_families(state()) -> [inet:address_family()].
get_address_families(#{mod := Mod} = State) ->
try Mod:address_families(State)
catch _:undef -> [inet, inet6]
get_address_families(State) ->
try callback(address_families, State)
catch _:{?MODULE, undef} -> [inet, inet6]
end.
-spec get_connect_timeout(state()) -> timeout().
get_connect_timeout(#{mod := Mod} = State) ->
try Mod:connect_timeout(State)
catch _:undef -> timer:seconds(10)
get_connect_timeout(State) ->
try callback(connect_timeout, State)
catch _:{?MODULE, undef} -> timer:seconds(10)
end.
%%%===================================================================
%%% Callbacks
%%%===================================================================
callback(F, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 1) of
true -> Mod:F(State);
false -> erlang:error({?MODULE, undef})
end.
callback(F, Arg1, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 2) of
true -> Mod:F(Arg1, State);
false -> erlang:error({?MODULE, undef})
end.
callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
%% code_change/3 callback is a special snowflake
case erlang:function_exported(Mod, code_change, 3) of
true -> Mod:code_change(OldVsn, State, Extra);
false -> {ok, State}
end;
callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
case erlang:function_exported(Mod, F, 3) of
true -> Mod:F(Arg1, Arg2, State);
false -> erlang:error({?MODULE, undef})
end.