diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl index 4f8be911e..55fa3a4bf 100644 --- a/src/xmpp_stream_in.erl +++ b/src/xmpp_stream_in.erl @@ -836,13 +836,13 @@ process_sasl_success(Props, ServerOut, AuthModule = proplists:get_value(auth_module, Props), Socket1 = xmpp_socket:reset_stream(Socket), State0 = State#{socket => Socket1}, - State1 = send_pkt(State0, #sasl_success{text = ServerOut}), + State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State0) + catch _:undef -> State + end, case is_disconnected(State1) of true -> State1; false -> - State2 = try Mod:handle_auth_success(User, Mech, AuthModule, State1) - catch _:undef -> State1 - end, + State2 = send_pkt(State1, #sasl_success{text = ServerOut}), case is_disconnected(State2) of true -> State2; false -> @@ -867,16 +867,22 @@ process_sasl_continue(ServerOut, NewSASLState, State) -> process_sasl_failure(Err, User, #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) -> {Reason, Text} = format_sasl_error(Mech, Err), - State1 = send_pkt(State, #sasl_failure{reason = Reason, - text = xmpp:mk_text(Text, Lang)}), + State1 = try Mod:handle_auth_failure(User, Mech, Text, State) + catch _:undef -> State + end, case is_disconnected(State1) of true -> State1; false -> - State2 = try Mod:handle_auth_failure(User, Mech, Text, State1) - catch _:undef -> State1 - end, - State3 = maps:remove(sasl_state, maps:remove(sasl_mech, State2)), - State3#{stream_state => wait_for_sasl_request} + State2 = send_pkt(State1, + #sasl_failure{reason = Reason, + text = xmpp:mk_text(Text, Lang)}), + case is_disconnected(State2) of + true -> State2; + false -> + State3 = maps:remove(sasl_state, + maps:remove(sasl_mech, State2)), + State3#{stream_state => wait_for_sasl_request} + end end. -spec process_sasl_abort(state()) -> state().