diff --git a/src/odbc/ejabberd_odbc.erl b/src/odbc/ejabberd_odbc.erl index be1c315a8..5376694ae 100644 --- a/src/odbc/ejabberd_odbc.erl +++ b/src/odbc/ejabberd_odbc.erl @@ -70,8 +70,8 @@ start_link(Host, StartInterval) -> gen_server:start_link(ejabberd_odbc, [Host, StartInterval], []). sql_query(Host, Query) -> - gen_server:call(ejabberd_odbc_sup:get_random_pid(Host), - {sql_query, Query}, ?TRANSACTION_TIMEOUT). + Msg = {sql_query, Query}, + sql_call(Host, Msg,). %% SQL transaction based on a list of queries %% This function automatically @@ -85,13 +85,24 @@ sql_transaction(Host, Queries) when is_list(Queries) -> sql_transaction(Host, F); %% SQL transaction, based on a erlang anonymous function (F = fun) sql_transaction(Host, F) -> - gen_server:call(ejabberd_odbc_sup:get_random_pid(Host), - {sql_transaction, F}, ?TRANSACTION_TIMEOUT). + Msg = {sql_transaction, F}, + sql_call(Host, Msg). %% SQL bloc, based on a erlang anonymous function (F = fun) sql_bloc(Host, F) -> - gen_server:call(ejabberd_odbc_sup:get_random_pid(Host), - {sql_bloc, F}, ?TRANSACTION_TIMEOUT). + Msg = {sql_bloc, F}, + sql_call(Host, Msg). + +sql_call(Host, Msg) -> + case get(?STATE_KEY) of + undefined -> + gen_server:call(ejabberd_odbc_sup:get_random_pid(Host), + Msg, ?TRANSACTION_TIMEOUT); + State -> + %% Query, Transaction or Bloc nested inside transaction + nested_op(Msg, State) + end. + %% This function is intended to be used from inside an sql_transaction: sql_query_t(Query) -> @@ -176,42 +187,8 @@ init([Host, StartInterval]) -> %% {stop, Reason, Reply, State} | (terminate/2 is called) %% {stop, Reason, State} (terminate/2 is called) %%---------------------------------------------------------------------- -handle_call({sql_query, Query}, _From, State) -> - case sql_query_internal(State, Query) of - % error returned by MySQL driver - {error, "query timed out"} = Reply -> - {stop, timeout, Reply, State}; - % error returned by MySQL driver - {error, "Failed sending data on socket"++_} = Reply -> - {stop, closed, Reply, State}; - Reply -> - {reply, Reply, State} - end; -handle_call({sql_transaction, F}, _From, State) -> - case execute_transaction(State, F, ?MAX_TRANSACTION_RESTARTS, "") of - % error returned by MySQL driver - {error, "query timed out"} -> - {stop, timeout, State}; - % error returned by MySQL driver - {error, "Failed sending data on socket"++_} = Reply -> - {stop, closed, Reply, State}; - Reply -> - {reply, Reply, State} - end; -handle_call({sql_bloc, F}, _From, State) -> - case execute_bloc(State, F) of - % error returned by MySQL driver - {error, "query timed out"} -> - {stop, timeout, State}; - % error returned by MySQL driver - {error, "Failed sending data on socket"++_} = Reply -> - {stop, closed, Reply, State}; - Reply -> - {reply, Reply, State} - end; -handle_call(_Request, _From, State) -> - Reply = ok, - {reply, Reply, State}. +handle_call(Command, _From, State) -> + dispatch_sql_command(Command, State). %%---------------------------------------------------------------------- %% Func: handle_cast/2 @@ -259,14 +236,36 @@ terminate(_Reason, State) -> %%%---------------------------------------------------------------------- %%% Internal functions %%%---------------------------------------------------------------------- +dispatch_sql_command({sql_query, Query}, State) -> + abort_on_driver_error(sql_query_internal(State, Query), State); +dispatch_sql_command({sql_transaction, F}, State) -> + abort_on_driver_error( + execute_transaction(State, F, ?MAX_TRANSACTION_RESTARTS, ""), State); +dispatch_sql_command({sql_bloc, F}, State) -> + abort_on_driver_error(execute_bloc(State, F), State); +dispatch_sql_command(Request, State) -> + ?WARNING_MSG("Unexpected call ~p.", [Request]), + {reply, ok, State}. + sql_query_internal(State, Query) -> - case State#state.db_type of - odbc -> - odbc:sql_query(State#state.db_ref, Query); - pgsql -> - pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query)); - mysql -> - mysql_to_odbc(mysql_conn:fetch(State#state.db_ref, Query, self())) + Nested = case get(?STATE_KEY) of + undefined -> put(?STATE_KEY, State), false; + _State -> true + end, + Result = case State#state.db_type of + odbc -> + odbc:sql_query(State#state.db_ref, Query); + pgsql -> + pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query)); + mysql -> + ?DEBUG("MySQL, Send query~n~p~n", [Query]), + R = mysql_to_odbc(mysql_conn:fetch(State#state.db_ref, Query, self())), + ?DEBUG("MySQL, Received result~n~p~n", [R]), + R + end, + case Nested of + true -> Result; + false -> erase(?STATE_KEY), Result end. execute_transaction(State, _F, 0, Reason) -> @@ -280,30 +279,69 @@ execute_transaction(State, _F, 0, Reason) -> sql_query_internal(State, "rollback;"), {aborted, restarts_exceeded}; execute_transaction(State, F, NRestarts, _Reason) -> - put(?STATE_KEY, State), - sql_query_internal(State, "begin;"), - case catch F() of - {aborted, Reason} -> - execute_transaction(State, F, NRestarts - 1, Reason); - {'EXIT', Reason} -> - sql_query_internal(State, "rollback;"), - {aborted, Reason}; - Res -> - sql_query_internal(State, "commit;"), - {atomic, Res} + Nested = case get(?STATE_KEY) of + undefined -> + put(?STATE_KEY, State), + sql_query_internal(State, "begin;"), + false; + _State -> + true + end, + Result = case catch F() of + {aborted, Reason} -> + execute_transaction(State, F, NRestarts - 1, Reason); + {'EXIT', Reason} -> + sql_query_internal(State, "rollback;"), + {aborted, Reason}; + Res -> + {atomic, Res} + end, + case Nested of + true -> Result; + false -> sql_query_internal(State, "commit;"), erase(?STATE_KEY), Result end. execute_bloc(State, F) -> - put(?STATE_KEY, State), - case catch F() of - {aborted, Reason} -> - {aborted, Reason}; - {'EXIT', Reason} -> - {aborted, Reason}; - Res -> - {atomic, Res} + Nested = case get(?STATE_KEY) of + undefined -> put(?STATE_KEY, State), false; + _State -> true + end, + Result = case catch F() of + {aborted, Reason} -> + {aborted, Reason}; + {'EXIT', Reason} -> + {aborted, Reason}; + Res -> + {atomic, Res} + end, + case Nested of + true -> Result; + false -> erase(?STATE_KEY), Result end. +nested_op(Op, State) -> + case dispatch_sql_command(Op, State) of + {reply, Res, NewState} -> + put(?STATE_KEY, NewState), + Res; + {stop, _Reason, Reply, NewState} -> + put(?STATE_KEY, NewState), + throw({aborted, Reply}); + {noreply, NewState} -> + put(?STATE_KEY, NewState), + exit({bad_op_in_nested_txn, Op}) + end. + +abort_on_driver_error({error, "query timed out"} = Reply, State) -> + %% mysql driver error + {stop, timeout, Reply, State}; +abort_on_driver_error({error, "Failed sending data on socket"++_} = Reply, State) -> + %% mysql driver error + {stop, closed, Reply, State}; +abort_on_driver_error(Reply, State) -> + {reply, Reply, State}. + + %% == pure ODBC code %% part of init/1