mirror of
https://github.com/processone/ejabberd.git
synced 2024-11-28 16:34:13 +01:00
Don't call Mod:function() in xmpp_stream callbacks
If a callback function is not defined by the `Mod` then a call to code_server process is performed. Under heavy load this may cause code_server to get overloaded. We now avoid this.
This commit is contained in:
parent
bfe2545c01
commit
fc77051b68
@ -210,14 +210,14 @@ format_error(Err) ->
|
||||
%%%===================================================================
|
||||
%%% gen_server callbacks
|
||||
%%%===================================================================
|
||||
init([Module, {_SockMod, Socket}, Opts]) ->
|
||||
init([Mod, {_SockMod, Socket}, Opts]) ->
|
||||
Encrypted = proplists:get_bool(tls, Opts),
|
||||
SocketMonitor = xmpp_socket:monitor(Socket),
|
||||
case xmpp_socket:peername(Socket) of
|
||||
{ok, IP} ->
|
||||
Time = p1_time_compat:monotonic_time(milli_seconds),
|
||||
State = #{owner => self(),
|
||||
mod => Module,
|
||||
mod => Mod,
|
||||
socket => Socket,
|
||||
socket_monitor => SocketMonitor,
|
||||
stream_timeout => {timer:seconds(30), Time},
|
||||
@ -238,15 +238,15 @@ init([Module, {_SockMod, Socket}, Opts]) ->
|
||||
resource => <<"">>,
|
||||
lserver => <<"">>,
|
||||
ip => IP},
|
||||
case try Module:init([State, Opts])
|
||||
case try Mod:init([State, Opts])
|
||||
catch _:undef -> {ok, State}
|
||||
end of
|
||||
{ok, State1} when not Encrypted ->
|
||||
{_, State2, Timeout} = noreply(State1),
|
||||
{ok, State2, Timeout};
|
||||
{ok, State1} when Encrypted ->
|
||||
TLSOpts = try Module:tls_options(State1)
|
||||
catch _:undef -> []
|
||||
TLSOpts = try callback(tls_options, State1)
|
||||
catch _:{?MODULE, undef} -> []
|
||||
end,
|
||||
case xmpp_socket:starttls(Socket, TLSOpts) of
|
||||
{ok, TLSSocket} ->
|
||||
@ -276,14 +276,14 @@ handle_cast({close, Reason}, State) ->
|
||||
true -> State1;
|
||||
false -> process_stream_end({socket, Reason}, State)
|
||||
end);
|
||||
handle_cast(Cast, #{mod := Mod} = State) ->
|
||||
noreply(try Mod:handle_cast(Cast, State)
|
||||
catch _:undef -> State
|
||||
handle_cast(Cast, State) ->
|
||||
noreply(try callback(handle_cast, Cast, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end).
|
||||
|
||||
handle_call(Call, From, #{mod := Mod} = State) ->
|
||||
noreply(try Mod:handle_call(Call, From, State)
|
||||
catch _:undef -> State
|
||||
handle_call(Call, From, State) ->
|
||||
noreply(try callback(handle_call, Call, From, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end).
|
||||
|
||||
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
|
||||
@ -343,20 +343,20 @@ handle_info({'$gen_event', El}, #{stream_state := wait_for_stream} = State) ->
|
||||
false -> send_pkt(State1, xmpp:serr_invalid_xml())
|
||||
end);
|
||||
handle_info({'$gen_event', {xmlstreamelement, El}},
|
||||
#{xmlns := NS, mod := Mod, codec_options := Opts} = State) ->
|
||||
#{xmlns := NS, codec_options := Opts} = State) ->
|
||||
noreply(
|
||||
try xmpp:decode(El, NS, Opts) of
|
||||
Pkt ->
|
||||
State1 = try Mod:handle_recv(El, Pkt, State)
|
||||
catch _:undef -> State
|
||||
State1 = try callback(handle_recv, El, Pkt, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case is_disconnected(State1) of
|
||||
true -> State1;
|
||||
false -> process_element(Pkt, State1)
|
||||
end
|
||||
catch _:{xmpp_codec, Why} ->
|
||||
State1 = try Mod:handle_recv(El, {error, Why}, State)
|
||||
catch _:undef -> State
|
||||
State1 = try callback(handle_recv, El, {error, Why}, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case is_disconnected(State1) of
|
||||
true -> State1;
|
||||
@ -364,17 +364,17 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
|
||||
end
|
||||
end);
|
||||
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
|
||||
#{mod := Mod} = State) ->
|
||||
noreply(try Mod:handle_cdata(Data, State)
|
||||
catch _:undef -> State
|
||||
State) ->
|
||||
noreply(try callback(handle_cdata, Data, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end);
|
||||
handle_info(timeout, #{mod := Mod, lang := Lang} = State) ->
|
||||
handle_info(timeout, #{lang := Lang} = State) ->
|
||||
Disconnected = is_disconnected(State),
|
||||
noreply(try Mod:handle_timeout(State)
|
||||
catch _:undef when not Disconnected ->
|
||||
noreply(try callback(handle_timeout, State)
|
||||
catch _:{?MODULE, undef} when not Disconnected ->
|
||||
Txt = <<"Idle connection">>,
|
||||
send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
|
||||
_:undef ->
|
||||
_:{?MODULE, undef} ->
|
||||
stop(State)
|
||||
end);
|
||||
handle_info({'DOWN', MRef, _Type, _Object, _Info},
|
||||
@ -395,25 +395,25 @@ handle_info({tcp_closed, _}, State) ->
|
||||
handle_info({'$gen_event', closed}, State);
|
||||
handle_info({tcp_error, _, Reason}, State) ->
|
||||
noreply(process_stream_end({socket, Reason}, State));
|
||||
handle_info(Info, #{mod := Mod} = State) ->
|
||||
noreply(try Mod:handle_info(Info, State)
|
||||
catch _:undef -> State
|
||||
handle_info(Info, State) ->
|
||||
noreply(try callback(handle_info, Info, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end).
|
||||
|
||||
terminate(Reason, #{mod := Mod} = State) ->
|
||||
terminate(Reason, State) ->
|
||||
case get(already_terminated) of
|
||||
true ->
|
||||
State;
|
||||
_ ->
|
||||
put(already_terminated, true),
|
||||
try Mod:terminate(Reason, State)
|
||||
catch _:undef -> ok
|
||||
try callback(terminate, Reason, State)
|
||||
catch _:{?MODULE, undef} -> ok
|
||||
end,
|
||||
send_trailer(State)
|
||||
end.
|
||||
|
||||
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
|
||||
Mod:code_change(OldVsn, State, Extra).
|
||||
code_change(OldVsn, State, Extra) ->
|
||||
callback(code_change, OldVsn, State, Extra).
|
||||
|
||||
%%%===================================================================
|
||||
%%% Internal functions
|
||||
@ -464,11 +464,11 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
|
||||
-spec process_stream_end(stop_reason(), state()) -> state().
|
||||
process_stream_end(_, #{stream_state := disconnected} = State) ->
|
||||
State;
|
||||
process_stream_end(Reason, #{mod := Mod} = State) ->
|
||||
process_stream_end(Reason, State) ->
|
||||
State1 = State#{stream_timeout => infinity,
|
||||
stream_state => disconnected},
|
||||
try Mod:handle_stream_end(Reason, State1)
|
||||
catch _:undef -> stop(State1)
|
||||
try callback(handle_stream_end, Reason, State1)
|
||||
catch _:{?MODULE, undef} -> stop(State1)
|
||||
end.
|
||||
|
||||
-spec process_stream(stream_start(), state()) -> state().
|
||||
@ -503,17 +503,17 @@ process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
|
||||
Txt = <<"Improper 'to' attribute">>,
|
||||
send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang));
|
||||
process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
|
||||
#{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
|
||||
#{xmlns := ?NS_COMPONENT} = State) ->
|
||||
State1 = State#{remote_server => RemoteServer,
|
||||
stream_state => wait_for_handshake},
|
||||
try Mod:handle_stream_start(StreamStart, State1)
|
||||
catch _:undef -> State1
|
||||
try callback(handle_stream_start, StreamStart, State1)
|
||||
catch _:{?MODULE, undef} -> State1
|
||||
end;
|
||||
process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
|
||||
from = From} = StreamStart,
|
||||
#{stream_authenticated := Authenticated,
|
||||
stream_restarted := StreamWasRestarted,
|
||||
mod := Mod, xmlns := NS, resource := Resource,
|
||||
xmlns := NS, resource := Resource,
|
||||
stream_encrypted := Encrypted} = State) ->
|
||||
State1 = if not StreamWasRestarted ->
|
||||
State#{server => Server, lserver => LServer};
|
||||
@ -526,8 +526,8 @@ process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
|
||||
_ ->
|
||||
State1
|
||||
end,
|
||||
State3 = try Mod:handle_stream_start(StreamStart, State2)
|
||||
catch _:undef -> State2
|
||||
State3 = try callback(handle_stream_start, StreamStart, State2)
|
||||
catch _:{?MODULE, undef} -> State2
|
||||
end,
|
||||
case is_disconnected(State3) of
|
||||
true -> State3;
|
||||
@ -604,21 +604,21 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
|
||||
end.
|
||||
|
||||
-spec process_unauthenticated_packet(xmpp_element(), state()) -> state().
|
||||
process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
|
||||
process_unauthenticated_packet(Pkt, State) ->
|
||||
NewPkt = set_lang(Pkt, State),
|
||||
try Mod:handle_unauthenticated_packet(NewPkt, State)
|
||||
catch _:undef ->
|
||||
try callback(handle_unauthenticated_packet, NewPkt, State)
|
||||
catch _:{?MODULE, undef} ->
|
||||
Err = xmpp:serr_not_authorized(),
|
||||
send(State, Err)
|
||||
end.
|
||||
|
||||
-spec process_authenticated_packet(xmpp_element(), state()) -> state().
|
||||
process_authenticated_packet(Pkt, #{mod := Mod} = State) ->
|
||||
process_authenticated_packet(Pkt, State) ->
|
||||
Pkt1 = set_lang(Pkt, State),
|
||||
case set_from_to(Pkt1, State) of
|
||||
{ok, Pkt2} ->
|
||||
try Mod:handle_authenticated_packet(Pkt2, State)
|
||||
catch _:undef ->
|
||||
try callback(handle_authenticated_packet, Pkt2, State)
|
||||
catch _:{?MODULE, undef} ->
|
||||
Err = xmpp:err_service_unavailable(),
|
||||
send_error(State, Pkt, Err)
|
||||
end;
|
||||
@ -628,10 +628,10 @@ process_authenticated_packet(Pkt, #{mod := Mod} = State) ->
|
||||
|
||||
-spec process_bind(xmpp_element(), state()) -> state().
|
||||
process_bind(#iq{type = set, sub_els = [_]} = Pkt,
|
||||
#{xmlns := ?NS_CLIENT, mod := Mod, lang := MyLang} = State) ->
|
||||
#{xmlns := ?NS_CLIENT, lang := MyLang} = State) ->
|
||||
try xmpp:try_subtag(Pkt, #bind{}) of
|
||||
#bind{resource = R} ->
|
||||
case Mod:bind(R, State) of
|
||||
case callback(bind, R, State) of
|
||||
{ok, #{user := U, server := S, resource := NewR} = State1}
|
||||
when NewR /= <<"">> ->
|
||||
Reply = #bind{jid = jid:make(U, S, NewR)},
|
||||
@ -641,8 +641,8 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt,
|
||||
send_error(State1, Pkt, Err)
|
||||
end;
|
||||
_ ->
|
||||
try Mod:handle_unbinded_packet(Pkt, State)
|
||||
catch _:undef ->
|
||||
try callback(handle_unbinded_packet, Pkt, State)
|
||||
catch _:{?MODULE, undef} ->
|
||||
Err = xmpp:err_not_authorized(),
|
||||
send_error(State, Pkt, Err)
|
||||
end
|
||||
@ -652,19 +652,19 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt,
|
||||
Err = xmpp:err_bad_request(Txt, Lang),
|
||||
send_error(State, Pkt, Err)
|
||||
end;
|
||||
process_bind(Pkt, #{mod := Mod} = State) ->
|
||||
try Mod:handle_unbinded_packet(Pkt, State)
|
||||
catch _:undef ->
|
||||
process_bind(Pkt, State) ->
|
||||
try callback(handle_unbinded_packet, Pkt, State)
|
||||
catch _:{?MODULE, undef} ->
|
||||
Err = xmpp:err_not_authorized(),
|
||||
send_error(State, Pkt, Err)
|
||||
end.
|
||||
|
||||
-spec process_handshake(handshake(), state()) -> state().
|
||||
process_handshake(#handshake{data = Digest},
|
||||
#{mod := Mod, stream_id := StreamID,
|
||||
#{stream_id := StreamID,
|
||||
remote_server := RemoteServer} = State) ->
|
||||
GetPW = try Mod:get_password_fun(State)
|
||||
catch _:undef -> fun(_) -> {false, undefined} end
|
||||
GetPW = try callback(get_password_fun, State)
|
||||
catch _:{?MODULE, undef} -> fun(_) -> {false, undefined} end
|
||||
end,
|
||||
AuthRes = case GetPW(<<"">>) of
|
||||
{false, _} ->
|
||||
@ -674,9 +674,9 @@ process_handshake(#handshake{data = Digest},
|
||||
end,
|
||||
case AuthRes of
|
||||
true ->
|
||||
State1 = try Mod:handle_auth_success(
|
||||
State1 = try callback(handle_auth_success,
|
||||
RemoteServer, <<"handshake">>, undefined, State)
|
||||
catch _:undef -> State
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case is_disconnected(State1) of
|
||||
true -> State1;
|
||||
@ -685,9 +685,9 @@ process_handshake(#handshake{data = Digest},
|
||||
process_stream_established(State2)
|
||||
end;
|
||||
false ->
|
||||
State1 = try Mod:handle_auth_failure(
|
||||
State1 = try callback(handle_auth_failure,
|
||||
RemoteServer, <<"handshake">>, <<"not authorized">>, State)
|
||||
catch _:undef -> State
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case is_disconnected(State1) of
|
||||
true -> State1;
|
||||
@ -699,12 +699,12 @@ process_handshake(#handshake{data = Digest},
|
||||
process_stream_established(#{stream_state := StateName} = State)
|
||||
when StateName == disconnected; StateName == established ->
|
||||
State;
|
||||
process_stream_established(#{mod := Mod} = State) ->
|
||||
process_stream_established(State) ->
|
||||
State1 = State#{stream_authenticated => true,
|
||||
stream_state => established,
|
||||
stream_timeout => infinity},
|
||||
try Mod:handle_stream_established(State1)
|
||||
catch _:undef -> State1
|
||||
try callback(handle_stream_established, State1)
|
||||
catch _:{?MODULE, undef} -> State1
|
||||
end.
|
||||
|
||||
-spec process_compress(compress(), state()) -> state().
|
||||
@ -714,9 +714,9 @@ process_compress(#compress{},
|
||||
when Compressed or not Authenticated ->
|
||||
send_pkt(State, #compress_failure{reason = 'setup-failed'});
|
||||
process_compress(#compress{methods = HisMethods},
|
||||
#{socket := Socket, mod := Mod} = State) ->
|
||||
MyMethods = try Mod:compress_methods(State)
|
||||
catch _:undef -> []
|
||||
#{socket := Socket} = State) ->
|
||||
MyMethods = try callback(compress_methods, State)
|
||||
catch _:{?MODULE, undef} -> []
|
||||
end,
|
||||
CommonMethods = lists_intersection(MyMethods, HisMethods),
|
||||
case lists:member(<<"zlib">>, CommonMethods) of
|
||||
@ -745,12 +745,11 @@ process_compress(#compress{methods = HisMethods},
|
||||
-spec process_starttls(state()) -> state().
|
||||
process_starttls(#{stream_encrypted := true} = State) ->
|
||||
process_starttls_failure(already_encrypted, State);
|
||||
process_starttls(#{socket := Socket,
|
||||
mod := Mod} = State) ->
|
||||
process_starttls(#{socket := Socket} = State) ->
|
||||
case is_starttls_available(State) of
|
||||
true ->
|
||||
TLSOpts = try Mod:tls_options(State)
|
||||
catch _:undef -> []
|
||||
TLSOpts = try callback(tls_options, State)
|
||||
catch _:{?MODULE, undef} -> []
|
||||
end,
|
||||
case xmpp_socket:starttls(Socket, TLSOpts) of
|
||||
{ok, TLSSocket} ->
|
||||
@ -782,7 +781,7 @@ process_starttls_failure(Why, State) ->
|
||||
|
||||
-spec process_sasl_request(sasl_auth(), state()) -> state().
|
||||
process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
|
||||
#{mod := Mod, lserver := LServer} = State) ->
|
||||
#{lserver := LServer} = State) ->
|
||||
State1 = State#{sasl_mech => Mech},
|
||||
Mechs = get_sasl_mechanisms(State1),
|
||||
case lists:member(Mech, Mechs) of
|
||||
@ -795,14 +794,14 @@ process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
|
||||
end,
|
||||
process_sasl_result(Res, State1);
|
||||
true ->
|
||||
GetPW = try Mod:get_password_fun(State1)
|
||||
catch _:undef -> fun(_) -> false end
|
||||
GetPW = try callback(get_password_fun, State1)
|
||||
catch _:{?MODULE, undef} -> fun(_) -> false end
|
||||
end,
|
||||
CheckPW = try Mod:check_password_fun(State1)
|
||||
catch _:undef -> fun(_, _, _) -> false end
|
||||
CheckPW = try callback(check_password_fun, State1)
|
||||
catch _:{?MODULE, undef} -> fun(_, _, _) -> false end
|
||||
end,
|
||||
CheckPWDigest = try Mod:check_password_digest_fun(State1)
|
||||
catch _:undef -> fun(_, _, _, _, _) -> false end
|
||||
CheckPWDigest = try callback(check_password_digest_fun, State1)
|
||||
catch _:{?MODULE, undef} -> fun(_, _, _, _, _) -> false end
|
||||
end,
|
||||
SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
|
||||
GetPW, CheckPW, CheckPWDigest),
|
||||
@ -831,13 +830,13 @@ process_sasl_result({error, Reason, User}, State) ->
|
||||
-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
|
||||
process_sasl_success(Props, ServerOut,
|
||||
#{socket := Socket,
|
||||
mod := Mod, sasl_mech := Mech} = State) ->
|
||||
sasl_mech := Mech} = State) ->
|
||||
User = identity(Props),
|
||||
AuthModule = proplists:get_value(auth_module, Props),
|
||||
Socket1 = xmpp_socket:reset_stream(Socket),
|
||||
State0 = State#{socket => Socket1},
|
||||
State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State0)
|
||||
catch _:undef -> State
|
||||
State1 = try callback(handle_auth_success, User, Mech, AuthModule, State0)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case is_disconnected(State1) of
|
||||
true -> State1;
|
||||
@ -865,10 +864,10 @@ process_sasl_continue(ServerOut, NewSASLState, State) ->
|
||||
|
||||
-spec process_sasl_failure(atom(), binary(), state()) -> state().
|
||||
process_sasl_failure(Err, User,
|
||||
#{mod := Mod, sasl_mech := Mech, lang := Lang} = State) ->
|
||||
#{sasl_mech := Mech, lang := Lang} = State) ->
|
||||
{Reason, Text} = format_sasl_error(Mech, Err),
|
||||
State1 = try Mod:handle_auth_failure(User, Mech, Text, State)
|
||||
catch _:undef -> State
|
||||
State1 = try callback(handle_auth_failure, User, Mech, Text, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case is_disconnected(State1) of
|
||||
true -> State1;
|
||||
@ -906,21 +905,21 @@ send_features(State) ->
|
||||
State.
|
||||
|
||||
-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()].
|
||||
get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod,
|
||||
get_sasl_mechanisms(#{stream_encrypted := Encrypted,
|
||||
xmlns := NS, lserver := LServer} = State) ->
|
||||
Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer);
|
||||
true -> []
|
||||
end,
|
||||
TLSVerify = try Mod:tls_verify(State)
|
||||
catch _:undef -> false
|
||||
TLSVerify = try callback(tls_verify, State)
|
||||
catch _:{?MODULE, undef} -> false
|
||||
end,
|
||||
Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
|
||||
[<<"EXTERNAL">>|Mechs];
|
||||
true ->
|
||||
Mechs
|
||||
end,
|
||||
try Mod:sasl_mechanisms(Mechs1, State)
|
||||
catch _:undef -> Mechs1
|
||||
try callback(sasl_mechanisms, Mechs1, State)
|
||||
catch _:{?MODULE, undef} -> Mechs1
|
||||
end.
|
||||
|
||||
-spec get_sasl_feature(state()) -> [sasl_mechanisms()].
|
||||
@ -937,12 +936,12 @@ get_sasl_feature(_) ->
|
||||
[].
|
||||
|
||||
-spec get_compress_feature(state()) -> [compression()].
|
||||
get_compress_feature(#{stream_compressed := false, mod := Mod,
|
||||
get_compress_feature(#{stream_compressed := false,
|
||||
stream_authenticated := true} = State) ->
|
||||
try Mod:compress_methods(State) of
|
||||
try callback(compress_methods, State) of
|
||||
[] -> [];
|
||||
Ms -> [#compression{methods = Ms}]
|
||||
catch _:undef ->
|
||||
catch _:{?MODULE, undef} ->
|
||||
[]
|
||||
end;
|
||||
get_compress_feature(_) ->
|
||||
@ -978,25 +977,25 @@ get_session_feature(_) ->
|
||||
[].
|
||||
|
||||
-spec get_other_features(state()) -> [xmpp_element()].
|
||||
get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
|
||||
get_other_features(#{stream_authenticated := Auth} = State) ->
|
||||
try
|
||||
if Auth -> Mod:authenticated_stream_features(State);
|
||||
true -> Mod:unauthenticated_stream_features(State)
|
||||
if Auth -> callback(authenticated_stream_features, State);
|
||||
true -> callback(unauthenticated_stream_features, State)
|
||||
end
|
||||
catch _:undef ->
|
||||
catch _:{?MODULE, undef} ->
|
||||
[]
|
||||
end.
|
||||
|
||||
-spec is_starttls_available(state()) -> boolean().
|
||||
is_starttls_available(#{mod := Mod} = State) ->
|
||||
try Mod:tls_enabled(State)
|
||||
catch _:undef -> true
|
||||
is_starttls_available(State) ->
|
||||
try callback(tls_enabled, State)
|
||||
catch _:{?MODULE, undef} -> true
|
||||
end.
|
||||
|
||||
-spec is_starttls_required(state()) -> boolean().
|
||||
is_starttls_required(#{mod := Mod} = State) ->
|
||||
try Mod:tls_required(State)
|
||||
catch _:undef -> false
|
||||
is_starttls_required(State) ->
|
||||
try callback(tls_required, State)
|
||||
catch _:{?MODULE, undef} -> false
|
||||
end.
|
||||
|
||||
-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} |
|
||||
@ -1076,10 +1075,10 @@ send_header(State, _) ->
|
||||
State.
|
||||
|
||||
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
|
||||
send_pkt(#{mod := Mod} = State, Pkt) ->
|
||||
send_pkt(State, Pkt) ->
|
||||
Result = socket_send(State, Pkt),
|
||||
State1 = try Mod:handle_send(Pkt, Result, State)
|
||||
catch _:undef -> State
|
||||
State1 = try callback(handle_send, Pkt, Result, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case Result of
|
||||
_ when is_record(Pkt, stream_error) ->
|
||||
@ -1200,3 +1199,36 @@ identity(Props) ->
|
||||
<<>> -> proplists:get_value(username, Props, <<>>);
|
||||
AuthzId -> AuthzId
|
||||
end.
|
||||
|
||||
%%%===================================================================
|
||||
%%% Callbacks
|
||||
%%%===================================================================
|
||||
callback(F, #{mod := Mod} = State) ->
|
||||
case erlang:function_exported(Mod, F, 1) of
|
||||
true -> Mod:F(State);
|
||||
false -> erlang:error({?MODULE, undef})
|
||||
end.
|
||||
|
||||
callback(F, Arg1, #{mod := Mod} = State) ->
|
||||
case erlang:function_exported(Mod, F, 2) of
|
||||
true -> Mod:F(Arg1, State);
|
||||
false -> erlang:error({?MODULE, undef})
|
||||
end.
|
||||
|
||||
callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
|
||||
%% code_change/3 callback is a special snowflake
|
||||
case erlang:function_exported(Mod, code_change, 3) of
|
||||
true -> Mod:code_change(OldVsn, State, Extra);
|
||||
false -> {ok, State}
|
||||
end;
|
||||
callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
|
||||
case erlang:function_exported(Mod, F, 3) of
|
||||
true -> Mod:F(Arg1, Arg2, State);
|
||||
false -> erlang:error({?MODULE, undef})
|
||||
end.
|
||||
|
||||
callback(F, Arg1, Arg2, Arg3, #{mod := Mod} = State) ->
|
||||
case erlang:function_exported(Mod, F, 4) of
|
||||
true -> Mod:F(Arg1, Arg2, Arg3, State);
|
||||
false -> erlang:error({?MODULE, undef})
|
||||
end.
|
||||
|
@ -266,9 +266,9 @@ init([Mod, _SockMod, From, To, Opts]) ->
|
||||
end.
|
||||
|
||||
-spec handle_call(term(), term(), state()) -> noreply().
|
||||
handle_call(Call, From, #{mod := Mod} = State) ->
|
||||
noreply(try Mod:handle_call(Call, From, State)
|
||||
catch _:undef -> State
|
||||
handle_call(Call, From, State) ->
|
||||
noreply(try callback(handle_call, Call, From, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end).
|
||||
|
||||
-spec handle_cast(term(), state()) -> noreply().
|
||||
@ -311,9 +311,9 @@ handle_cast({close, Reason}, State) ->
|
||||
true -> State1;
|
||||
false -> process_stream_end({socket, Reason}, State)
|
||||
end);
|
||||
handle_cast(Cast, #{mod := Mod} = State) ->
|
||||
noreply(try Mod:handle_cast(Cast, State)
|
||||
catch _:undef -> State
|
||||
handle_cast(Cast, State) ->
|
||||
noreply(try callback(handle_cast, Cast, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end).
|
||||
|
||||
-spec handle_info(term(), state()) -> noreply().
|
||||
@ -348,20 +348,20 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
|
||||
send_pkt(State1, Err)
|
||||
end);
|
||||
handle_info({'$gen_event', {xmlstreamelement, El}},
|
||||
#{xmlns := NS, mod := Mod, codec_options := Opts} = State) ->
|
||||
#{xmlns := NS, codec_options := Opts} = State) ->
|
||||
noreply(
|
||||
try xmpp:decode(El, NS, Opts) of
|
||||
Pkt ->
|
||||
State1 = try Mod:handle_recv(El, Pkt, State)
|
||||
catch _:undef -> State
|
||||
State1 = try callback(handle_recv, El, Pkt, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case is_disconnected(State1) of
|
||||
true -> State1;
|
||||
false -> process_element(Pkt, State1)
|
||||
end
|
||||
catch _:{xmpp_codec, Why} ->
|
||||
State1 = try Mod:handle_recv(El, {error, Why}, State)
|
||||
catch _:undef -> State
|
||||
State1 = try callback(handle_recv, El, {error, Why}, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case is_disconnected(State1) of
|
||||
true -> State1;
|
||||
@ -369,21 +369,21 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
|
||||
end
|
||||
end);
|
||||
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
|
||||
#{mod := Mod} = State) ->
|
||||
noreply(try Mod:handle_cdata(Data, State)
|
||||
catch _:undef -> State
|
||||
State) ->
|
||||
noreply(try callback(handle_cdata, Data, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end);
|
||||
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
|
||||
noreply(process_stream_end({stream, reset}, State));
|
||||
handle_info({'$gen_event', closed}, State) ->
|
||||
noreply(process_stream_end({socket, closed}, State));
|
||||
handle_info(timeout, #{mod := Mod, lang := Lang} = State) ->
|
||||
handle_info(timeout, #{lang := Lang} = State) ->
|
||||
Disconnected = is_disconnected(State),
|
||||
noreply(try Mod:handle_timeout(State)
|
||||
catch _:undef when not Disconnected ->
|
||||
noreply(try callback(handle_timeout, State)
|
||||
catch _:{?MODULE, undef} when not Disconnected ->
|
||||
Txt = <<"Idle connection">>,
|
||||
send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
|
||||
_:undef ->
|
||||
_:{?MODULE, undef} ->
|
||||
stop(State)
|
||||
end);
|
||||
handle_info({'DOWN', MRef, _Type, _Object, _Info},
|
||||
@ -404,26 +404,26 @@ handle_info({tcp_closed, _}, State) ->
|
||||
handle_info({'$gen_event', closed}, State);
|
||||
handle_info({tcp_error, _, Reason}, State) ->
|
||||
noreply(process_stream_end({socket, Reason}, State));
|
||||
handle_info(Info, #{mod := Mod} = State) ->
|
||||
noreply(try Mod:handle_info(Info, State)
|
||||
catch _:undef -> State
|
||||
handle_info(Info, State) ->
|
||||
noreply(try callback(handle_info, Info, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end).
|
||||
|
||||
-spec terminate(term(), state()) -> any().
|
||||
terminate(Reason, #{mod := Mod} = State) ->
|
||||
terminate(Reason, State) ->
|
||||
case get(already_terminated) of
|
||||
true ->
|
||||
State;
|
||||
_ ->
|
||||
put(already_terminated, true),
|
||||
try Mod:terminate(Reason, State)
|
||||
catch _:undef -> ok
|
||||
try callback(terminate, Reason, State)
|
||||
catch _:{?MODULE, undef} -> ok
|
||||
end,
|
||||
send_trailer(State)
|
||||
end.
|
||||
|
||||
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
|
||||
Mod:code_change(OldVsn, State, Extra).
|
||||
code_change(OldVsn, State, Extra) ->
|
||||
callback(code_change, OldVsn, State, Extra).
|
||||
|
||||
%%%===================================================================
|
||||
%%% Internal functions
|
||||
@ -458,11 +458,11 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
|
||||
-spec process_stream_end(stop_reason(), state()) -> state().
|
||||
process_stream_end(_, #{stream_state := disconnected} = State) ->
|
||||
State;
|
||||
process_stream_end(Reason, #{mod := Mod} = State) ->
|
||||
process_stream_end(Reason, State) ->
|
||||
State1 = State#{stream_timeout => infinity,
|
||||
stream_state => disconnected},
|
||||
try Mod:handle_stream_end(Reason, State1)
|
||||
catch _:undef -> stop(State1)
|
||||
try callback(handle_stream_end, Reason, State1)
|
||||
catch _:{?MODULE, undef} -> stop(State1)
|
||||
end.
|
||||
|
||||
-spec process_stream(stream_start(), state()) -> state().
|
||||
@ -475,10 +475,10 @@ process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
|
||||
send_pkt(State, xmpp:serr_unsupported_version());
|
||||
process_stream(#stream_start{lang = Lang, id = ID,
|
||||
version = Version} = StreamStart,
|
||||
#{mod := Mod} = State) ->
|
||||
State) ->
|
||||
State1 = State#{stream_remote_id => ID, lang => Lang},
|
||||
State2 = try Mod:handle_stream_start(StreamStart, State1)
|
||||
catch _:undef -> State1
|
||||
State2 = try callback(handle_stream_start, StreamStart, State1)
|
||||
catch _:{?MODULE, undef} -> State1
|
||||
end,
|
||||
case is_disconnected(State2) of
|
||||
true -> State2;
|
||||
@ -522,16 +522,16 @@ process_element(Pkt, #{stream_state := StateName} = State) ->
|
||||
|
||||
-spec process_features(stream_features(), state()) -> state().
|
||||
process_features(StreamFeatures,
|
||||
#{stream_authenticated := true, mod := Mod} = State) ->
|
||||
State1 = try Mod:handle_authenticated_features(StreamFeatures, State)
|
||||
catch _:undef -> State
|
||||
#{stream_authenticated := true} = State) ->
|
||||
State1 = try callback(handle_authenticated_features, StreamFeatures, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
process_stream_established(State1);
|
||||
process_features(StreamFeatures,
|
||||
#{stream_encrypted := Encrypted,
|
||||
mod := Mod, lang := Lang} = State) ->
|
||||
State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State)
|
||||
catch _:undef -> State
|
||||
lang := Lang} = State) ->
|
||||
State1 = try callback(handle_unauthenticated_features, StreamFeatures, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case is_disconnected(State1) of
|
||||
true -> State1;
|
||||
@ -582,12 +582,12 @@ process_features(StreamFeatures,
|
||||
process_stream_established(#{stream_state := StateName} = State)
|
||||
when StateName == disconnected; StateName == established ->
|
||||
State;
|
||||
process_stream_established(#{mod := Mod} = State) ->
|
||||
process_stream_established(State) ->
|
||||
State1 = State#{stream_authenticated := true,
|
||||
stream_state => established,
|
||||
stream_timeout => infinity},
|
||||
try Mod:handle_stream_established(State1)
|
||||
catch _:undef -> State1
|
||||
try callback(handle_stream_established, State1)
|
||||
catch _:{?MODULE, undef} -> State1
|
||||
end.
|
||||
|
||||
-spec process_sasl_mechanisms([binary()], state()) -> state().
|
||||
@ -620,7 +620,7 @@ process_starttls(#{socket := Socket} = State) ->
|
||||
|
||||
-spec process_stream_downgrade(stream_start(), state()) -> state().
|
||||
process_stream_downgrade(StreamStart,
|
||||
#{mod := Mod, lang := Lang,
|
||||
#{lang := Lang,
|
||||
stream_encrypted := Encrypted} = State) ->
|
||||
TLSRequired = is_starttls_required(State),
|
||||
if not Encrypted and TLSRequired ->
|
||||
@ -628,18 +628,17 @@ process_stream_downgrade(StreamStart,
|
||||
send_pkt(State, xmpp:serr_policy_violation(Txt, Lang));
|
||||
true ->
|
||||
State1 = State#{stream_state => downgraded},
|
||||
try Mod:handle_stream_downgraded(StreamStart, State1)
|
||||
catch _:undef ->
|
||||
try callback(handle_stream_downgraded, StreamStart, State1)
|
||||
catch _:{?MODULE, undef} ->
|
||||
send_pkt(State1, xmpp:serr_unsupported_version())
|
||||
end
|
||||
end.
|
||||
|
||||
-spec process_cert_verification(state()) -> state().
|
||||
process_cert_verification(#{stream_encrypted := true,
|
||||
stream_verified := false,
|
||||
mod := Mod} = State) ->
|
||||
case try Mod:tls_verify(State)
|
||||
catch _:undef -> true
|
||||
stream_verified := false} = State) ->
|
||||
case try callback(tls_verify, State)
|
||||
catch _:{?MODULE, undef} -> true
|
||||
end of
|
||||
true ->
|
||||
case xmpp_stream_pkix:authenticate(State) of
|
||||
@ -655,8 +654,7 @@ process_cert_verification(State) ->
|
||||
State.
|
||||
|
||||
-spec process_sasl_success(state()) -> state().
|
||||
process_sasl_success(#{mod := Mod,
|
||||
socket := Socket} = State) ->
|
||||
process_sasl_success(#{socket := Socket} = State) ->
|
||||
Socket1 = xmpp_socket:reset_stream(Socket),
|
||||
State0 = State#{socket => Socket1},
|
||||
State1 = State0#{stream_id => new_id(),
|
||||
@ -667,8 +665,8 @@ process_sasl_success(#{mod := Mod,
|
||||
case is_disconnected(State2) of
|
||||
true -> State2;
|
||||
false ->
|
||||
try Mod:handle_auth_success(<<"EXTERNAL">>, State2)
|
||||
catch _:undef -> State2
|
||||
try callback(handle_auth_success, <<"EXTERNAL">>, State2)
|
||||
catch _:{?MODULE, undef} -> State2
|
||||
end
|
||||
end.
|
||||
|
||||
@ -677,27 +675,27 @@ process_sasl_failure(#sasl_failure{} = Failure, State) ->
|
||||
Reason = format("Peer responded with error: ~s",
|
||||
[format_sasl_failure(Failure)]),
|
||||
process_sasl_failure(Reason, State);
|
||||
process_sasl_failure(Reason, #{mod := Mod} = State) ->
|
||||
try Mod:handle_auth_failure(<<"EXTERNAL">>, {auth, Reason}, State)
|
||||
catch _:undef -> process_stream_end({auth, Reason}, State)
|
||||
process_sasl_failure(Reason, State) ->
|
||||
try callback(handle_auth_failure, <<"EXTERNAL">>, {auth, Reason}, State)
|
||||
catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State)
|
||||
end.
|
||||
|
||||
-spec process_packet(xmpp_element(), state()) -> state().
|
||||
process_packet(Pkt, #{mod := Mod} = State) ->
|
||||
try Mod:handle_packet(Pkt, State)
|
||||
catch _:undef -> State
|
||||
process_packet(Pkt, State) ->
|
||||
try callback(handle_packet, Pkt, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end.
|
||||
|
||||
-spec is_starttls_required(state()) -> boolean().
|
||||
is_starttls_required(#{mod := Mod} = State) ->
|
||||
try Mod:tls_required(State)
|
||||
catch _:undef -> false
|
||||
is_starttls_required(State) ->
|
||||
try callback(tls_required, State)
|
||||
catch _:{?MODULE, undef} -> false
|
||||
end.
|
||||
|
||||
-spec is_starttls_available(state()) -> boolean().
|
||||
is_starttls_available(#{mod := Mod} = State) ->
|
||||
try Mod:tls_enabled(State)
|
||||
catch _:undef -> true
|
||||
is_starttls_available(State) ->
|
||||
try callback(tls_enabled, State)
|
||||
catch _:{?MODULE, undef} -> true
|
||||
end.
|
||||
|
||||
-spec send_header(state()) -> state().
|
||||
@ -731,10 +729,10 @@ send_header(#{remote_server := RemoteServer,
|
||||
end.
|
||||
|
||||
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
|
||||
send_pkt(#{mod := Mod} = State, Pkt) ->
|
||||
send_pkt(State, Pkt) ->
|
||||
Result = socket_send(State, Pkt),
|
||||
State1 = try Mod:handle_send(Pkt, Result, State)
|
||||
catch _:undef -> State
|
||||
State1 = try callback(handle_send, Pkt, Result, State)
|
||||
catch _:{?MODULE, undef} -> State
|
||||
end,
|
||||
case Result of
|
||||
_ when is_record(Pkt, stream_error) ->
|
||||
@ -795,10 +793,10 @@ close_socket(State) ->
|
||||
stream_state => disconnected}.
|
||||
|
||||
-spec starttls(term(), state()) -> {ok, term()} | {error, tls_error_reason()}.
|
||||
starttls(Socket, #{mod := Mod, xmlns := NS,
|
||||
starttls(Socket, #{xmlns := NS,
|
||||
remote_server := RemoteServer} = State) ->
|
||||
TLSOpts = try Mod:tls_options(State)
|
||||
catch _:undef -> []
|
||||
TLSOpts = try callback(tls_options, State)
|
||||
catch _:{?MODULE, undef} -> []
|
||||
end,
|
||||
SNI = idna_to_ascii(RemoteServer),
|
||||
ALPN = case NS of
|
||||
@ -1077,32 +1075,59 @@ get_addr_type({_, _, _, _}) -> inet;
|
||||
get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
|
||||
|
||||
-spec get_dns_timeout(state()) -> timeout().
|
||||
get_dns_timeout(#{mod := Mod} = State) ->
|
||||
try Mod:dns_timeout(State)
|
||||
catch _:undef -> timer:seconds(10)
|
||||
get_dns_timeout(State) ->
|
||||
try callback(dns_timeout, State)
|
||||
catch _:{?MODULE, undef} -> timer:seconds(10)
|
||||
end.
|
||||
|
||||
-spec get_dns_retries(state()) -> non_neg_integer().
|
||||
get_dns_retries(#{mod := Mod} = State) ->
|
||||
try Mod:dns_retries(State)
|
||||
catch _:undef -> 2
|
||||
get_dns_retries(State) ->
|
||||
try callback(dns_retries, State)
|
||||
catch _:{?MODULE, undef} -> 2
|
||||
end.
|
||||
|
||||
-spec get_default_port(state()) -> inet:port_number().
|
||||
get_default_port(#{mod := Mod, xmlns := NS} = State) ->
|
||||
try Mod:default_port(State)
|
||||
catch _:undef when NS == ?NS_SERVER -> 5269;
|
||||
_:undef when NS == ?NS_CLIENT -> 5222
|
||||
get_default_port(#{xmlns := NS} = State) ->
|
||||
try callback(default_port, State)
|
||||
catch _:{?MODULE, undef} when NS == ?NS_SERVER -> 5269;
|
||||
_:{?MODULE, undef} when NS == ?NS_CLIENT -> 5222
|
||||
end.
|
||||
|
||||
-spec get_address_families(state()) -> [inet:address_family()].
|
||||
get_address_families(#{mod := Mod} = State) ->
|
||||
try Mod:address_families(State)
|
||||
catch _:undef -> [inet, inet6]
|
||||
get_address_families(State) ->
|
||||
try callback(address_families, State)
|
||||
catch _:{?MODULE, undef} -> [inet, inet6]
|
||||
end.
|
||||
|
||||
-spec get_connect_timeout(state()) -> timeout().
|
||||
get_connect_timeout(#{mod := Mod} = State) ->
|
||||
try Mod:connect_timeout(State)
|
||||
catch _:undef -> timer:seconds(10)
|
||||
get_connect_timeout(State) ->
|
||||
try callback(connect_timeout, State)
|
||||
catch _:{?MODULE, undef} -> timer:seconds(10)
|
||||
end.
|
||||
|
||||
%%%===================================================================
|
||||
%%% Callbacks
|
||||
%%%===================================================================
|
||||
callback(F, #{mod := Mod} = State) ->
|
||||
case erlang:function_exported(Mod, F, 1) of
|
||||
true -> Mod:F(State);
|
||||
false -> erlang:error({?MODULE, undef})
|
||||
end.
|
||||
|
||||
callback(F, Arg1, #{mod := Mod} = State) ->
|
||||
case erlang:function_exported(Mod, F, 2) of
|
||||
true -> Mod:F(Arg1, State);
|
||||
false -> erlang:error({?MODULE, undef})
|
||||
end.
|
||||
|
||||
callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
|
||||
%% code_change/3 callback is a special snowflake
|
||||
case erlang:function_exported(Mod, code_change, 3) of
|
||||
true -> Mod:code_change(OldVsn, State, Extra);
|
||||
false -> {ok, State}
|
||||
end;
|
||||
callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
|
||||
case erlang:function_exported(Mod, F, 3) of
|
||||
true -> Mod:F(Arg1, Arg2, State);
|
||||
false -> erlang:error({?MODULE, undef})
|
||||
end.
|
||||
|
Loading…
Reference in New Issue
Block a user