From 3d329c7e8f43bc2a8ac05ad2c54b13df184894ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Chmielowski?= Date: Thu, 28 Jul 2022 13:17:35 +0200 Subject: [PATCH] Make connection close errors bubble up from inside sql transaction --- src/ejabberd_sql.erl | 72 ++++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/src/ejabberd_sql.erl b/src/ejabberd_sql.erl index 1e32a87c8..d4d7c74a8 100644 --- a/src/ejabberd_sql.erl +++ b/src/ejabberd_sql.erl @@ -479,6 +479,12 @@ handle_reconnect(Reason, #state{host = Host, reconnect_count = RC} = State) -> "** Retry after: ~B seconds", [State#state.db_type, Reason, StartInterval div 1000]), + case State#state.db_type of + mysql -> catch p1_mysql_conn:stop(State#state.db_ref); + sqlite -> catch sqlite3:close(sqlite_db(State#state.host)); + pgsql -> catch pgsql:terminate(State#state.db_ref); + _ -> ok + end, p1_fsm:send_event_after(StartInterval, connect), {next_state, connecting, State#state{reconnect_count = RC + 1}}. @@ -562,9 +568,7 @@ outer_transaction(F, NRestarts, _Reason) -> {atomic, Res} catch ?EX_RULE(throw, {aborted, Reason}, _) when NRestarts > 0 -> - sql_rollback(), - put(?NESTING_KEY, ?TOP_LEVEL_TXN), - outer_transaction(F, NRestarts - 1, Reason); + rollback_transaction(F, NRestarts, Reason); ?EX_RULE(throw, {aborted, Reason}, Stack) when NRestarts =:= 0 -> StackTrace = ?EX_STACK(Stack), ?ERROR_MSG("SQL transaction restarts exceeded~n** " @@ -573,11 +577,36 @@ outer_transaction(F, NRestarts, _Reason) -> "== ~p", [?MAX_TRANSACTION_RESTARTS, Reason, StackTrace, get(?STATE_KEY)]), - sql_rollback(), - {aborted, Reason}; + rollback_transaction(F, NRestarts, Reason); ?EX_RULE(exit, Reason, _) -> - sql_rollback(), - {aborted, Reason} + rollback_transaction(F, 0, Reason) + end. + +rollback_transaction(F, NRestarts, Reason) -> + Res = case driver_restart_required(Reason) of + true -> + {aborted, Reason}; + _ -> + case sql_rollback() of + {Tag, Reason2} when Tag == error; Tag == aborted -> + case driver_restart_required(Reason2) of + true -> + {aborted, Reason2}; + _ -> + continue + end; + _ -> + continue + end + end, + case Res of + continue when NRestarts > 0 -> + put(?NESTING_KEY, ?TOP_LEVEL_TXN), + outer_transaction(F, NRestarts - 1, Reason); + continue -> + {aborted, Reason}; + Other -> + Other end. execute_bloc(F) -> @@ -865,23 +894,22 @@ sql_rollback() -> [{mssql, [<<"rollback transaction;">>]}, {any, [<<"rollback;">>]}]). +driver_restart_required(<<"query timed out">>) -> true; +driver_restart_required(<<"connection closed">>) -> true; +driver_restart_required(<<"Failed sending data on socket", _/binary>>) -> true; +driver_restart_required(<<"SQL connection failed">>) -> true; +driver_restart_required(<<"Communication link failure">>) -> true; +driver_restart_required(_) -> false. + %% Generate the OTP callback return tuple depending on the driver result. -abort_on_driver_error({error, <<"query timed out">>} = Reply, From, Timestamp) -> +abort_on_driver_error({Tag, Msg} = Reply, From, Timestamp) when Tag == error; Tag == aborted -> reply(From, Reply, Timestamp), - {stop, timeout, get(?STATE_KEY)}; -abort_on_driver_error({error, <<"connection closed">>} = Reply, From, Timestamp) -> - reply(From, Reply, Timestamp), - handle_reconnect(<<"connection closed">>, get(?STATE_KEY)); -abort_on_driver_error({error, <<"Failed sending data on socket", _/binary>>} = Reply, - From, Timestamp) -> - reply(From, Reply, Timestamp), - {stop, closed, get(?STATE_KEY)}; -abort_on_driver_error({error, <<"SQL connection failed">>} = Reply, From, Timestamp) -> - reply(From, Reply, Timestamp), - {stop, timeout, get(?STATE_KEY)}; -abort_on_driver_error({error, <<"Communication link failure">>} = Reply, From, Timestamp) -> - reply(From, Reply, Timestamp), - {stop, closed, get(?STATE_KEY)}; + case driver_restart_required(Msg) of + true -> + handle_reconnect(Msg, get(?STATE_KEY)); + _ -> + {next_state, session_established, get(?STATE_KEY)} + end; abort_on_driver_error(Reply, From, Timestamp) -> reply(From, Reply, Timestamp), {next_state, session_established, get(?STATE_KEY)}.