diff --git a/src/ejabberd_receiver.erl b/src/ejabberd_receiver.erl index cd7c1d31a..71ae8e409 100644 --- a/src/ejabberd_receiver.erl +++ b/src/ejabberd_receiver.erl @@ -127,20 +127,10 @@ init([Socket, SockMod, Shaper, MaxStanzaSize]) -> shaper_state = ShaperState, max_stanza_size = MaxStanzaSize, timeout = Timeout}}. -handle_call({starttls, TLSSocket}, _From, - #state{xml_stream_state = XMLStreamState, - c2s_pid = C2SPid, - max_stanza_size = MaxStanzaSize} = State) -> - close_stream(XMLStreamState), - NewXMLStreamState = case C2SPid of - undefined -> - XMLStreamState; - _ -> - xml_stream:new(C2SPid, MaxStanzaSize) - end, - NewState = State#state{socket = TLSSocket, - sock_mod = p1_tls, - xml_stream_state = NewXMLStreamState}, +handle_call({starttls, TLSSocket}, _From, State) -> + State1 = reset_parser(State), + NewState = State1#state{socket = TLSSocket, + sock_mod = p1_tls}, case p1_tls:recv_data(TLSSocket, <<"">>) of {ok, TLSData} -> {reply, ok, @@ -149,20 +139,16 @@ handle_call({starttls, TLSSocket}, _From, {stop, normal, ok, NewState} end; handle_call({compress, Data}, _From, - #state{xml_stream_state = XMLStreamState, - c2s_pid = C2SPid, socket = Socket, sock_mod = SockMod, - max_stanza_size = MaxStanzaSize} = + #state{socket = Socket, sock_mod = SockMod} = State) -> {ok, ZlibSocket} = ezlib:enable_zlib(SockMod, Socket), if Data /= undefined -> do_send(State, Data); true -> ok end, - close_stream(XMLStreamState), - NewXMLStreamState = xml_stream:new(C2SPid, MaxStanzaSize), - NewState = State#state{socket = ZlibSocket, - sock_mod = ezlib, - xml_stream_state = NewXMLStreamState}, + State1 = reset_parser(State), + NewState = State1#state{socket = ZlibSocket, + sock_mod = ezlib}, case ezlib:recv_data(ZlibSocket, <<"">>) of {ok, ZlibData} -> {reply, {ok, ZlibSocket}, @@ -170,16 +156,10 @@ handle_call({compress, Data}, _From, {error, _Reason} -> {stop, normal, ok, NewState} end; -handle_call(reset_stream, _From, - #state{xml_stream_state = XMLStreamState, - c2s_pid = C2SPid, max_stanza_size = MaxStanzaSize} = - State) -> - close_stream(XMLStreamState), - NewXMLStreamState = xml_stream:new(C2SPid, MaxStanzaSize), +handle_call(reset_stream, _From, State) -> + NewState = reset_parser(State), Reply = ok, - {reply, Reply, - State#state{xml_stream_state = NewXMLStreamState}, - ?HIBERNATE_TIMEOUT}; + {reply, Reply, NewState, ?HIBERNATE_TIMEOUT}; handle_call({become_controller, C2SPid}, _From, State) -> XMLStreamState = xml_stream:new(C2SPid, State#state.max_stanza_size), NewState = State#state{c2s_pid = C2SPid, @@ -332,6 +312,24 @@ close_stream(undefined) -> ok; close_stream(XMLStreamState) -> xml_stream:close(XMLStreamState). +reset_parser(#state{xml_stream_state = undefined} = State) -> + State; +reset_parser(#state{c2s_pid = C2SPid, + max_stanza_size = MaxStanzaSize, + xml_stream_state = XMLStreamState} + = State) -> + NewStreamState = try xml_stream:reset(XMLStreamState) + catch error:_ -> + close_stream(XMLStreamState), + case C2SPid of + undefined -> + undefined; + _ -> + xml_stream:new(C2SPid, MaxStanzaSize) + end + end, + State#state{xml_stream_state = NewStreamState}. + do_send(State, Data) -> (State#state.sock_mod):send(State#state.socket, Data).