mirror of
https://github.com/processone/ejabberd.git
synced 2024-11-24 16:23:40 +01:00
adfca08e43
SVN Revision: 2557
496 lines
16 KiB
Erlang
496 lines
16 KiB
Erlang
%%%----------------------------------------------------------------------
|
|
%%% File : ejabberd_odbc.erl
|
|
%%% Author : Alexey Shchepin <alexey@process-one.net>
|
|
%%% Purpose : Serve ODBC connection
|
|
%%% Created : 8 Dec 2004 by Alexey Shchepin <alexey@process-one.net>
|
|
%%%
|
|
%%%
|
|
%%% ejabberd, Copyright (C) 2002-2009 ProcessOne
|
|
%%%
|
|
%%% This program is free software; you can redistribute it and/or
|
|
%%% modify it under the terms of the GNU General Public License as
|
|
%%% published by the Free Software Foundation; either version 2 of the
|
|
%%% License, or (at your option) any later version.
|
|
%%%
|
|
%%% This program is distributed in the hope that it will be useful,
|
|
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
%%% General Public License for more details.
|
|
%%%
|
|
%%% You should have received a copy of the GNU General Public License
|
|
%%% along with this program; if not, write to the Free Software
|
|
%%% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
|
|
%%% 02111-1307 USA
|
|
%%%
|
|
%%%----------------------------------------------------------------------
|
|
|
|
-module(ejabberd_odbc).
|
|
-author('alexey@process-one.net').
|
|
|
|
-behaviour(gen_server).
|
|
|
|
%% External exports
|
|
-export([start/1, start_link/2,
|
|
sql_query/2,
|
|
sql_query_t/1,
|
|
sql_transaction/2,
|
|
sql_bloc/2,
|
|
escape/1,
|
|
escape_like/1,
|
|
keep_alive/1]).
|
|
|
|
%% gen_server callbacks
|
|
-export([init/1,
|
|
handle_call/3,
|
|
handle_cast/2,
|
|
code_change/3,
|
|
handle_info/2,
|
|
terminate/2]).
|
|
|
|
-include("ejabberd.hrl").
|
|
|
|
-record(state, {db_ref, db_type}).
|
|
|
|
-define(STATE_KEY, ejabberd_odbc_state).
|
|
-define(NESTING_KEY, ejabberd_odbc_nesting_level).
|
|
-define(TOP_LEVEL_TXN, 0).
|
|
-define(MAX_TRANSACTION_RESTARTS, 10).
|
|
-define(PGSQL_PORT, 5432).
|
|
-define(MYSQL_PORT, 3306).
|
|
|
|
-define(TRANSACTION_TIMEOUT, 60000). % milliseconds
|
|
-define(KEEPALIVE_TIMEOUT, 60000).
|
|
-define(KEEPALIVE_QUERY, "SELECT 1;").
|
|
|
|
%%%----------------------------------------------------------------------
|
|
%%% API
|
|
%%%----------------------------------------------------------------------
|
|
start(Host) ->
|
|
gen_server:start(ejabberd_odbc, [Host], []).
|
|
|
|
start_link(Host, StartInterval) ->
|
|
gen_server:start_link(ejabberd_odbc, [Host, StartInterval], []).
|
|
|
|
sql_query(Host, Query) ->
|
|
sql_call(Host, {sql_query, Query}).
|
|
|
|
%% SQL transaction based on a list of queries
|
|
%% This function automatically
|
|
sql_transaction(Host, Queries) when is_list(Queries) ->
|
|
F = fun() ->
|
|
lists:foreach(fun(Query) ->
|
|
sql_query_t(Query)
|
|
end,
|
|
Queries)
|
|
end,
|
|
sql_transaction(Host, F);
|
|
%% SQL transaction, based on a erlang anonymous function (F = fun)
|
|
sql_transaction(Host, F) when is_function(F) ->
|
|
sql_call(Host, {sql_transaction, F}).
|
|
|
|
%% SQL bloc, based on a erlang anonymous function (F = fun)
|
|
sql_bloc(Host, F) ->
|
|
sql_call(Host, {sql_bloc, F}).
|
|
|
|
sql_call(Host, Msg) ->
|
|
case get(?STATE_KEY) of
|
|
undefined ->
|
|
gen_server:call(ejabberd_odbc_sup:get_random_pid(Host),
|
|
{sql_cmd, Msg}, ?TRANSACTION_TIMEOUT);
|
|
_State ->
|
|
nested_op(Msg)
|
|
end.
|
|
|
|
|
|
%% This function is intended to be used from inside an sql_transaction:
|
|
sql_query_t(Query) ->
|
|
QRes = sql_query_internal(Query),
|
|
case QRes of
|
|
{error, Reason} ->
|
|
throw({aborted, Reason});
|
|
Rs when is_list(Rs) ->
|
|
case lists:keysearch(error, 1, Rs) of
|
|
{value, {error, Reason}} ->
|
|
throw({aborted, Reason});
|
|
_ ->
|
|
QRes
|
|
end;
|
|
_ ->
|
|
QRes
|
|
end.
|
|
|
|
%% Escape character that will confuse an SQL engine
|
|
escape(S) when is_list(S) ->
|
|
[odbc_queries:escape(C) || C <- S];
|
|
|
|
escape(S) when is_binary(S) ->
|
|
[odbc_queries:escape(C) || <<C>> <= S].
|
|
|
|
%% Escape character that will confuse an SQL engine
|
|
%% Percent and underscore only need to be escaped for pattern matching like
|
|
%% statement
|
|
escape_like(S) when is_list(S) ->
|
|
[escape_like(C) || C <- S];
|
|
escape_like($%) -> "\\%";
|
|
escape_like($_) -> "\\_";
|
|
escape_like(C) -> odbc_queries:escape(C).
|
|
|
|
|
|
%%%----------------------------------------------------------------------
|
|
%%% Callback functions from gen_server
|
|
%%%----------------------------------------------------------------------
|
|
|
|
%%----------------------------------------------------------------------
|
|
%% Func: init/1
|
|
%% Returns: {ok, State} |
|
|
%% {ok, State, Timeout} |
|
|
%% ignore |
|
|
%% {stop, Reason}
|
|
%%----------------------------------------------------------------------
|
|
init([Host, StartInterval]) ->
|
|
case ejabberd_config:get_local_option({odbc_keepalive_interval, Host}) of
|
|
KeepaliveInterval when is_integer(KeepaliveInterval) ->
|
|
timer:apply_interval(KeepaliveInterval*1000, ?MODULE,
|
|
keep_alive, [self()]);
|
|
undefined ->
|
|
ok;
|
|
_Other ->
|
|
?ERROR_MSG("Wrong odbc_keepalive_interval definition '~p'"
|
|
" for host ~p.~n", [_Other, Host])
|
|
end,
|
|
SQLServer = ejabberd_config:get_local_option({odbc_server, Host}),
|
|
case SQLServer of
|
|
%% Default pgsql port
|
|
{pgsql, Server, DB, Username, Password} ->
|
|
pgsql_connect(Server, ?PGSQL_PORT, DB, Username, Password,
|
|
StartInterval);
|
|
{pgsql, Server, Port, DB, Username, Password} when is_integer(Port) ->
|
|
pgsql_connect(Server, Port, DB, Username, Password,
|
|
StartInterval);
|
|
%% Default mysql port
|
|
{mysql, Server, DB, Username, Password} ->
|
|
mysql_connect(Server, ?MYSQL_PORT, DB, Username, Password,
|
|
StartInterval);
|
|
{mysql, Server, Port, DB, Username, Password} when is_integer(Port) ->
|
|
mysql_connect(Server, Port, DB, Username, Password,
|
|
StartInterval);
|
|
_ when is_list(SQLServer) ->
|
|
odbc_connect(SQLServer, StartInterval)
|
|
end.
|
|
|
|
%%----------------------------------------------------------------------
|
|
%% Func: handle_call/3
|
|
%% Returns: {reply, Reply, State} |
|
|
%% {reply, Reply, State, Timeout} |
|
|
%% {noreply, State} |
|
|
%% {noreply, State, Timeout} |
|
|
%% {stop, Reason, Reply, State} | (terminate/2 is called)
|
|
%% {stop, Reason, State} (terminate/2 is called)
|
|
%%----------------------------------------------------------------------
|
|
handle_call({sql_cmd, Command}, _From, State) ->
|
|
put(?NESTING_KEY, ?TOP_LEVEL_TXN),
|
|
put(?STATE_KEY, State),
|
|
abort_on_driver_error(outer_op(Command));
|
|
handle_call(Request, {Who, _Ref}, State) ->
|
|
?WARNING_MSG("Unexpected call ~p from ~p.", [Request, Who]),
|
|
{reply, ok, State}.
|
|
|
|
%%----------------------------------------------------------------------
|
|
%% Func: handle_cast/2
|
|
%% Returns: {noreply, State} |
|
|
%% {noreply, State, Timeout} |
|
|
%% {stop, Reason, State} (terminate/2 is called)
|
|
%%----------------------------------------------------------------------
|
|
handle_cast(_Msg, State) ->
|
|
{noreply, State}.
|
|
|
|
|
|
code_change(_OldVsn, State, _Extra) ->
|
|
{ok, State}.
|
|
|
|
%%----------------------------------------------------------------------
|
|
%% Func: handle_info/2
|
|
%% Returns: {noreply, State} |
|
|
%% {noreply, State, Timeout} |
|
|
%% {stop, Reason, State} (terminate/2 is called)
|
|
%%----------------------------------------------------------------------
|
|
%% We receive the down signal when we loose the MySQL connection (we are
|
|
%% monitoring the connection)
|
|
%% => We exit and let the supervisor restart the connection.
|
|
handle_info({'DOWN', _MonitorRef, process, _Pid, _Info}, State) ->
|
|
{stop, connection_dropped, State};
|
|
handle_info(_Info, State) ->
|
|
{noreply, State}.
|
|
|
|
%%----------------------------------------------------------------------
|
|
%% Func: terminate/2
|
|
%% Purpose: Shutdown the server
|
|
%% Returns: any (ignored by gen_server)
|
|
%%----------------------------------------------------------------------
|
|
terminate(_Reason, State) ->
|
|
case State#state.db_type of
|
|
mysql ->
|
|
% old versions of mysql driver don't have the stop function
|
|
% so the catch
|
|
catch mysql_conn:stop(State#state.db_ref);
|
|
_ ->
|
|
ok
|
|
end,
|
|
ok.
|
|
|
|
%%%----------------------------------------------------------------------
|
|
%%% Internal functions
|
|
%%%----------------------------------------------------------------------
|
|
|
|
%% Only called by handle_call, only handles top level operations.
|
|
%% @spec outer_op(Op) -> {error, Reason} | {aborted, Reason} | {atomic, Result}
|
|
outer_op({sql_query, Query}) ->
|
|
sql_query_internal(Query);
|
|
outer_op({sql_transaction, F}) ->
|
|
outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, "");
|
|
outer_op({sql_bloc, F}) ->
|
|
execute_bloc(F).
|
|
|
|
%% Called via sql_query/transaction/bloc from client code when inside a
|
|
%% nested operation
|
|
nested_op({sql_query, Query}) ->
|
|
%% XXX - use sql_query_t here insted? Most likely would break
|
|
%% callers who expect {error, _} tuples (sql_query_t turns
|
|
%% these into throws)
|
|
sql_query_internal(Query);
|
|
nested_op({sql_transaction, F}) ->
|
|
NestingLevel = get(?NESTING_KEY),
|
|
if NestingLevel =:= ?TOP_LEVEL_TXN ->
|
|
%% First transaction inside a (series of) sql_blocs
|
|
outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, "");
|
|
true ->
|
|
%% Transaction inside a transaction
|
|
inner_transaction(F)
|
|
end;
|
|
nested_op({sql_bloc, F}) ->
|
|
execute_bloc(F).
|
|
|
|
%% Never retry nested transactions - only outer transactions
|
|
inner_transaction(F) ->
|
|
PreviousNestingLevel = get(?NESTING_KEY),
|
|
case get(?NESTING_KEY) of
|
|
?TOP_LEVEL_TXN ->
|
|
{backtrace, T} = process_info(self(), backtrace),
|
|
?ERROR_MSG("inner transaction called at outer txn level. Trace: ~s",
|
|
[T]),
|
|
erlang:exit(implementation_faulty);
|
|
_N -> ok
|
|
end,
|
|
put(?NESTING_KEY, PreviousNestingLevel + 1),
|
|
Result = (catch F()),
|
|
put(?NESTING_KEY, PreviousNestingLevel),
|
|
case Result of
|
|
{aborted, Reason} ->
|
|
{aborted, Reason};
|
|
{'EXIT', Reason} ->
|
|
{'EXIT', Reason};
|
|
{atomic, Res} ->
|
|
{atomic, Res};
|
|
Res ->
|
|
{atomic, Res}
|
|
end.
|
|
|
|
outer_transaction(F, NRestarts, _Reason) ->
|
|
PreviousNestingLevel = get(?NESTING_KEY),
|
|
case get(?NESTING_KEY) of
|
|
?TOP_LEVEL_TXN ->
|
|
ok;
|
|
_N ->
|
|
{backtrace, T} = process_info(self(), backtrace),
|
|
?ERROR_MSG("outer transaction called at inner txn level. Trace: ~s",
|
|
[T]),
|
|
erlang:exit(implementation_faulty)
|
|
end,
|
|
sql_query_internal("begin;"),
|
|
put(?NESTING_KEY, PreviousNestingLevel + 1),
|
|
Result = (catch F()),
|
|
put(?NESTING_KEY, PreviousNestingLevel),
|
|
case Result of
|
|
{aborted, Reason} when NRestarts > 0 ->
|
|
%% Retry outer transaction upto NRestarts times.
|
|
sql_query_internal("rollback;"),
|
|
outer_transaction(F, NRestarts - 1, Reason);
|
|
{aborted, Reason} when NRestarts =:= 0 ->
|
|
%% Too many retries of outer transaction.
|
|
?ERROR_MSG("SQL transaction restarts exceeded~n"
|
|
"** Restarts: ~p~n"
|
|
"** Last abort reason: ~p~n"
|
|
"** Stacktrace: ~p~n"
|
|
"** When State == ~p",
|
|
[?MAX_TRANSACTION_RESTARTS, Reason,
|
|
erlang:get_stacktrace(), get(?STATE_KEY)]),
|
|
sql_query_internal("rollback;"),
|
|
{aborted, Reason};
|
|
{'EXIT', Reason} ->
|
|
%% Abort sql transaction on EXIT from outer txn only.
|
|
sql_query_internal("rollback;"),
|
|
{aborted, Reason};
|
|
Res ->
|
|
%% Commit successful outer txn
|
|
sql_query_internal("commit;"),
|
|
{atomic, Res}
|
|
end.
|
|
|
|
execute_bloc(F) ->
|
|
%% We don't alter ?NESTING_KEY here as only SQL transactions alter
|
|
%% txn nesting
|
|
case catch F() of
|
|
{aborted, Reason} ->
|
|
{aborted, Reason};
|
|
{'EXIT', Reason} ->
|
|
{aborted, Reason};
|
|
Res ->
|
|
{atomic, Res}
|
|
end.
|
|
|
|
sql_query_internal(Query) ->
|
|
State = get(?STATE_KEY),
|
|
Res = case State#state.db_type of
|
|
odbc ->
|
|
odbc:sql_query(State#state.db_ref, Query);
|
|
pgsql ->
|
|
pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query));
|
|
mysql ->
|
|
?DEBUG("MySQL, Send query~n~p~n", [Query]),
|
|
R = mysql_to_odbc(mysql_conn:fetch(State#state.db_ref,
|
|
Query, self())),
|
|
%% ?INFO_MSG("MySQL, Received result~n~p~n", [R]),
|
|
R
|
|
end,
|
|
case Res of
|
|
{error, "No SQL-driver information available."} ->
|
|
% workaround for odbc bug
|
|
{updated, 0};
|
|
_Else -> Res
|
|
end.
|
|
|
|
%% Generate the OTP callback return tuple depending on the driver result.
|
|
abort_on_driver_error({error, "query timed out"} = Reply) ->
|
|
%% mysql driver error
|
|
{stop, timeout, Reply, get(?STATE_KEY)};
|
|
abort_on_driver_error({error, "Failed sending data on socket"++_} = Reply) ->
|
|
%% mysql driver error
|
|
{stop, closed, Reply, get(?STATE_KEY)};
|
|
abort_on_driver_error(Reply) ->
|
|
{reply, Reply, get(?STATE_KEY)}.
|
|
|
|
|
|
%% == pure ODBC code
|
|
|
|
%% part of init/1
|
|
%% Open an ODBC database connection
|
|
odbc_connect(SQLServer, StartInterval) ->
|
|
application:start(odbc),
|
|
case odbc:connect(SQLServer,[{scrollable_cursors, off}]) of
|
|
{ok, Ref} ->
|
|
erlang:monitor(process, Ref),
|
|
{ok, #state{db_ref = Ref, db_type = odbc}};
|
|
{error, Reason} ->
|
|
?ERROR_MSG("ODBC connection (~s) failed: ~p~n",
|
|
[SQLServer, Reason]),
|
|
%% If we can't connect we wait before retrying
|
|
timer:sleep(StartInterval),
|
|
{stop, odbc_connection_failed}
|
|
end.
|
|
|
|
|
|
%% == Native PostgreSQL code
|
|
|
|
%% part of init/1
|
|
%% Open a database connection to PostgreSQL
|
|
pgsql_connect(Server, Port, DB, Username, Password, StartInterval) ->
|
|
case pgsql:connect(Server, DB, Username, Password, Port) of
|
|
{ok, Ref} ->
|
|
erlang:monitor(process, Ref),
|
|
{ok, #state{db_ref = Ref, db_type = pgsql}};
|
|
{error, Reason} ->
|
|
?ERROR_MSG("PostgreSQL connection failed: ~p~n", [Reason]),
|
|
%% If we can't connect we wait before retrying
|
|
timer:sleep(StartInterval),
|
|
{stop, pgsql_connection_failed}
|
|
end.
|
|
|
|
%% Convert PostgreSQL query result to Erlang ODBC result formalism
|
|
pgsql_to_odbc({ok, PGSQLResult}) ->
|
|
case PGSQLResult of
|
|
[Item] ->
|
|
pgsql_item_to_odbc(Item);
|
|
Items ->
|
|
[pgsql_item_to_odbc(Item) || Item <- Items]
|
|
end.
|
|
|
|
pgsql_item_to_odbc({"SELECT", Rows, Recs}) ->
|
|
{selected,
|
|
[element(1, Row) || Row <- Rows],
|
|
[list_to_tuple(Rec) || Rec <- Recs]};
|
|
pgsql_item_to_odbc("INSERT " ++ OIDN) ->
|
|
[_OID, N] = string:tokens(OIDN, " "),
|
|
{updated, list_to_integer(N)};
|
|
pgsql_item_to_odbc("DELETE " ++ N) ->
|
|
{updated, list_to_integer(N)};
|
|
pgsql_item_to_odbc("UPDATE " ++ N) ->
|
|
{updated, list_to_integer(N)};
|
|
pgsql_item_to_odbc({error, Error}) ->
|
|
{error, Error};
|
|
pgsql_item_to_odbc(_) ->
|
|
{updated,undefined}.
|
|
|
|
%% == Native MySQL code
|
|
|
|
%% part of init/1
|
|
%% Open a database connection to MySQL
|
|
mysql_connect(Server, Port, DB, Username, Password, StartInterval) ->
|
|
case mysql_conn:start(Server, Port, Username, Password, DB, fun log/3) of
|
|
{ok, Ref} ->
|
|
erlang:monitor(process, Ref),
|
|
mysql_conn:fetch(Ref, ["set names 'utf8';"], self()),
|
|
{ok, #state{db_ref = Ref, db_type = mysql}};
|
|
{error, Reason} ->
|
|
?ERROR_MSG("MySQL connection failed: ~p~n"
|
|
"Waiting ~p seconds before retrying...~n",
|
|
[Reason, StartInterval div 1000]),
|
|
%% If we can't connect we wait before retrying
|
|
timer:sleep(StartInterval),
|
|
{stop, mysql_connection_failed}
|
|
end.
|
|
|
|
%% Convert MySQL query result to Erlang ODBC result formalism
|
|
mysql_to_odbc({updated, MySQLRes}) ->
|
|
{updated, mysql:get_result_affected_rows(MySQLRes)};
|
|
mysql_to_odbc({data, MySQLRes}) ->
|
|
mysql_item_to_odbc(mysql:get_result_field_info(MySQLRes),
|
|
mysql:get_result_rows(MySQLRes));
|
|
mysql_to_odbc({error, MySQLRes}) when is_list(MySQLRes) ->
|
|
{error, MySQLRes};
|
|
mysql_to_odbc({error, MySQLRes}) ->
|
|
{error, mysql:get_result_reason(MySQLRes)}.
|
|
|
|
%% When tabular data is returned, convert it to the ODBC formalism
|
|
mysql_item_to_odbc(Columns, Recs) ->
|
|
%% For now, there is a bug and we do not get the correct value from MySQL
|
|
%% module:
|
|
{selected,
|
|
[element(2, Column) || Column <- Columns],
|
|
[list_to_tuple(Rec) || Rec <- Recs]}.
|
|
|
|
% perform a harmless query on all opened connexions to avoid connexion close.
|
|
keep_alive(PID) ->
|
|
gen_server:call(PID, {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}},
|
|
?KEEPALIVE_TIMEOUT).
|
|
|
|
% log function used by MySQL driver
|
|
log(Level, Format, Args) ->
|
|
case Level of
|
|
debug ->
|
|
?DEBUG(Format, Args);
|
|
normal ->
|
|
?INFO_MSG(Format, Args);
|
|
error ->
|
|
?ERROR_MSG(Format, Args)
|
|
end.
|