diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl index 135972717..d0360d0c8 100644 --- a/src/xmpp_stream_out.erl +++ b/src/xmpp_stream_out.erl @@ -90,9 +90,11 @@ -callback tls_required(state()) -> boolean(). -callback tls_verify(state()) -> boolean(). -callback tls_enabled(state()) -> boolean(). +-callback resolve(string(), state()) -> [host_port()]. -callback dns_timeout(state()) -> timeout(). -callback dns_retries(state()) -> non_neg_integer(). -callback default_port(state()) -> inet:port_number(). +-callback connect_options(inet:ip_address(), list(), state()) -> list(). -callback address_families(state()) -> [inet:address_family()]. -callback connect_timeout(state()) -> timeout(). @@ -121,9 +123,11 @@ tls_required/1, tls_verify/1, tls_enabled/1, + resolve/2, dns_timeout/1, dns_retries/1, default_port/1, + connect_options/3, address_families/1, connect_timeout/1]). @@ -394,8 +398,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}}, false -> process_invalid_xml(State1, El, Why) end end); -handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, - State) -> +handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, State) -> noreply(try callback(handle_cdata, Data, State) catch _:{?MODULE, undef} -> State end); @@ -934,6 +937,17 @@ idna_to_ascii(Host) -> -spec resolve(string(), state()) -> {ok, [ip_port()]} | network_error(). resolve(Host, State) -> + try callback(resolve, Host, State) of + [] -> + do_resolve(Host, State); + HostPorts -> + a_lookup(HostPorts, State) + catch _:{?MODULE, undef} -> + do_resolve(Host, State) + end. + +-spec do_resolve(string(), state()) -> {ok, [ip_port()]} | network_error(). +do_resolve(Host, State) -> case srv_lookup(Host, State) of {error, _Reason} -> DefaultPort = get_default_port(State), @@ -967,10 +981,14 @@ srv_lookup(Host, State) -> end end. -srv_lookup(Host, State, Timeout, Retries) -> +srv_lookup(Host, #{xmlns := NS} = State, Timeout, Retries) -> + SRVType = case NS of + ?NS_SERVER -> "-server._tcp."; + ?NS_CLIENT -> "-client._tcp." + end, TLSAddrs = case is_starttls_available(State) of true -> - case srv_lookup("_xmpps-server._tcp." ++ Host, + case srv_lookup("_xmpps" ++ SRVType ++ Host, Timeout, Retries) of {ok, HostEnt} -> [{A, true} || A <- HostEnt#hostent.h_addr_list]; @@ -980,7 +998,7 @@ srv_lookup(Host, State, Timeout, Retries) -> false -> [] end, - case srv_lookup("_xmpp-server._tcp." ++ Host, Timeout, Retries) of + case srv_lookup("_xmpp" ++ SRVType ++ Host, Timeout, Retries) of {ok, HostEntry} -> Addrs = [{A, false} || A <- HostEntry#hostent.h_addr_list], {ok, TLSAddrs ++ Addrs}; @@ -1033,24 +1051,35 @@ a_lookup([], _State, Acc, _) -> a_lookup(_Host, _Port, _TLS, _Family, _Timeout, Retries) when Retries < 1 -> {error, timeout}; a_lookup(Host, Port, TLS, Family, Timeout, Retries) -> - Start = p1_time_compat:monotonic_time(milli_seconds), - case inet:gethostbyname(Host, Family, Timeout) of - {error, nxdomain} = Err -> - %% inet:gethostbyname/3 doesn't return {error, timeout}, - %% so we should check if 'nxdomain' is in fact a result - %% of a timeout. - %% We also cannot use inet_res:gethostbyname/3 because - %% it ignores DNS configuration settings (/etc/hosts, etc) - End = p1_time_compat:monotonic_time(milli_seconds), - if (End - Start) >= Timeout -> - a_lookup(Host, Port, TLS, Family, Timeout, Retries - 1); + case inet:parse_address(Host) of + {ok, Addr} -> + if tuple_size(Addr) == 4 andalso Family == inet -> + {ok, [{Addr, Port, TLS}]}; + tuple_size(Addr) == 8 andalso Family == inet6 -> + {ok, [{Addr, Port, TLS}]}; true -> - Err + {error, nxdomain} end; - {error, _} = Err -> - Err; - {ok, HostEntry} -> - host_entry_to_addr_ports(HostEntry, Port, TLS) + {error, _} -> + Start = p1_time_compat:monotonic_time(milli_seconds), + case inet:gethostbyname(Host, Family, Timeout) of + {error, nxdomain} = Err -> + %% inet:gethostbyname/3 doesn't return {error, timeout}, + %% so we should check if 'nxdomain' is in fact a result + %% of a timeout. + %% We also cannot use inet_res:gethostbyname/3 because + %% it ignores DNS configuration settings (/etc/hosts, etc) + End = p1_time_compat:monotonic_time(milli_seconds), + if (End - Start) >= Timeout -> + a_lookup(Host, Port, TLS, Family, Timeout, Retries - 1); + true -> + Err + end; + {error, _} = Err -> + Err; + {ok, HostEntry} -> + host_entry_to_addr_ports(HostEntry, Port, TLS) + end end. -spec h_addr_list_to_host_ports(h_addr_list()) -> {ok, [host_port()]} | @@ -1094,7 +1123,7 @@ host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port, TLS) -> {error, {tls, tls_error_reason()}}. connect(AddrPorts, State) -> Timeout = get_connect_timeout(State), - case connect(AddrPorts, Timeout, {error, nxdomain}) of + case connect(AddrPorts, Timeout, State, {error, nxdomain}) of {ok, Socket, {Addr, Port, TLS = true}} -> case starttls(Socket, State) of {ok, TLSSocket} -> {ok, TLSSocket, {Addr, Port, TLS}}; @@ -1106,24 +1135,26 @@ connect(AddrPorts, State) -> {error, {socket, Why}} end. --spec connect([ip_port()], timeout(), network_error()) -> +-spec connect([ip_port()], timeout(), state(), network_error()) -> {ok, term(), ip_port()} | network_error(). -connect([{Addr, Port, TLS}|AddrPorts], Timeout, _) -> +connect([{Addr, Port, TLS}|AddrPorts], Timeout, State, _) -> Type = get_addr_type(Addr), - try xmpp_socket:connect(Addr, Port, - [binary, {packet, 0}, - {send_timeout, ?TCP_SEND_TIMEOUT}, - {send_timeout_close, true}, - {active, false}, Type], - Timeout) of + Opts = [binary, {packet, 0}, + {send_timeout, ?TCP_SEND_TIMEOUT}, + {send_timeout_close, true}, + {active, false}, Type], + Opts1 = try callback(connect_options, Addr, Opts, State) + catch _:{?MODULE, undef} -> Opts + end, + try xmpp_socket:connect(Addr, Port, Opts1, Timeout) of {ok, Socket} -> {ok, Socket, {Addr, Port, TLS}}; Err -> - connect(AddrPorts, Timeout, Err) + connect(AddrPorts, Timeout, State, Err) catch _:badarg -> - connect(AddrPorts, Timeout, {error, einval}) + connect(AddrPorts, Timeout, State, {error, einval}) end; -connect([], _Timeout, Err) -> +connect([], _Timeout, _State, Err) -> Err. -spec get_addr_type(inet:ip_address()) -> inet:address_family().