From fc77051b68a923d958f876193fd5745af34208db Mon Sep 17 00:00:00 2001 From: Evgeniy Khramtsov Date: Sat, 26 May 2018 09:06:24 +0300 Subject: [PATCH] Don't call Mod:function() in xmpp_stream callbacks If a callback function is not defined by the `Mod` then a call to code_server process is performed. Under heavy load this may cause code_server to get overloaded. We now avoid this. --- src/xmpp_stream_in.erl | 240 +++++++++++++++++++++++----------------- src/xmpp_stream_out.erl | 197 +++++++++++++++++++-------------- 2 files changed, 247 insertions(+), 190 deletions(-) diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl index 55fa3a4bf..675425bd0 100644 --- a/src/xmpp_stream_in.erl +++ b/src/xmpp_stream_in.erl @@ -210,14 +210,14 @@ format_error(Err) -> %%%=================================================================== %%% gen_server callbacks %%%=================================================================== -init([Module, {_SockMod, Socket}, Opts]) -> +init([Mod, {_SockMod, Socket}, Opts]) -> Encrypted = proplists:get_bool(tls, Opts), SocketMonitor = xmpp_socket:monitor(Socket), case xmpp_socket:peername(Socket) of {ok, IP} -> Time = p1_time_compat:monotonic_time(milli_seconds), State = #{owner => self(), - mod => Module, + mod => Mod, socket => Socket, socket_monitor => SocketMonitor, stream_timeout => {timer:seconds(30), Time}, @@ -238,15 +238,15 @@ init([Module, {_SockMod, Socket}, Opts]) -> resource => <<"">>, lserver => <<"">>, ip => IP}, - case try Module:init([State, Opts]) + case try Mod:init([State, Opts]) catch _:undef -> {ok, State} end of {ok, State1} when not Encrypted -> {_, State2, Timeout} = noreply(State1), {ok, State2, Timeout}; {ok, State1} when Encrypted -> - TLSOpts = try Module:tls_options(State1) - catch _:undef -> [] + TLSOpts = try callback(tls_options, State1) + catch _:{?MODULE, undef} -> [] end, case xmpp_socket:starttls(Socket, TLSOpts) of {ok, TLSSocket} -> @@ -276,14 +276,14 @@ handle_cast({close, Reason}, State) -> true -> State1; false -> process_stream_end({socket, Reason}, State) end); -handle_cast(Cast, #{mod := Mod} = State) -> - noreply(try Mod:handle_cast(Cast, State) - catch _:undef -> State +handle_cast(Cast, State) -> + noreply(try callback(handle_cast, Cast, State) + catch _:{?MODULE, undef} -> State end). -handle_call(Call, From, #{mod := Mod} = State) -> - noreply(try Mod:handle_call(Call, From, State) - catch _:undef -> State +handle_call(Call, From, State) -> + noreply(try callback(handle_call, Call, From, State) + catch _:{?MODULE, undef} -> State end). handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, @@ -343,20 +343,20 @@ handle_info({'$gen_event', El}, #{stream_state := wait_for_stream} = State) -> false -> send_pkt(State1, xmpp:serr_invalid_xml()) end); handle_info({'$gen_event', {xmlstreamelement, El}}, - #{xmlns := NS, mod := Mod, codec_options := Opts} = State) -> + #{xmlns := NS, codec_options := Opts} = State) -> noreply( try xmpp:decode(El, NS, Opts) of Pkt -> - State1 = try Mod:handle_recv(El, Pkt, State) - catch _:undef -> State + State1 = try callback(handle_recv, El, Pkt, State) + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; false -> process_element(Pkt, State1) end catch _:{xmpp_codec, Why} -> - State1 = try Mod:handle_recv(El, {error, Why}, State) - catch _:undef -> State + State1 = try callback(handle_recv, El, {error, Why}, State) + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; @@ -364,17 +364,17 @@ handle_info({'$gen_event', {xmlstreamelement, El}}, end end); handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, - #{mod := Mod} = State) -> - noreply(try Mod:handle_cdata(Data, State) - catch _:undef -> State + State) -> + noreply(try callback(handle_cdata, Data, State) + catch _:{?MODULE, undef} -> State end); -handle_info(timeout, #{mod := Mod, lang := Lang} = State) -> +handle_info(timeout, #{lang := Lang} = State) -> Disconnected = is_disconnected(State), - noreply(try Mod:handle_timeout(State) - catch _:undef when not Disconnected -> + noreply(try callback(handle_timeout, State) + catch _:{?MODULE, undef} when not Disconnected -> Txt = <<"Idle connection">>, send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang)); - _:undef -> + _:{?MODULE, undef} -> stop(State) end); handle_info({'DOWN', MRef, _Type, _Object, _Info}, @@ -395,25 +395,25 @@ handle_info({tcp_closed, _}, State) -> handle_info({'$gen_event', closed}, State); handle_info({tcp_error, _, Reason}, State) -> noreply(process_stream_end({socket, Reason}, State)); -handle_info(Info, #{mod := Mod} = State) -> - noreply(try Mod:handle_info(Info, State) - catch _:undef -> State +handle_info(Info, State) -> + noreply(try callback(handle_info, Info, State) + catch _:{?MODULE, undef} -> State end). -terminate(Reason, #{mod := Mod} = State) -> +terminate(Reason, State) -> case get(already_terminated) of true -> State; _ -> put(already_terminated, true), - try Mod:terminate(Reason, State) - catch _:undef -> ok + try callback(terminate, Reason, State) + catch _:{?MODULE, undef} -> ok end, send_trailer(State) end. -code_change(OldVsn, #{mod := Mod} = State, Extra) -> - Mod:code_change(OldVsn, State, Extra). +code_change(OldVsn, State, Extra) -> + callback(code_change, OldVsn, State, Extra). %%%=================================================================== %%% Internal functions @@ -464,11 +464,11 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> -spec process_stream_end(stop_reason(), state()) -> state(). process_stream_end(_, #{stream_state := disconnected} = State) -> State; -process_stream_end(Reason, #{mod := Mod} = State) -> +process_stream_end(Reason, State) -> State1 = State#{stream_timeout => infinity, stream_state => disconnected}, - try Mod:handle_stream_end(Reason, State1) - catch _:undef -> stop(State1) + try callback(handle_stream_end, Reason, State1) + catch _:{?MODULE, undef} -> stop(State1) end. -spec process_stream(stream_start(), state()) -> state(). @@ -503,17 +503,17 @@ process_stream(#stream_start{to = #jid{luser = U, lresource = R}}, Txt = <<"Improper 'to' attribute">>, send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)); process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart, - #{xmlns := ?NS_COMPONENT, mod := Mod} = State) -> + #{xmlns := ?NS_COMPONENT} = State) -> State1 = State#{remote_server => RemoteServer, stream_state => wait_for_handshake}, - try Mod:handle_stream_start(StreamStart, State1) - catch _:undef -> State1 + try callback(handle_stream_start, StreamStart, State1) + catch _:{?MODULE, undef} -> State1 end; process_stream(#stream_start{to = #jid{server = Server, lserver = LServer}, from = From} = StreamStart, #{stream_authenticated := Authenticated, stream_restarted := StreamWasRestarted, - mod := Mod, xmlns := NS, resource := Resource, + xmlns := NS, resource := Resource, stream_encrypted := Encrypted} = State) -> State1 = if not StreamWasRestarted -> State#{server => Server, lserver => LServer}; @@ -526,8 +526,8 @@ process_stream(#stream_start{to = #jid{server = Server, lserver = LServer}, _ -> State1 end, - State3 = try Mod:handle_stream_start(StreamStart, State2) - catch _:undef -> State2 + State3 = try callback(handle_stream_start, StreamStart, State2) + catch _:{?MODULE, undef} -> State2 end, case is_disconnected(State3) of true -> State3; @@ -604,21 +604,21 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> end. -spec process_unauthenticated_packet(xmpp_element(), state()) -> state(). -process_unauthenticated_packet(Pkt, #{mod := Mod} = State) -> +process_unauthenticated_packet(Pkt, State) -> NewPkt = set_lang(Pkt, State), - try Mod:handle_unauthenticated_packet(NewPkt, State) - catch _:undef -> + try callback(handle_unauthenticated_packet, NewPkt, State) + catch _:{?MODULE, undef} -> Err = xmpp:serr_not_authorized(), send(State, Err) end. -spec process_authenticated_packet(xmpp_element(), state()) -> state(). -process_authenticated_packet(Pkt, #{mod := Mod} = State) -> +process_authenticated_packet(Pkt, State) -> Pkt1 = set_lang(Pkt, State), case set_from_to(Pkt1, State) of {ok, Pkt2} -> - try Mod:handle_authenticated_packet(Pkt2, State) - catch _:undef -> + try callback(handle_authenticated_packet, Pkt2, State) + catch _:{?MODULE, undef} -> Err = xmpp:err_service_unavailable(), send_error(State, Pkt, Err) end; @@ -628,10 +628,10 @@ process_authenticated_packet(Pkt, #{mod := Mod} = State) -> -spec process_bind(xmpp_element(), state()) -> state(). process_bind(#iq{type = set, sub_els = [_]} = Pkt, - #{xmlns := ?NS_CLIENT, mod := Mod, lang := MyLang} = State) -> + #{xmlns := ?NS_CLIENT, lang := MyLang} = State) -> try xmpp:try_subtag(Pkt, #bind{}) of #bind{resource = R} -> - case Mod:bind(R, State) of + case callback(bind, R, State) of {ok, #{user := U, server := S, resource := NewR} = State1} when NewR /= <<"">> -> Reply = #bind{jid = jid:make(U, S, NewR)}, @@ -641,8 +641,8 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt, send_error(State1, Pkt, Err) end; _ -> - try Mod:handle_unbinded_packet(Pkt, State) - catch _:undef -> + try callback(handle_unbinded_packet, Pkt, State) + catch _:{?MODULE, undef} -> Err = xmpp:err_not_authorized(), send_error(State, Pkt, Err) end @@ -652,19 +652,19 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt, Err = xmpp:err_bad_request(Txt, Lang), send_error(State, Pkt, Err) end; -process_bind(Pkt, #{mod := Mod} = State) -> - try Mod:handle_unbinded_packet(Pkt, State) - catch _:undef -> +process_bind(Pkt, State) -> + try callback(handle_unbinded_packet, Pkt, State) + catch _:{?MODULE, undef} -> Err = xmpp:err_not_authorized(), send_error(State, Pkt, Err) end. -spec process_handshake(handshake(), state()) -> state(). process_handshake(#handshake{data = Digest}, - #{mod := Mod, stream_id := StreamID, + #{stream_id := StreamID, remote_server := RemoteServer} = State) -> - GetPW = try Mod:get_password_fun(State) - catch _:undef -> fun(_) -> {false, undefined} end + GetPW = try callback(get_password_fun, State) + catch _:{?MODULE, undef} -> fun(_) -> {false, undefined} end end, AuthRes = case GetPW(<<"">>) of {false, _} -> @@ -674,9 +674,9 @@ process_handshake(#handshake{data = Digest}, end, case AuthRes of true -> - State1 = try Mod:handle_auth_success( + State1 = try callback(handle_auth_success, RemoteServer, <<"handshake">>, undefined, State) - catch _:undef -> State + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; @@ -685,9 +685,9 @@ process_handshake(#handshake{data = Digest}, process_stream_established(State2) end; false -> - State1 = try Mod:handle_auth_failure( + State1 = try callback(handle_auth_failure, RemoteServer, <<"handshake">>, <<"not authorized">>, State) - catch _:undef -> State + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; @@ -699,12 +699,12 @@ process_handshake(#handshake{data = Digest}, process_stream_established(#{stream_state := StateName} = State) when StateName == disconnected; StateName == established -> State; -process_stream_established(#{mod := Mod} = State) -> +process_stream_established(State) -> State1 = State#{stream_authenticated => true, stream_state => established, stream_timeout => infinity}, - try Mod:handle_stream_established(State1) - catch _:undef -> State1 + try callback(handle_stream_established, State1) + catch _:{?MODULE, undef} -> State1 end. -spec process_compress(compress(), state()) -> state(). @@ -714,9 +714,9 @@ process_compress(#compress{}, when Compressed or not Authenticated -> send_pkt(State, #compress_failure{reason = 'setup-failed'}); process_compress(#compress{methods = HisMethods}, - #{socket := Socket, mod := Mod} = State) -> - MyMethods = try Mod:compress_methods(State) - catch _:undef -> [] + #{socket := Socket} = State) -> + MyMethods = try callback(compress_methods, State) + catch _:{?MODULE, undef} -> [] end, CommonMethods = lists_intersection(MyMethods, HisMethods), case lists:member(<<"zlib">>, CommonMethods) of @@ -745,12 +745,11 @@ process_compress(#compress{methods = HisMethods}, -spec process_starttls(state()) -> state(). process_starttls(#{stream_encrypted := true} = State) -> process_starttls_failure(already_encrypted, State); -process_starttls(#{socket := Socket, - mod := Mod} = State) -> +process_starttls(#{socket := Socket} = State) -> case is_starttls_available(State) of true -> - TLSOpts = try Mod:tls_options(State) - catch _:undef -> [] + TLSOpts = try callback(tls_options, State) + catch _:{?MODULE, undef} -> [] end, case xmpp_socket:starttls(Socket, TLSOpts) of {ok, TLSSocket} -> @@ -782,7 +781,7 @@ process_starttls_failure(Why, State) -> -spec process_sasl_request(sasl_auth(), state()) -> state(). process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, - #{mod := Mod, lserver := LServer} = State) -> + #{lserver := LServer} = State) -> State1 = State#{sasl_mech => Mech}, Mechs = get_sasl_mechanisms(State1), case lists:member(Mech, Mechs) of @@ -795,14 +794,14 @@ process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, end, process_sasl_result(Res, State1); true -> - GetPW = try Mod:get_password_fun(State1) - catch _:undef -> fun(_) -> false end + GetPW = try callback(get_password_fun, State1) + catch _:{?MODULE, undef} -> fun(_) -> false end end, - CheckPW = try Mod:check_password_fun(State1) - catch _:undef -> fun(_, _, _) -> false end + CheckPW = try callback(check_password_fun, State1) + catch _:{?MODULE, undef} -> fun(_, _, _) -> false end end, - CheckPWDigest = try Mod:check_password_digest_fun(State1) - catch _:undef -> fun(_, _, _, _, _) -> false end + CheckPWDigest = try callback(check_password_digest_fun, State1) + catch _:{?MODULE, undef} -> fun(_, _, _, _, _) -> false end end, SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [], GetPW, CheckPW, CheckPWDigest), @@ -831,13 +830,13 @@ process_sasl_result({error, Reason, User}, State) -> -spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state(). process_sasl_success(Props, ServerOut, #{socket := Socket, - mod := Mod, sasl_mech := Mech} = State) -> + sasl_mech := Mech} = State) -> User = identity(Props), AuthModule = proplists:get_value(auth_module, Props), Socket1 = xmpp_socket:reset_stream(Socket), State0 = State#{socket => Socket1}, - State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State0) - catch _:undef -> State + State1 = try callback(handle_auth_success, User, Mech, AuthModule, State0) + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; @@ -865,10 +864,10 @@ process_sasl_continue(ServerOut, NewSASLState, State) -> -spec process_sasl_failure(atom(), binary(), state()) -> state(). process_sasl_failure(Err, User, - #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) -> + #{sasl_mech := Mech, lang := Lang} = State) -> {Reason, Text} = format_sasl_error(Mech, Err), - State1 = try Mod:handle_auth_failure(User, Mech, Text, State) - catch _:undef -> State + State1 = try callback(handle_auth_failure, User, Mech, Text, State) + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; @@ -906,21 +905,21 @@ send_features(State) -> State. -spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()]. -get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod, +get_sasl_mechanisms(#{stream_encrypted := Encrypted, xmlns := NS, lserver := LServer} = State) -> Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer); true -> [] end, - TLSVerify = try Mod:tls_verify(State) - catch _:undef -> false + TLSVerify = try callback(tls_verify, State) + catch _:{?MODULE, undef} -> false end, Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) -> [<<"EXTERNAL">>|Mechs]; true -> Mechs end, - try Mod:sasl_mechanisms(Mechs1, State) - catch _:undef -> Mechs1 + try callback(sasl_mechanisms, Mechs1, State) + catch _:{?MODULE, undef} -> Mechs1 end. -spec get_sasl_feature(state()) -> [sasl_mechanisms()]. @@ -937,12 +936,12 @@ get_sasl_feature(_) -> []. -spec get_compress_feature(state()) -> [compression()]. -get_compress_feature(#{stream_compressed := false, mod := Mod, +get_compress_feature(#{stream_compressed := false, stream_authenticated := true} = State) -> - try Mod:compress_methods(State) of + try callback(compress_methods, State) of [] -> []; Ms -> [#compression{methods = Ms}] - catch _:undef -> + catch _:{?MODULE, undef} -> [] end; get_compress_feature(_) -> @@ -978,25 +977,25 @@ get_session_feature(_) -> []. -spec get_other_features(state()) -> [xmpp_element()]. -get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) -> +get_other_features(#{stream_authenticated := Auth} = State) -> try - if Auth -> Mod:authenticated_stream_features(State); - true -> Mod:unauthenticated_stream_features(State) + if Auth -> callback(authenticated_stream_features, State); + true -> callback(unauthenticated_stream_features, State) end - catch _:undef -> + catch _:{?MODULE, undef} -> [] end. -spec is_starttls_available(state()) -> boolean(). -is_starttls_available(#{mod := Mod} = State) -> - try Mod:tls_enabled(State) - catch _:undef -> true +is_starttls_available(State) -> + try callback(tls_enabled, State) + catch _:{?MODULE, undef} -> true end. -spec is_starttls_required(state()) -> boolean(). -is_starttls_required(#{mod := Mod} = State) -> - try Mod:tls_required(State) - catch _:undef -> false +is_starttls_required(State) -> + try callback(tls_required, State) + catch _:{?MODULE, undef} -> false end. -spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} | @@ -1076,10 +1075,10 @@ send_header(State, _) -> State. -spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). -send_pkt(#{mod := Mod} = State, Pkt) -> +send_pkt(State, Pkt) -> Result = socket_send(State, Pkt), - State1 = try Mod:handle_send(Pkt, Result, State) - catch _:undef -> State + State1 = try callback(handle_send, Pkt, Result, State) + catch _:{?MODULE, undef} -> State end, case Result of _ when is_record(Pkt, stream_error) -> @@ -1200,3 +1199,36 @@ identity(Props) -> <<>> -> proplists:get_value(username, Props, <<>>); AuthzId -> AuthzId end. + +%%%=================================================================== +%%% Callbacks +%%%=================================================================== +callback(F, #{mod := Mod} = State) -> + case erlang:function_exported(Mod, F, 1) of + true -> Mod:F(State); + false -> erlang:error({?MODULE, undef}) + end. + +callback(F, Arg1, #{mod := Mod} = State) -> + case erlang:function_exported(Mod, F, 2) of + true -> Mod:F(Arg1, State); + false -> erlang:error({?MODULE, undef}) + end. + +callback(code_change, OldVsn, #{mod := Mod} = State, Extra) -> + %% code_change/3 callback is a special snowflake + case erlang:function_exported(Mod, code_change, 3) of + true -> Mod:code_change(OldVsn, State, Extra); + false -> {ok, State} + end; +callback(F, Arg1, Arg2, #{mod := Mod} = State) -> + case erlang:function_exported(Mod, F, 3) of + true -> Mod:F(Arg1, Arg2, State); + false -> erlang:error({?MODULE, undef}) + end. + +callback(F, Arg1, Arg2, Arg3, #{mod := Mod} = State) -> + case erlang:function_exported(Mod, F, 4) of + true -> Mod:F(Arg1, Arg2, Arg3, State); + false -> erlang:error({?MODULE, undef}) + end. diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl index b2367a09b..da0a14e22 100644 --- a/src/xmpp_stream_out.erl +++ b/src/xmpp_stream_out.erl @@ -266,9 +266,9 @@ init([Mod, _SockMod, From, To, Opts]) -> end. -spec handle_call(term(), term(), state()) -> noreply(). -handle_call(Call, From, #{mod := Mod} = State) -> - noreply(try Mod:handle_call(Call, From, State) - catch _:undef -> State +handle_call(Call, From, State) -> + noreply(try callback(handle_call, Call, From, State) + catch _:{?MODULE, undef} -> State end). -spec handle_cast(term(), state()) -> noreply(). @@ -311,9 +311,9 @@ handle_cast({close, Reason}, State) -> true -> State1; false -> process_stream_end({socket, Reason}, State) end); -handle_cast(Cast, #{mod := Mod} = State) -> - noreply(try Mod:handle_cast(Cast, State) - catch _:undef -> State +handle_cast(Cast, State) -> + noreply(try callback(handle_cast, Cast, State) + catch _:{?MODULE, undef} -> State end). -spec handle_info(term(), state()) -> noreply(). @@ -348,20 +348,20 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> send_pkt(State1, Err) end); handle_info({'$gen_event', {xmlstreamelement, El}}, - #{xmlns := NS, mod := Mod, codec_options := Opts} = State) -> + #{xmlns := NS, codec_options := Opts} = State) -> noreply( try xmpp:decode(El, NS, Opts) of Pkt -> - State1 = try Mod:handle_recv(El, Pkt, State) - catch _:undef -> State + State1 = try callback(handle_recv, El, Pkt, State) + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; false -> process_element(Pkt, State1) end catch _:{xmpp_codec, Why} -> - State1 = try Mod:handle_recv(El, {error, Why}, State) - catch _:undef -> State + State1 = try callback(handle_recv, El, {error, Why}, State) + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; @@ -369,21 +369,21 @@ handle_info({'$gen_event', {xmlstreamelement, El}}, end end); handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, - #{mod := Mod} = State) -> - noreply(try Mod:handle_cdata(Data, State) - catch _:undef -> State + State) -> + noreply(try callback(handle_cdata, Data, State) + catch _:{?MODULE, undef} -> State end); handle_info({'$gen_event', {xmlstreamend, _}}, State) -> noreply(process_stream_end({stream, reset}, State)); handle_info({'$gen_event', closed}, State) -> noreply(process_stream_end({socket, closed}, State)); -handle_info(timeout, #{mod := Mod, lang := Lang} = State) -> +handle_info(timeout, #{lang := Lang} = State) -> Disconnected = is_disconnected(State), - noreply(try Mod:handle_timeout(State) - catch _:undef when not Disconnected -> + noreply(try callback(handle_timeout, State) + catch _:{?MODULE, undef} when not Disconnected -> Txt = <<"Idle connection">>, send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang)); - _:undef -> + _:{?MODULE, undef} -> stop(State) end); handle_info({'DOWN', MRef, _Type, _Object, _Info}, @@ -404,26 +404,26 @@ handle_info({tcp_closed, _}, State) -> handle_info({'$gen_event', closed}, State); handle_info({tcp_error, _, Reason}, State) -> noreply(process_stream_end({socket, Reason}, State)); -handle_info(Info, #{mod := Mod} = State) -> - noreply(try Mod:handle_info(Info, State) - catch _:undef -> State +handle_info(Info, State) -> + noreply(try callback(handle_info, Info, State) + catch _:{?MODULE, undef} -> State end). -spec terminate(term(), state()) -> any(). -terminate(Reason, #{mod := Mod} = State) -> +terminate(Reason, State) -> case get(already_terminated) of true -> State; _ -> put(already_terminated, true), - try Mod:terminate(Reason, State) - catch _:undef -> ok + try callback(terminate, Reason, State) + catch _:{?MODULE, undef} -> ok end, send_trailer(State) end. -code_change(OldVsn, #{mod := Mod} = State, Extra) -> - Mod:code_change(OldVsn, State, Extra). +code_change(OldVsn, State, Extra) -> + callback(code_change, OldVsn, State, Extra). %%%=================================================================== %%% Internal functions @@ -458,11 +458,11 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> -spec process_stream_end(stop_reason(), state()) -> state(). process_stream_end(_, #{stream_state := disconnected} = State) -> State; -process_stream_end(Reason, #{mod := Mod} = State) -> +process_stream_end(Reason, State) -> State1 = State#{stream_timeout => infinity, stream_state => disconnected}, - try Mod:handle_stream_end(Reason, State1) - catch _:undef -> stop(State1) + try callback(handle_stream_end, Reason, State1) + catch _:{?MODULE, undef} -> stop(State1) end. -spec process_stream(stream_start(), state()) -> state(). @@ -475,10 +475,10 @@ process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> send_pkt(State, xmpp:serr_unsupported_version()); process_stream(#stream_start{lang = Lang, id = ID, version = Version} = StreamStart, - #{mod := Mod} = State) -> + State) -> State1 = State#{stream_remote_id => ID, lang => Lang}, - State2 = try Mod:handle_stream_start(StreamStart, State1) - catch _:undef -> State1 + State2 = try callback(handle_stream_start, StreamStart, State1) + catch _:{?MODULE, undef} -> State1 end, case is_disconnected(State2) of true -> State2; @@ -522,16 +522,16 @@ process_element(Pkt, #{stream_state := StateName} = State) -> -spec process_features(stream_features(), state()) -> state(). process_features(StreamFeatures, - #{stream_authenticated := true, mod := Mod} = State) -> - State1 = try Mod:handle_authenticated_features(StreamFeatures, State) - catch _:undef -> State + #{stream_authenticated := true} = State) -> + State1 = try callback(handle_authenticated_features, StreamFeatures, State) + catch _:{?MODULE, undef} -> State end, process_stream_established(State1); process_features(StreamFeatures, #{stream_encrypted := Encrypted, - mod := Mod, lang := Lang} = State) -> - State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State) - catch _:undef -> State + lang := Lang} = State) -> + State1 = try callback(handle_unauthenticated_features, StreamFeatures, State) + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; @@ -582,12 +582,12 @@ process_features(StreamFeatures, process_stream_established(#{stream_state := StateName} = State) when StateName == disconnected; StateName == established -> State; -process_stream_established(#{mod := Mod} = State) -> +process_stream_established(State) -> State1 = State#{stream_authenticated := true, stream_state => established, stream_timeout => infinity}, - try Mod:handle_stream_established(State1) - catch _:undef -> State1 + try callback(handle_stream_established, State1) + catch _:{?MODULE, undef} -> State1 end. -spec process_sasl_mechanisms([binary()], state()) -> state(). @@ -620,7 +620,7 @@ process_starttls(#{socket := Socket} = State) -> -spec process_stream_downgrade(stream_start(), state()) -> state(). process_stream_downgrade(StreamStart, - #{mod := Mod, lang := Lang, + #{lang := Lang, stream_encrypted := Encrypted} = State) -> TLSRequired = is_starttls_required(State), if not Encrypted and TLSRequired -> @@ -628,18 +628,17 @@ process_stream_downgrade(StreamStart, send_pkt(State, xmpp:serr_policy_violation(Txt, Lang)); true -> State1 = State#{stream_state => downgraded}, - try Mod:handle_stream_downgraded(StreamStart, State1) - catch _:undef -> + try callback(handle_stream_downgraded, StreamStart, State1) + catch _:{?MODULE, undef} -> send_pkt(State1, xmpp:serr_unsupported_version()) end end. -spec process_cert_verification(state()) -> state(). process_cert_verification(#{stream_encrypted := true, - stream_verified := false, - mod := Mod} = State) -> - case try Mod:tls_verify(State) - catch _:undef -> true + stream_verified := false} = State) -> + case try callback(tls_verify, State) + catch _:{?MODULE, undef} -> true end of true -> case xmpp_stream_pkix:authenticate(State) of @@ -655,8 +654,7 @@ process_cert_verification(State) -> State. -spec process_sasl_success(state()) -> state(). -process_sasl_success(#{mod := Mod, - socket := Socket} = State) -> +process_sasl_success(#{socket := Socket} = State) -> Socket1 = xmpp_socket:reset_stream(Socket), State0 = State#{socket => Socket1}, State1 = State0#{stream_id => new_id(), @@ -667,8 +665,8 @@ process_sasl_success(#{mod := Mod, case is_disconnected(State2) of true -> State2; false -> - try Mod:handle_auth_success(<<"EXTERNAL">>, State2) - catch _:undef -> State2 + try callback(handle_auth_success, <<"EXTERNAL">>, State2) + catch _:{?MODULE, undef} -> State2 end end. @@ -677,27 +675,27 @@ process_sasl_failure(#sasl_failure{} = Failure, State) -> Reason = format("Peer responded with error: ~s", [format_sasl_failure(Failure)]), process_sasl_failure(Reason, State); -process_sasl_failure(Reason, #{mod := Mod} = State) -> - try Mod:handle_auth_failure(<<"EXTERNAL">>, {auth, Reason}, State) - catch _:undef -> process_stream_end({auth, Reason}, State) +process_sasl_failure(Reason, State) -> + try callback(handle_auth_failure, <<"EXTERNAL">>, {auth, Reason}, State) + catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State) end. -spec process_packet(xmpp_element(), state()) -> state(). -process_packet(Pkt, #{mod := Mod} = State) -> - try Mod:handle_packet(Pkt, State) - catch _:undef -> State +process_packet(Pkt, State) -> + try callback(handle_packet, Pkt, State) + catch _:{?MODULE, undef} -> State end. -spec is_starttls_required(state()) -> boolean(). -is_starttls_required(#{mod := Mod} = State) -> - try Mod:tls_required(State) - catch _:undef -> false +is_starttls_required(State) -> + try callback(tls_required, State) + catch _:{?MODULE, undef} -> false end. -spec is_starttls_available(state()) -> boolean(). -is_starttls_available(#{mod := Mod} = State) -> - try Mod:tls_enabled(State) - catch _:undef -> true +is_starttls_available(State) -> + try callback(tls_enabled, State) + catch _:{?MODULE, undef} -> true end. -spec send_header(state()) -> state(). @@ -731,10 +729,10 @@ send_header(#{remote_server := RemoteServer, end. -spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). -send_pkt(#{mod := Mod} = State, Pkt) -> +send_pkt(State, Pkt) -> Result = socket_send(State, Pkt), - State1 = try Mod:handle_send(Pkt, Result, State) - catch _:undef -> State + State1 = try callback(handle_send, Pkt, Result, State) + catch _:{?MODULE, undef} -> State end, case Result of _ when is_record(Pkt, stream_error) -> @@ -795,10 +793,10 @@ close_socket(State) -> stream_state => disconnected}. -spec starttls(term(), state()) -> {ok, term()} | {error, tls_error_reason()}. -starttls(Socket, #{mod := Mod, xmlns := NS, +starttls(Socket, #{xmlns := NS, remote_server := RemoteServer} = State) -> - TLSOpts = try Mod:tls_options(State) - catch _:undef -> [] + TLSOpts = try callback(tls_options, State) + catch _:{?MODULE, undef} -> [] end, SNI = idna_to_ascii(RemoteServer), ALPN = case NS of @@ -1077,32 +1075,59 @@ get_addr_type({_, _, _, _}) -> inet; get_addr_type({_, _, _, _, _, _, _, _}) -> inet6. -spec get_dns_timeout(state()) -> timeout(). -get_dns_timeout(#{mod := Mod} = State) -> - try Mod:dns_timeout(State) - catch _:undef -> timer:seconds(10) +get_dns_timeout(State) -> + try callback(dns_timeout, State) + catch _:{?MODULE, undef} -> timer:seconds(10) end. -spec get_dns_retries(state()) -> non_neg_integer(). -get_dns_retries(#{mod := Mod} = State) -> - try Mod:dns_retries(State) - catch _:undef -> 2 +get_dns_retries(State) -> + try callback(dns_retries, State) + catch _:{?MODULE, undef} -> 2 end. -spec get_default_port(state()) -> inet:port_number(). -get_default_port(#{mod := Mod, xmlns := NS} = State) -> - try Mod:default_port(State) - catch _:undef when NS == ?NS_SERVER -> 5269; - _:undef when NS == ?NS_CLIENT -> 5222 +get_default_port(#{xmlns := NS} = State) -> + try callback(default_port, State) + catch _:{?MODULE, undef} when NS == ?NS_SERVER -> 5269; + _:{?MODULE, undef} when NS == ?NS_CLIENT -> 5222 end. -spec get_address_families(state()) -> [inet:address_family()]. -get_address_families(#{mod := Mod} = State) -> - try Mod:address_families(State) - catch _:undef -> [inet, inet6] +get_address_families(State) -> + try callback(address_families, State) + catch _:{?MODULE, undef} -> [inet, inet6] end. -spec get_connect_timeout(state()) -> timeout(). -get_connect_timeout(#{mod := Mod} = State) -> - try Mod:connect_timeout(State) - catch _:undef -> timer:seconds(10) +get_connect_timeout(State) -> + try callback(connect_timeout, State) + catch _:{?MODULE, undef} -> timer:seconds(10) + end. + +%%%=================================================================== +%%% Callbacks +%%%=================================================================== +callback(F, #{mod := Mod} = State) -> + case erlang:function_exported(Mod, F, 1) of + true -> Mod:F(State); + false -> erlang:error({?MODULE, undef}) + end. + +callback(F, Arg1, #{mod := Mod} = State) -> + case erlang:function_exported(Mod, F, 2) of + true -> Mod:F(Arg1, State); + false -> erlang:error({?MODULE, undef}) + end. + +callback(code_change, OldVsn, #{mod := Mod} = State, Extra) -> + %% code_change/3 callback is a special snowflake + case erlang:function_exported(Mod, code_change, 3) of + true -> Mod:code_change(OldVsn, State, Extra); + false -> {ok, State} + end; +callback(F, Arg1, Arg2, #{mod := Mod} = State) -> + case erlang:function_exported(Mod, F, 3) of + true -> Mod:F(Arg1, Arg2, State); + false -> erlang:error({?MODULE, undef}) end.