diff --git a/src/odbc/ejabberd_odbc.erl b/src/odbc/ejabberd_odbc.erl index b2c1c20fc..de38260c4 100644 --- a/src/odbc/ejabberd_odbc.erl +++ b/src/odbc/ejabberd_odbc.erl @@ -27,7 +27,9 @@ -module(ejabberd_odbc). -author('alexey@process-one.net'). --behaviour(gen_server). +-define(GEN_FSM, p1_fsm). + +-behaviour(?GEN_FSM). %% External exports -export([start/1, start_link/2, @@ -39,17 +41,28 @@ escape_like/1, keep_alive/1]). -%% gen_server callbacks +%% gen_fsm callbacks -export([init/1, - handle_call/3, - handle_cast/2, - code_change/3, - handle_info/2, - terminate/2]). + handle_event/3, + handle_sync_event/4, + handle_info/3, + terminate/3, + code_change/4]). + +%% gen_fsm states +-export([connecting/2, + connecting/3, + session_established/2, + session_established/3]). -include("ejabberd.hrl"). --record(state, {db_ref, db_type}). +-record(state, {db_ref, + db_type, + start_interval, + host, + max_pending_requests_len, + pending_requests}). -define(STATE_KEY, ejabberd_odbc_state). -define(NESTING_KEY, ejabberd_odbc_nesting_level). @@ -62,14 +75,23 @@ -define(KEEPALIVE_TIMEOUT, 60000). -define(KEEPALIVE_QUERY, "SELECT 1;"). +%%-define(DBGFSM, true). + +-ifdef(DBGFSM). +-define(FSMOPTS, [{debug, [trace]}]). +-else. +-define(FSMOPTS, []). +-endif. + %%%---------------------------------------------------------------------- %%% API %%%---------------------------------------------------------------------- start(Host) -> - gen_server:start(ejabberd_odbc, [Host], []). + ?GEN_FSM:start(ejabberd_odbc, [Host], fsm_limit_opts() ++ ?FSMOPTS). start_link(Host, StartInterval) -> - gen_server:start_link(ejabberd_odbc, [Host, StartInterval], []). + ?GEN_FSM:start_link(ejabberd_odbc, [Host, StartInterval], + fsm_limit_opts() ++ ?FSMOPTS). sql_query(Host, Query) -> sql_call(Host, {sql_query, Query}). @@ -95,12 +117,16 @@ sql_bloc(Host, F) -> sql_call(Host, Msg) -> case get(?STATE_KEY) of undefined -> - gen_server:call(ejabberd_odbc_sup:get_random_pid(Host), - {sql_cmd, Msg}, ?TRANSACTION_TIMEOUT); + ?GEN_FSM:sync_send_event(ejabberd_odbc_sup:get_random_pid(Host), + {sql_cmd, Msg}, ?TRANSACTION_TIMEOUT); _State -> nested_op(Msg) end. +% perform a harmless query on all opened connexions to avoid connexion close. +keep_alive(PID) -> + ?GEN_FSM:sync_send_event(PID, {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}}, + ?KEEPALIVE_TIMEOUT). %% This function is intended to be used from inside an sql_transaction: sql_query_t(Query) -> @@ -134,16 +160,8 @@ escape_like(C) -> odbc_queries:escape(C). %%%---------------------------------------------------------------------- -%%% Callback functions from gen_server +%%% Callback functions from gen_fsm %%%---------------------------------------------------------------------- - -%%---------------------------------------------------------------------- -%% Func: init/1 -%% Returns: {ok, State} | -%% {ok, State, Timeout} | -%% ignore | -%% {stop, Reason} -%%---------------------------------------------------------------------- init([Host, StartInterval]) -> case ejabberd_config:get_local_option({odbc_keepalive_interval, Host}) of KeepaliveInterval when is_integer(KeepaliveInterval) -> @@ -155,80 +173,114 @@ init([Host, StartInterval]) -> ?ERROR_MSG("Wrong odbc_keepalive_interval definition '~p'" " for host ~p.~n", [_Other, Host]) end, - SQLServer = ejabberd_config:get_local_option({odbc_server, Host}), - case SQLServer of - %% Default pgsql port - {pgsql, Server, DB, Username, Password} -> - pgsql_connect(Server, ?PGSQL_PORT, DB, Username, Password, - StartInterval); - {pgsql, Server, Port, DB, Username, Password} when is_integer(Port) -> - pgsql_connect(Server, Port, DB, Username, Password, - StartInterval); - %% Default mysql port - {mysql, Server, DB, Username, Password} -> - mysql_connect(Server, ?MYSQL_PORT, DB, Username, Password, - StartInterval); - {mysql, Server, Port, DB, Username, Password} when is_integer(Port) -> - mysql_connect(Server, Port, DB, Username, Password, - StartInterval); - _ when is_list(SQLServer) -> - odbc_connect(SQLServer, StartInterval) - end. + [DBType | _] = db_opts(Host), + ?GEN_FSM:send_event(self(), connect), + {ok, connecting, #state{db_type = DBType, + host = Host, + max_pending_requests_len = max_fsm_queue(), + pending_requests = {0, queue:new()}, + start_interval = StartInterval}}. -%%---------------------------------------------------------------------- -%% Func: handle_call/3 -%% Returns: {reply, Reply, State} | -%% {reply, Reply, State, Timeout} | -%% {noreply, State} | -%% {noreply, State, Timeout} | -%% {stop, Reason, Reply, State} | (terminate/2 is called) -%% {stop, Reason, State} (terminate/2 is called) -%%---------------------------------------------------------------------- -handle_call({sql_cmd, Command}, _From, State) -> +connecting(connect, #state{host = Host} = State) -> + ConnectRes = case db_opts(Host) of + [mysql | Args] -> + apply(fun mysql_connect/5, Args); + [pgsql | Args] -> + apply(fun pgsql_connect/5, Args); + [odbc | Args] -> + apply(fun odbc_connect/1, Args) + end, + {_, PendingRequests} = State#state.pending_requests, + case ConnectRes of + {ok, Ref} -> + erlang:monitor(process, Ref), + queue:filter( + fun(Req) -> + ?GEN_FSM:send_event(self(), Req), + false + end, PendingRequests), + {next_state, session_established, + State#state{db_ref = Ref, + pending_requests = {0, queue:new()}}}; + {error, Reason} -> + ?INFO_MSG("~p connection failed:~n" + "** Reason: ~p~n" + "** Retry after: ~p seconds", + [State#state.db_type, Reason, + State#state.start_interval div 1000]), + ?GEN_FSM:send_event_after(State#state.start_interval, + connect), + {next_state, connecting, State} + end; +connecting(Event, State) -> + ?WARNING_MSG("unexpected event in 'connecting': ~p", [Event]), + {next_state, connecting, State}. + +connecting({sql_cmd, {sql_query, ?KEEPALIVE_QUERY}}, From, State) -> + ?GEN_FSM:reply(From, {error, "SQL connection failed"}), + {next_state, connecting, State}; +connecting({sql_cmd, Command} = Req, From, State) -> + ?DEBUG("queueing pending request while connecting:~n\t~p", [Req]), + {Len, PendingRequests} = State#state.pending_requests, + NewPendingRequests = + if Len < State#state.max_pending_requests_len -> + {Len + 1, queue:in({sql_cmd, Command, From}, PendingRequests)}; + true -> + queue:filter( + fun({sql_cmd, _, To}) -> + ?GEN_FSM:reply(To, + {error, "SQL connection failed"}), + false + end, PendingRequests), + {1, queue:from_list([{sql_cmd, Command, From}])} + end, + {next_state, connecting, + State#state{pending_requests = NewPendingRequests}}; +connecting(Request, {Who, _Ref}, State) -> + ?WARNING_MSG("unexpected call ~p from ~p in 'connecting'", + [Request, Who]), + {reply, {error, badarg}, connecting, State}. + +session_established({sql_cmd, Command}, From, State) -> put(?NESTING_KEY, ?TOP_LEVEL_TXN), put(?STATE_KEY, State), - abort_on_driver_error(outer_op(Command)); -handle_call(Request, {Who, _Ref}, State) -> - ?WARNING_MSG("Unexpected call ~p from ~p.", [Request, Who]), - {reply, ok, State}. + abort_on_driver_error(outer_op(Command), From); +session_established(Request, {Who, _Ref}, State) -> + ?WARNING_MSG("unexpected call ~p from ~p in 'session_established'", + [Request, Who]), + {reply, {error, badarg}, session_established, State}. -%%---------------------------------------------------------------------- -%% Func: handle_cast/2 -%% Returns: {noreply, State} | -%% {noreply, State, Timeout} | -%% {stop, Reason, State} (terminate/2 is called) -%%---------------------------------------------------------------------- -handle_cast(_Msg, State) -> - {noreply, State}. +session_established({sql_cmd, Command, From}, State) -> + put(?NESTING_KEY, ?TOP_LEVEL_TXN), + put(?STATE_KEY, State), + abort_on_driver_error(outer_op(Command), From); +session_established(Event, State) -> + ?WARNING_MSG("unexpected event in 'session_established': ~p", [Event]), + {next_state, session_established, State}. +handle_event(_Event, StateName, State) -> + {next_state, StateName, State}. -code_change(_OldVsn, State, _Extra) -> - {ok, State}. +handle_sync_event(_Event, _From, StateName, State) -> + {reply, {error, badarg}, StateName, State}. + +code_change(_OldVsn, StateName, State, _Extra) -> + {ok, StateName, State}. -%%---------------------------------------------------------------------- -%% Func: handle_info/2 -%% Returns: {noreply, State} | -%% {noreply, State, Timeout} | -%% {stop, Reason, State} (terminate/2 is called) -%%---------------------------------------------------------------------- %% We receive the down signal when we loose the MySQL connection (we are %% monitoring the connection) -%% => We exit and let the supervisor restart the connection. -handle_info({'DOWN', _MonitorRef, process, _Pid, _Info}, State) -> - {stop, connection_dropped, State}; -handle_info(_Info, State) -> - {noreply, State}. +handle_info({'DOWN', _MonitorRef, process, _Pid, _Info}, _StateName, State) -> + ?GEN_FSM:send_event(self(), connect), + {next_state, connecting, State}; +handle_info(Info, StateName, State) -> + ?WARNING_MSG("unexpected info in ~p: ~p", [StateName, Info]), + {next_state, StateName, State}. -%%---------------------------------------------------------------------- -%% Func: terminate/2 -%% Purpose: Shutdown the server -%% Returns: any (ignored by gen_server) -%%---------------------------------------------------------------------- -terminate(_Reason, State) -> +terminate(_Reason, _StateName, State) -> case State#state.db_type of mysql -> - % old versions of mysql driver don't have the stop function - % so the catch + %% old versions of mysql driver don't have the stop function + %% so the catch catch mysql_conn:stop(State#state.db_ref); _ -> ok @@ -367,50 +419,34 @@ sql_query_internal(Query) -> end. %% Generate the OTP callback return tuple depending on the driver result. -abort_on_driver_error({error, "query timed out"} = Reply) -> +abort_on_driver_error({error, "query timed out"} = Reply, From) -> %% mysql driver error - {stop, timeout, Reply, get(?STATE_KEY)}; -abort_on_driver_error({error, "Failed sending data on socket"++_} = Reply) -> + ?GEN_FSM:reply(From, Reply), + {stop, timeout, get(?STATE_KEY)}; +abort_on_driver_error({error, "Failed sending data on socket" ++ _} = Reply, + From) -> %% mysql driver error - {stop, closed, Reply, get(?STATE_KEY)}; -abort_on_driver_error(Reply) -> - {reply, Reply, get(?STATE_KEY)}. + ?GEN_FSM:reply(From, Reply), + {stop, closed, get(?STATE_KEY)}; +abort_on_driver_error(Reply, From) -> + ?GEN_FSM:reply(From, Reply), + {next_state, session_established, get(?STATE_KEY)}. %% == pure ODBC code %% part of init/1 %% Open an ODBC database connection -odbc_connect(SQLServer, StartInterval) -> +odbc_connect(SQLServer) -> application:start(odbc), - case odbc:connect(SQLServer,[{scrollable_cursors, off}]) of - {ok, Ref} -> - erlang:monitor(process, Ref), - {ok, #state{db_ref = Ref, db_type = odbc}}; - {error, Reason} -> - ?ERROR_MSG("ODBC connection (~s) failed: ~p~n", - [SQLServer, Reason]), - %% If we can't connect we wait before retrying - timer:sleep(StartInterval), - {stop, odbc_connection_failed} - end. - + odbc:connect(SQLServer, [{scrollable_cursors, off}]). %% == Native PostgreSQL code %% part of init/1 %% Open a database connection to PostgreSQL -pgsql_connect(Server, Port, DB, Username, Password, StartInterval) -> - case pgsql:connect(Server, DB, Username, Password, Port) of - {ok, Ref} -> - erlang:monitor(process, Ref), - {ok, #state{db_ref = Ref, db_type = pgsql}}; - {error, Reason} -> - ?ERROR_MSG("PostgreSQL connection failed: ~p~n", [Reason]), - %% If we can't connect we wait before retrying - timer:sleep(StartInterval), - {stop, pgsql_connection_failed} - end. +pgsql_connect(Server, Port, DB, Username, Password) -> + pgsql:connect(Server, DB, Username, Password, Port). %% Convert PostgreSQL query result to Erlang ODBC result formalism pgsql_to_odbc({ok, PGSQLResult}) -> @@ -441,19 +477,13 @@ pgsql_item_to_odbc(_) -> %% part of init/1 %% Open a database connection to MySQL -mysql_connect(Server, Port, DB, Username, Password, StartInterval) -> +mysql_connect(Server, Port, DB, Username, Password) -> case mysql_conn:start(Server, Port, Username, Password, DB, fun log/3) of {ok, Ref} -> - erlang:monitor(process, Ref), mysql_conn:fetch(Ref, ["set names 'utf8';"], self()), - {ok, #state{db_ref = Ref, db_type = mysql}}; - {error, Reason} -> - ?ERROR_MSG("MySQL connection failed: ~p~n" - "Waiting ~p seconds before retrying...~n", - [Reason, StartInterval div 1000]), - %% If we can't connect we wait before retrying - timer:sleep(StartInterval), - {stop, mysql_connection_failed} + {ok, Ref}; + Err -> + Err end. %% Convert MySQL query result to Erlang ODBC result formalism @@ -475,11 +505,6 @@ mysql_item_to_odbc(Columns, Recs) -> [element(2, Column) || Column <- Columns], [list_to_tuple(Rec) || Rec <- Recs]}. -% perform a harmless query on all opened connexions to avoid connexion close. -keep_alive(PID) -> - gen_server:call(PID, {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}}, - ?KEEPALIVE_TIMEOUT). - % log function used by MySQL driver log(Level, Format, Args) -> case Level of @@ -490,3 +515,35 @@ log(Level, Format, Args) -> error -> ?ERROR_MSG(Format, Args) end. + +db_opts(Host) -> + case ejabberd_config:get_local_option({odbc_server, Host}) of + %% Default pgsql port + {pgsql, Server, DB, User, Pass} -> + [pgsql, Server, ?PGSQL_PORT, DB, User, Pass]; + {pgsql, Server, Port, DB, User, Pass} when is_integer(Port) -> + [pgsql, Server, Port, DB, User, Pass]; + %% Default mysql port + {mysql, Server, DB, User, Pass} -> + [mysql, Server, ?MYSQL_PORT, DB, User, Pass]; + {mysql, Server, Port, DB, User, Pass} when is_integer(Port) -> + [mysql, Server, Port, DB, User, Pass]; + SQLServer when is_list(SQLServer) -> + [odbc, SQLServer] + end. + +max_fsm_queue() -> + case ejabberd_config:get_local_option(max_fsm_queue) of + N when is_integer(N), N>0 -> + N; + _ -> + undefined + end. + +fsm_limit_opts() -> + case max_fsm_queue() of + N when is_integer(N) -> + [{max_queue, N}]; + _ -> + [] + end.