25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-11-24 16:23:40 +01:00

improved SQL reconnect behaviour

SVN Revision: 2947
This commit is contained in:
Evgeniy Khramtsov 2010-01-31 11:41:28 +00:00
parent 1cced15c0d
commit 03454c7f1d

View File

@ -27,7 +27,9 @@
-module(ejabberd_odbc). -module(ejabberd_odbc).
-author('alexey@process-one.net'). -author('alexey@process-one.net').
-behaviour(gen_server). -define(GEN_FSM, p1_fsm).
-behaviour(?GEN_FSM).
%% External exports %% External exports
-export([start/1, start_link/2, -export([start/1, start_link/2,
@ -39,17 +41,28 @@
escape_like/1, escape_like/1,
keep_alive/1]). keep_alive/1]).
%% gen_server callbacks %% gen_fsm callbacks
-export([init/1, -export([init/1,
handle_call/3, handle_event/3,
handle_cast/2, handle_sync_event/4,
code_change/3, handle_info/3,
handle_info/2, terminate/3,
terminate/2]). code_change/4]).
%% gen_fsm states
-export([connecting/2,
connecting/3,
session_established/2,
session_established/3]).
-include("ejabberd.hrl"). -include("ejabberd.hrl").
-record(state, {db_ref, db_type}). -record(state, {db_ref,
db_type,
start_interval,
host,
max_pending_requests_len,
pending_requests}).
-define(STATE_KEY, ejabberd_odbc_state). -define(STATE_KEY, ejabberd_odbc_state).
-define(NESTING_KEY, ejabberd_odbc_nesting_level). -define(NESTING_KEY, ejabberd_odbc_nesting_level).
@ -62,14 +75,23 @@
-define(KEEPALIVE_TIMEOUT, 60000). -define(KEEPALIVE_TIMEOUT, 60000).
-define(KEEPALIVE_QUERY, "SELECT 1;"). -define(KEEPALIVE_QUERY, "SELECT 1;").
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
-define(FSMOPTS, [{debug, [trace]}]).
-else.
-define(FSMOPTS, []).
-endif.
%%%---------------------------------------------------------------------- %%%----------------------------------------------------------------------
%%% API %%% API
%%%---------------------------------------------------------------------- %%%----------------------------------------------------------------------
start(Host) -> start(Host) ->
gen_server:start(ejabberd_odbc, [Host], []). ?GEN_FSM:start(ejabberd_odbc, [Host], fsm_limit_opts() ++ ?FSMOPTS).
start_link(Host, StartInterval) -> start_link(Host, StartInterval) ->
gen_server:start_link(ejabberd_odbc, [Host, StartInterval], []). ?GEN_FSM:start_link(ejabberd_odbc, [Host, StartInterval],
fsm_limit_opts() ++ ?FSMOPTS).
sql_query(Host, Query) -> sql_query(Host, Query) ->
sql_call(Host, {sql_query, Query}). sql_call(Host, {sql_query, Query}).
@ -95,12 +117,16 @@ sql_bloc(Host, F) ->
sql_call(Host, Msg) -> sql_call(Host, Msg) ->
case get(?STATE_KEY) of case get(?STATE_KEY) of
undefined -> undefined ->
gen_server:call(ejabberd_odbc_sup:get_random_pid(Host), ?GEN_FSM:sync_send_event(ejabberd_odbc_sup:get_random_pid(Host),
{sql_cmd, Msg}, ?TRANSACTION_TIMEOUT); {sql_cmd, Msg}, ?TRANSACTION_TIMEOUT);
_State -> _State ->
nested_op(Msg) nested_op(Msg)
end. end.
% perform a harmless query on all opened connexions to avoid connexion close.
keep_alive(PID) ->
?GEN_FSM:sync_send_event(PID, {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}},
?KEEPALIVE_TIMEOUT).
%% This function is intended to be used from inside an sql_transaction: %% This function is intended to be used from inside an sql_transaction:
sql_query_t(Query) -> sql_query_t(Query) ->
@ -134,16 +160,8 @@ escape_like(C) -> odbc_queries:escape(C).
%%%---------------------------------------------------------------------- %%%----------------------------------------------------------------------
%%% Callback functions from gen_server %%% Callback functions from gen_fsm
%%%---------------------------------------------------------------------- %%%----------------------------------------------------------------------
%%----------------------------------------------------------------------
%% Func: init/1
%% Returns: {ok, State} |
%% {ok, State, Timeout} |
%% ignore |
%% {stop, Reason}
%%----------------------------------------------------------------------
init([Host, StartInterval]) -> init([Host, StartInterval]) ->
case ejabberd_config:get_local_option({odbc_keepalive_interval, Host}) of case ejabberd_config:get_local_option({odbc_keepalive_interval, Host}) of
KeepaliveInterval when is_integer(KeepaliveInterval) -> KeepaliveInterval when is_integer(KeepaliveInterval) ->
@ -155,80 +173,114 @@ init([Host, StartInterval]) ->
?ERROR_MSG("Wrong odbc_keepalive_interval definition '~p'" ?ERROR_MSG("Wrong odbc_keepalive_interval definition '~p'"
" for host ~p.~n", [_Other, Host]) " for host ~p.~n", [_Other, Host])
end, end,
SQLServer = ejabberd_config:get_local_option({odbc_server, Host}), [DBType | _] = db_opts(Host),
case SQLServer of ?GEN_FSM:send_event(self(), connect),
%% Default pgsql port {ok, connecting, #state{db_type = DBType,
{pgsql, Server, DB, Username, Password} -> host = Host,
pgsql_connect(Server, ?PGSQL_PORT, DB, Username, Password, max_pending_requests_len = max_fsm_queue(),
StartInterval); pending_requests = {0, queue:new()},
{pgsql, Server, Port, DB, Username, Password} when is_integer(Port) -> start_interval = StartInterval}}.
pgsql_connect(Server, Port, DB, Username, Password,
StartInterval);
%% Default mysql port
{mysql, Server, DB, Username, Password} ->
mysql_connect(Server, ?MYSQL_PORT, DB, Username, Password,
StartInterval);
{mysql, Server, Port, DB, Username, Password} when is_integer(Port) ->
mysql_connect(Server, Port, DB, Username, Password,
StartInterval);
_ when is_list(SQLServer) ->
odbc_connect(SQLServer, StartInterval)
end.
%%---------------------------------------------------------------------- connecting(connect, #state{host = Host} = State) ->
%% Func: handle_call/3 ConnectRes = case db_opts(Host) of
%% Returns: {reply, Reply, State} | [mysql | Args] ->
%% {reply, Reply, State, Timeout} | apply(fun mysql_connect/5, Args);
%% {noreply, State} | [pgsql | Args] ->
%% {noreply, State, Timeout} | apply(fun pgsql_connect/5, Args);
%% {stop, Reason, Reply, State} | (terminate/2 is called) [odbc | Args] ->
%% {stop, Reason, State} (terminate/2 is called) apply(fun odbc_connect/1, Args)
%%---------------------------------------------------------------------- end,
handle_call({sql_cmd, Command}, _From, State) -> {_, PendingRequests} = State#state.pending_requests,
case ConnectRes of
{ok, Ref} ->
erlang:monitor(process, Ref),
queue:filter(
fun(Req) ->
?GEN_FSM:send_event(self(), Req),
false
end, PendingRequests),
{next_state, session_established,
State#state{db_ref = Ref,
pending_requests = {0, queue:new()}}};
{error, Reason} ->
?INFO_MSG("~p connection failed:~n"
"** Reason: ~p~n"
"** Retry after: ~p seconds",
[State#state.db_type, Reason,
State#state.start_interval div 1000]),
?GEN_FSM:send_event_after(State#state.start_interval,
connect),
{next_state, connecting, State}
end;
connecting(Event, State) ->
?WARNING_MSG("unexpected event in 'connecting': ~p", [Event]),
{next_state, connecting, State}.
connecting({sql_cmd, {sql_query, ?KEEPALIVE_QUERY}}, From, State) ->
?GEN_FSM:reply(From, {error, "SQL connection failed"}),
{next_state, connecting, State};
connecting({sql_cmd, Command} = Req, From, State) ->
?DEBUG("queueing pending request while connecting:~n\t~p", [Req]),
{Len, PendingRequests} = State#state.pending_requests,
NewPendingRequests =
if Len < State#state.max_pending_requests_len ->
{Len + 1, queue:in({sql_cmd, Command, From}, PendingRequests)};
true ->
queue:filter(
fun({sql_cmd, _, To}) ->
?GEN_FSM:reply(To,
{error, "SQL connection failed"}),
false
end, PendingRequests),
{1, queue:from_list([{sql_cmd, Command, From}])}
end,
{next_state, connecting,
State#state{pending_requests = NewPendingRequests}};
connecting(Request, {Who, _Ref}, State) ->
?WARNING_MSG("unexpected call ~p from ~p in 'connecting'",
[Request, Who]),
{reply, {error, badarg}, connecting, State}.
session_established({sql_cmd, Command}, From, State) ->
put(?NESTING_KEY, ?TOP_LEVEL_TXN), put(?NESTING_KEY, ?TOP_LEVEL_TXN),
put(?STATE_KEY, State), put(?STATE_KEY, State),
abort_on_driver_error(outer_op(Command)); abort_on_driver_error(outer_op(Command), From);
handle_call(Request, {Who, _Ref}, State) -> session_established(Request, {Who, _Ref}, State) ->
?WARNING_MSG("Unexpected call ~p from ~p.", [Request, Who]), ?WARNING_MSG("unexpected call ~p from ~p in 'session_established'",
{reply, ok, State}. [Request, Who]),
{reply, {error, badarg}, session_established, State}.
%%---------------------------------------------------------------------- session_established({sql_cmd, Command, From}, State) ->
%% Func: handle_cast/2 put(?NESTING_KEY, ?TOP_LEVEL_TXN),
%% Returns: {noreply, State} | put(?STATE_KEY, State),
%% {noreply, State, Timeout} | abort_on_driver_error(outer_op(Command), From);
%% {stop, Reason, State} (terminate/2 is called) session_established(Event, State) ->
%%---------------------------------------------------------------------- ?WARNING_MSG("unexpected event in 'session_established': ~p", [Event]),
handle_cast(_Msg, State) -> {next_state, session_established, State}.
{noreply, State}.
handle_event(_Event, StateName, State) ->
{next_state, StateName, State}.
code_change(_OldVsn, State, _Extra) -> handle_sync_event(_Event, _From, StateName, State) ->
{ok, State}. {reply, {error, badarg}, StateName, State}.
code_change(_OldVsn, StateName, State, _Extra) ->
{ok, StateName, State}.
%%----------------------------------------------------------------------
%% Func: handle_info/2
%% Returns: {noreply, State} |
%% {noreply, State, Timeout} |
%% {stop, Reason, State} (terminate/2 is called)
%%----------------------------------------------------------------------
%% We receive the down signal when we loose the MySQL connection (we are %% We receive the down signal when we loose the MySQL connection (we are
%% monitoring the connection) %% monitoring the connection)
%% => We exit and let the supervisor restart the connection. handle_info({'DOWN', _MonitorRef, process, _Pid, _Info}, _StateName, State) ->
handle_info({'DOWN', _MonitorRef, process, _Pid, _Info}, State) -> ?GEN_FSM:send_event(self(), connect),
{stop, connection_dropped, State}; {next_state, connecting, State};
handle_info(_Info, State) -> handle_info(Info, StateName, State) ->
{noreply, State}. ?WARNING_MSG("unexpected info in ~p: ~p", [StateName, Info]),
{next_state, StateName, State}.
%%---------------------------------------------------------------------- terminate(_Reason, _StateName, State) ->
%% Func: terminate/2
%% Purpose: Shutdown the server
%% Returns: any (ignored by gen_server)
%%----------------------------------------------------------------------
terminate(_Reason, State) ->
case State#state.db_type of case State#state.db_type of
mysql -> mysql ->
% old versions of mysql driver don't have the stop function %% old versions of mysql driver don't have the stop function
% so the catch %% so the catch
catch mysql_conn:stop(State#state.db_ref); catch mysql_conn:stop(State#state.db_ref);
_ -> _ ->
ok ok
@ -367,50 +419,34 @@ sql_query_internal(Query) ->
end. end.
%% Generate the OTP callback return tuple depending on the driver result. %% Generate the OTP callback return tuple depending on the driver result.
abort_on_driver_error({error, "query timed out"} = Reply) -> abort_on_driver_error({error, "query timed out"} = Reply, From) ->
%% mysql driver error %% mysql driver error
{stop, timeout, Reply, get(?STATE_KEY)}; ?GEN_FSM:reply(From, Reply),
abort_on_driver_error({error, "Failed sending data on socket"++_} = Reply) -> {stop, timeout, get(?STATE_KEY)};
abort_on_driver_error({error, "Failed sending data on socket" ++ _} = Reply,
From) ->
%% mysql driver error %% mysql driver error
{stop, closed, Reply, get(?STATE_KEY)}; ?GEN_FSM:reply(From, Reply),
abort_on_driver_error(Reply) -> {stop, closed, get(?STATE_KEY)};
{reply, Reply, get(?STATE_KEY)}. abort_on_driver_error(Reply, From) ->
?GEN_FSM:reply(From, Reply),
{next_state, session_established, get(?STATE_KEY)}.
%% == pure ODBC code %% == pure ODBC code
%% part of init/1 %% part of init/1
%% Open an ODBC database connection %% Open an ODBC database connection
odbc_connect(SQLServer, StartInterval) -> odbc_connect(SQLServer) ->
application:start(odbc), application:start(odbc),
case odbc:connect(SQLServer,[{scrollable_cursors, off}]) of odbc:connect(SQLServer, [{scrollable_cursors, off}]).
{ok, Ref} ->
erlang:monitor(process, Ref),
{ok, #state{db_ref = Ref, db_type = odbc}};
{error, Reason} ->
?ERROR_MSG("ODBC connection (~s) failed: ~p~n",
[SQLServer, Reason]),
%% If we can't connect we wait before retrying
timer:sleep(StartInterval),
{stop, odbc_connection_failed}
end.
%% == Native PostgreSQL code %% == Native PostgreSQL code
%% part of init/1 %% part of init/1
%% Open a database connection to PostgreSQL %% Open a database connection to PostgreSQL
pgsql_connect(Server, Port, DB, Username, Password, StartInterval) -> pgsql_connect(Server, Port, DB, Username, Password) ->
case pgsql:connect(Server, DB, Username, Password, Port) of pgsql:connect(Server, DB, Username, Password, Port).
{ok, Ref} ->
erlang:monitor(process, Ref),
{ok, #state{db_ref = Ref, db_type = pgsql}};
{error, Reason} ->
?ERROR_MSG("PostgreSQL connection failed: ~p~n", [Reason]),
%% If we can't connect we wait before retrying
timer:sleep(StartInterval),
{stop, pgsql_connection_failed}
end.
%% Convert PostgreSQL query result to Erlang ODBC result formalism %% Convert PostgreSQL query result to Erlang ODBC result formalism
pgsql_to_odbc({ok, PGSQLResult}) -> pgsql_to_odbc({ok, PGSQLResult}) ->
@ -441,19 +477,13 @@ pgsql_item_to_odbc(_) ->
%% part of init/1 %% part of init/1
%% Open a database connection to MySQL %% Open a database connection to MySQL
mysql_connect(Server, Port, DB, Username, Password, StartInterval) -> mysql_connect(Server, Port, DB, Username, Password) ->
case mysql_conn:start(Server, Port, Username, Password, DB, fun log/3) of case mysql_conn:start(Server, Port, Username, Password, DB, fun log/3) of
{ok, Ref} -> {ok, Ref} ->
erlang:monitor(process, Ref),
mysql_conn:fetch(Ref, ["set names 'utf8';"], self()), mysql_conn:fetch(Ref, ["set names 'utf8';"], self()),
{ok, #state{db_ref = Ref, db_type = mysql}}; {ok, Ref};
{error, Reason} -> Err ->
?ERROR_MSG("MySQL connection failed: ~p~n" Err
"Waiting ~p seconds before retrying...~n",
[Reason, StartInterval div 1000]),
%% If we can't connect we wait before retrying
timer:sleep(StartInterval),
{stop, mysql_connection_failed}
end. end.
%% Convert MySQL query result to Erlang ODBC result formalism %% Convert MySQL query result to Erlang ODBC result formalism
@ -475,11 +505,6 @@ mysql_item_to_odbc(Columns, Recs) ->
[element(2, Column) || Column <- Columns], [element(2, Column) || Column <- Columns],
[list_to_tuple(Rec) || Rec <- Recs]}. [list_to_tuple(Rec) || Rec <- Recs]}.
% perform a harmless query on all opened connexions to avoid connexion close.
keep_alive(PID) ->
gen_server:call(PID, {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}},
?KEEPALIVE_TIMEOUT).
% log function used by MySQL driver % log function used by MySQL driver
log(Level, Format, Args) -> log(Level, Format, Args) ->
case Level of case Level of
@ -490,3 +515,35 @@ log(Level, Format, Args) ->
error -> error ->
?ERROR_MSG(Format, Args) ?ERROR_MSG(Format, Args)
end. end.
db_opts(Host) ->
case ejabberd_config:get_local_option({odbc_server, Host}) of
%% Default pgsql port
{pgsql, Server, DB, User, Pass} ->
[pgsql, Server, ?PGSQL_PORT, DB, User, Pass];
{pgsql, Server, Port, DB, User, Pass} when is_integer(Port) ->
[pgsql, Server, Port, DB, User, Pass];
%% Default mysql port
{mysql, Server, DB, User, Pass} ->
[mysql, Server, ?MYSQL_PORT, DB, User, Pass];
{mysql, Server, Port, DB, User, Pass} when is_integer(Port) ->
[mysql, Server, Port, DB, User, Pass];
SQLServer when is_list(SQLServer) ->
[odbc, SQLServer]
end.
max_fsm_queue() ->
case ejabberd_config:get_local_option(max_fsm_queue) of
N when is_integer(N), N>0 ->
N;
_ ->
undefined
end.
fsm_limit_opts() ->
case max_fsm_queue() of
N when is_integer(N) ->
[{max_queue, N}];
_ ->
[]
end.