diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl index 5401d3073..6cf762025 100644 --- a/src/xmpp_stream_out.erl +++ b/src/xmpp_stream_out.erl @@ -29,7 +29,7 @@ %% API -export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1, - stop/1, send/2, close/1, close/2, establish/1, format_error/1, + stop/1, send/2, close/1, close/2, bind/2, establish/1, format_error/1, set_timeout/2, get_transport/1, change_shaper/2]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, @@ -61,6 +61,7 @@ {tls, tls_error_reason()} | {pkix, binary()} | {auth, atom() | binary() | string()} | + {bind, stanza_error()} | {socket, socket_error_reason()} | internal_failure. -export_type([state/0, stop_reason/0]). @@ -82,6 +83,8 @@ -callback handle_unauthenticated_features(stream_features(), state()) -> state(). -callback handle_auth_success(cyrsasl:mechanism(), state()) -> state(). -callback handle_auth_failure(cyrsasl:mechanism(), binary(), state()) -> state(). +-callback handle_bind_success(state()) -> state(). +-callback handle_bind_failure(stanza_error(), state()) -> state(). -callback handle_packet(xmpp_element(), state()) -> state(). -callback tls_options(state()) -> [proplists:property()]. -callback tls_required(state()) -> boolean(). @@ -111,6 +114,8 @@ handle_unauthenticated_features/2, handle_auth_success/2, handle_auth_failure/3, + handle_bind_success/1, + handle_bind_failure/2, handle_packet/2, tls_options/1, tls_required/1, @@ -176,6 +181,10 @@ close(_) -> close(Pid, Reason) -> cast(Pid, {close, Reason}). +-spec bind(state(), stream_features()) -> state(). +bind(#{stream_authenticated := true} = State, StreamFeatures) -> + process_bind(StreamFeatures, State). + -spec establish(state()) -> state(). establish(State) -> process_stream_established(State). @@ -221,6 +230,8 @@ format_error({stream, {in, #stream_error{} = Err}}) -> format("Stream closed by peer: ~s", [xmpp:format_stream_error(Err)]); format_error({stream, {out, #stream_error{} = Err}}) -> format("Stream closed by us: ~s", [xmpp:format_stream_error(Err)]); +format_error({bind, #stanza_error{} = Err}) -> + format("Resource binding failure: ~s", [xmpp:format_stanza_error(Err)]); format_error({tls, Reason}) -> format("TLS failed: ~s", [format_tls_error(Reason)]); format_error({auth, Reason}) -> @@ -515,6 +526,8 @@ process_element(Pkt, #{stream_state := StateName} = State) -> is_record(Pkt, handshake) -> %% Do not pass this crap upstream State; + _ when StateName == wait_for_bind_response -> + process_bind_response(Pkt, State); _ -> process_packet(Pkt, State) end. @@ -522,10 +535,9 @@ process_element(Pkt, #{stream_state := StateName} = State) -> -spec process_features(stream_features(), state()) -> state(). process_features(StreamFeatures, #{stream_authenticated := true} = State) -> - State1 = try callback(handle_authenticated_features, StreamFeatures, State) - catch _:{?MODULE, undef} -> State - end, - process_stream_established(State1); + try callback(handle_authenticated_features, StreamFeatures, State) + catch _:{?MODULE, undef} -> process_bind(StreamFeatures, State) + end; process_features(StreamFeatures, #{stream_encrypted := Encrypted, lang := Lang, xmlns := NS} = State) -> @@ -679,6 +691,59 @@ process_sasl_failure(Reason, State) -> catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State) end. +-spec process_bind(stream_features(), state()) -> state(). +process_bind(StreamFeatures, #{lang := Lang, xmlns := ?NS_CLIENT, + user := U, server := S, resource := R, + stream_state := StateName} = State) + when StateName /= established, StateName /= disconnected -> + case xmpp:has_subtag(StreamFeatures, #bind{}) of + true -> + JID = jid:make(U, S, R), + ID = new_id(), + Pkt = #iq{from = JID, to = jid:remove_resource(JID), + id = ID, type = set, + sub_els = [#bind{resource = R}]}, + State1 = State#{stream_state => wait_for_bind_response, + bind_id => ID}, + send_pkt(State1, Pkt); + false -> + Txt = <<"Missing resource binding feature">>, + send_pkt(State, xmpp:serr_invalid_xml(Txt, Lang)) + end; +process_bind(_, State) -> + process_stream_established(State). + +-spec process_bind_response(xmpp_element(), state()) -> state(). +process_bind_response(#iq{type = result, id = ID} = IQ, + #{lang := Lang, bind_id := ID} = State) -> + State1 = maps:remove(bind_id, State), + try xmpp:try_subtag(IQ, #bind{}) of + #bind{jid = #jid{user = U, server = S, resource = R}} -> + State2 = State1#{user => U, server => S, resource => R}, + State3 = try callback(handle_bind_success, State2) + catch _:{?MODULE, undef} -> State2 + end, + process_stream_established(State3); + #bind{} -> + Txt = <<"Missing element in resource binding response">>, + send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang)); + false -> + Txt = <<"Missing element in resource binding response">>, + send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang)) + catch _:{xmpp_codec, Why} -> + Txt = xmpp:io_format_error(Why), + send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang)) + end; +process_bind_response(#iq{type = error, id = ID} = IQ, + #{bind_id := ID} = State) -> + Err = xmpp:get_error(IQ), + State1 = maps:remove(bind_id, State), + try callback(handle_bind_failure, Err, State1) + catch _:{?MODULE, undef} -> process_stream_end({bind, Err}, State1) + end; +process_bind_response(Pkt, State) -> + process_packet(Pkt, State). + -spec process_packet(xmpp_element(), state()) -> state(). process_packet(Pkt, State) -> try callback(handle_packet, Pkt, State)