25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-12-22 17:28:25 +01:00
xmpp.chapril.org-ejabberd/src/ejabberd_sql.erl
2018-12-13 11:46:53 +01:00

1186 lines
38 KiB
Erlang

%%%----------------------------------------------------------------------
%%% File : ejabberd_sql.erl
%%% Author : Alexey Shchepin <alexey@process-one.net>
%%% Purpose : Serve SQL connection
%%% Created : 8 Dec 2004 by Alexey Shchepin <alexey@process-one.net>
%%%
%%%
%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
%%%
%%% This program is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU General Public License as
%%% published by the Free Software Foundation; either version 2 of the
%%% License, or (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
%%% General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License along
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%----------------------------------------------------------------------
-module(ejabberd_sql).
-behaviour(ejabberd_config).
-author('alexey@process-one.net').
-behaviour(p1_fsm).
%% External exports
-export([start/1, start_link/2,
sql_query/2,
sql_query_t/1,
sql_transaction/2,
sql_bloc/2,
abort/1,
restart/1,
use_new_schema/0,
sql_query_to_iolist/1,
escape/1,
standard_escape/1,
escape_like/1,
escape_like_arg/1,
escape_like_arg_circumflex/1,
to_bool/1,
sqlite_db/1,
sqlite_file/1,
encode_term/1,
decode_term/1,
odbc_config/0,
freetds_config/0,
odbcinst_config/0,
init_mssql/1,
keep_alive/2,
to_list/2]).
%% gen_fsm callbacks
-export([init/1, handle_event/3, handle_sync_event/4,
handle_info/3, terminate/3, print_state/1,
code_change/4]).
-export([connecting/2, connecting/3,
session_established/2, session_established/3,
opt_type/1]).
-include("logger.hrl").
-include("ejabberd_sql_pt.hrl").
-include("ejabberd_stacktrace.hrl").
-record(state,
{db_ref = self() :: pid(),
db_type = odbc :: pgsql | mysql | sqlite | odbc | mssql,
db_version = undefined :: undefined | non_neg_integer(),
start_interval = 0 :: non_neg_integer(),
host = <<"">> :: binary(),
pending_requests :: p1_queue:queue()}).
-define(STATE_KEY, ejabberd_sql_state).
-define(NESTING_KEY, ejabberd_sql_nesting_level).
-define(TOP_LEVEL_TXN, 0).
-define(PGSQL_PORT, 5432).
-define(MYSQL_PORT, 3306).
-define(MSSQL_PORT, 1433).
-define(MAX_TRANSACTION_RESTARTS, 10).
-define(KEEPALIVE_QUERY, [<<"SELECT 1;">>]).
-define(PREPARE_KEY, ejabberd_sql_prepare).
-ifdef(NEW_SQL_SCHEMA).
-define(USE_NEW_SCHEMA_DEFAULT, true).
-else.
-define(USE_NEW_SCHEMA_DEFAULT, false).
-endif.
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
-define(FSMOPTS, [{debug, [trace]}]).
-else.
-define(FSMOPTS, []).
-endif.
%%%----------------------------------------------------------------------
%%% API
%%%----------------------------------------------------------------------
start(Host) ->
p1_fsm:start(ejabberd_sql, [Host],
fsm_limit_opts() ++ (?FSMOPTS)).
start_link(Host, StartInterval) ->
p1_fsm:start_link(ejabberd_sql,
[Host, StartInterval],
fsm_limit_opts() ++ (?FSMOPTS)).
-type sql_query() :: [sql_query() | binary()] | #sql_query{} |
fun(() -> any()) | fun((atom(), _) -> any()).
-type sql_query_result() :: {updated, non_neg_integer()} |
{error, binary()} |
{selected, [binary()],
[[binary()]]} |
{selected, [any()]}.
-spec sql_query(binary(), sql_query()) -> sql_query_result().
sql_query(Host, Query) ->
sql_call(Host, {sql_query, Query}).
%% SQL transaction based on a list of queries
%% This function automatically
-spec sql_transaction(binary(), [sql_query()] | fun(() -> any())) ->
{atomic, any()} |
{aborted, any()}.
sql_transaction(Host, Queries)
when is_list(Queries) ->
F = fun () ->
lists:foreach(fun (Query) -> sql_query_t(Query) end,
Queries)
end,
sql_transaction(Host, F);
%% SQL transaction, based on a erlang anonymous function (F = fun)
sql_transaction(Host, F) when is_function(F) ->
sql_call(Host, {sql_transaction, F}).
%% SQL bloc, based on a erlang anonymous function (F = fun)
sql_bloc(Host, F) -> sql_call(Host, {sql_bloc, F}).
sql_call(Host, Msg) ->
case get(?STATE_KEY) of
undefined ->
case ejabberd_sql_sup:get_random_pid(Host) of
none -> {error, <<"Unknown Host">>};
Pid ->
sync_send_event(Pid,{sql_cmd, Msg,
p1_time_compat:monotonic_time(milli_seconds)},
query_timeout(Host))
end;
_State -> nested_op(Msg)
end.
keep_alive(Host, PID) ->
case sync_send_event(PID,
{sql_cmd, {sql_query, ?KEEPALIVE_QUERY},
p1_time_compat:monotonic_time(milli_seconds)},
query_timeout(Host)) of
{selected,_,[[<<"1">>]]} ->
ok;
_Err ->
?ERROR_MSG("keep alive query failed, closing connection: ~p", [_Err]),
sync_send_event(PID, force_timeout, query_timeout(Host))
end.
sync_send_event(Pid, Msg, Timeout) ->
try p1_fsm:sync_send_event(Pid, Msg, Timeout)
catch _:{Reason, {p1_fsm, _, _}} ->
{error, Reason}
end.
-spec sql_query_t(sql_query()) -> sql_query_result().
%% This function is intended to be used from inside an sql_transaction:
sql_query_t(Query) ->
QRes = sql_query_internal(Query),
case QRes of
{error, Reason} -> throw({aborted, Reason});
Rs when is_list(Rs) ->
case lists:keysearch(error, 1, Rs) of
{value, {error, Reason}} -> throw({aborted, Reason});
_ -> QRes
end;
_ -> QRes
end.
abort(Reason) ->
exit(Reason).
restart(Reason) ->
throw({aborted, Reason}).
-spec escape_char(char()) -> binary().
escape_char($\000) -> <<"\\0">>;
escape_char($\n) -> <<"\\n">>;
escape_char($\t) -> <<"\\t">>;
escape_char($\b) -> <<"\\b">>;
escape_char($\r) -> <<"\\r">>;
escape_char($') -> <<"''">>;
escape_char($") -> <<"\\\"">>;
escape_char($\\) -> <<"\\\\">>;
escape_char(C) -> <<C>>.
-spec escape(binary()) -> binary().
escape(S) ->
<< <<(escape_char(Char))/binary>> || <<Char>> <= S >>.
%% Escape character that will confuse an SQL engine
%% Percent and underscore only need to be escaped for pattern matching like
%% statement
escape_like(S) when is_binary(S) ->
<< <<(escape_like(C))/binary>> || <<C>> <= S >>;
escape_like($%) -> <<"\\%">>;
escape_like($_) -> <<"\\_">>;
escape_like($\\) -> <<"\\\\\\\\">>;
escape_like(C) when is_integer(C), C >= 0, C =< 255 -> escape_char(C).
escape_like_arg(S) when is_binary(S) ->
<< <<(escape_like_arg(C))/binary>> || <<C>> <= S >>;
escape_like_arg($%) -> <<"\\%">>;
escape_like_arg($_) -> <<"\\_">>;
escape_like_arg($\\) -> <<"\\\\">>;
escape_like_arg(C) when is_integer(C), C >= 0, C =< 255 -> <<C>>.
escape_like_arg_circumflex(S) when is_binary(S) ->
<< <<(escape_like_arg_circumflex(C))/binary>> || <<C>> <= S >>;
escape_like_arg_circumflex($%) -> <<"^%">>;
escape_like_arg_circumflex($_) -> <<"^_">>;
escape_like_arg_circumflex($^) -> <<"^^">>;
escape_like_arg_circumflex($[) -> <<"^[">>; % For MSSQL
escape_like_arg_circumflex($]) -> <<"^]">>;
escape_like_arg_circumflex(C) when is_integer(C), C >= 0, C =< 255 -> <<C>>.
to_bool(<<"t">>) -> true;
to_bool(<<"true">>) -> true;
to_bool(<<"1">>) -> true;
to_bool(true) -> true;
to_bool(1) -> true;
to_bool(_) -> false.
to_list(EscapeFun, Val) ->
Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
[<<"(">>, Escaped, <<")">>].
encode_term(Term) ->
escape(list_to_binary(
erl_prettypr:format(erl_syntax:abstract(Term),
[{paper, 65535}, {ribbon, 65535}]))).
decode_term(Bin) ->
Str = binary_to_list(<<Bin/binary, ".">>),
{ok, Tokens, _} = erl_scan:string(Str),
{ok, Term} = erl_parse:parse_term(Tokens),
Term.
-spec sqlite_db(binary()) -> atom().
sqlite_db(Host) ->
list_to_atom("ejabberd_sqlite_" ++ binary_to_list(Host)).
-spec sqlite_file(binary()) -> string().
sqlite_file(Host) ->
case ejabberd_config:get_option({sql_database, Host}) of
undefined ->
{ok, Cwd} = file:get_cwd(),
filename:join([Cwd, "sqlite", atom_to_list(node()),
binary_to_list(Host), "ejabberd.db"]);
File ->
binary_to_list(File)
end.
use_new_schema() ->
ejabberd_config:get_option(new_sql_schema, ?USE_NEW_SCHEMA_DEFAULT).
%%%----------------------------------------------------------------------
%%% Callback functions from gen_fsm
%%%----------------------------------------------------------------------
init([Host, StartInterval]) ->
process_flag(trap_exit, true),
case ejabberd_config:get_option({sql_keepalive_interval, Host}) of
undefined ->
ok;
KeepaliveInterval ->
timer:apply_interval(KeepaliveInterval * 1000, ?MODULE,
keep_alive, [Host, self()])
end,
[DBType | _] = db_opts(Host),
p1_fsm:send_event(self(), connect),
ejabberd_sql_sup:add_pid(Host, self()),
QueueType = case ejabberd_config:get_option({sql_queue_type, Host}) of
undefined ->
ejabberd_config:default_queue_type(Host);
Type ->
Type
end,
{ok, connecting,
#state{db_type = DBType, host = Host,
pending_requests = p1_queue:new(QueueType, max_fsm_queue()),
start_interval = StartInterval}}.
connecting(connect, #state{host = Host} = State) ->
ConnectRes = case db_opts(Host) of
[mysql | Args] -> apply(fun mysql_connect/8, Args);
[pgsql | Args] -> apply(fun pgsql_connect/8, Args);
[sqlite | Args] -> apply(fun sqlite_connect/1, Args);
[mssql | Args] -> apply(fun odbc_connect/2, Args);
[odbc | Args] -> apply(fun odbc_connect/2, Args)
end,
case ConnectRes of
{ok, Ref} ->
erlang:monitor(process, Ref),
lists:foreach(
fun({{?PREPARE_KEY, _} = Key, _}) ->
erase(Key);
(_) ->
ok
end, get()),
PendingRequests =
p1_queue:dropwhile(
fun(Req) ->
p1_fsm:send_event(self(), Req),
true
end, State#state.pending_requests),
State1 = State#state{db_ref = Ref,
pending_requests = PendingRequests},
State2 = get_db_version(State1),
{next_state, session_established, State2};
{error, Reason} ->
?WARNING_MSG("~p connection failed:~n** Reason: ~p~n** "
"Retry after: ~p seconds",
[State#state.db_type, Reason,
State#state.start_interval div 1000]),
p1_fsm:send_event_after(State#state.start_interval,
connect),
{next_state, connecting, State}
end;
connecting(Event, State) ->
?WARNING_MSG("unexpected event in 'connecting': ~p",
[Event]),
{next_state, connecting, State}.
connecting({sql_cmd, {sql_query, ?KEEPALIVE_QUERY},
_Timestamp},
From, State) ->
p1_fsm:reply(From,
{error, <<"SQL connection failed">>}),
{next_state, connecting, State};
connecting({sql_cmd, Command, Timestamp} = Req, From,
State) ->
?DEBUG("queuing pending request while connecting:~n\t~p",
[Req]),
PendingRequests =
try p1_queue:in({sql_cmd, Command, From, Timestamp},
State#state.pending_requests)
catch error:full ->
Q = p1_queue:dropwhile(
fun({sql_cmd, _, To, _Timestamp}) ->
p1_fsm:reply(
To, {error, <<"SQL connection failed">>}),
true
end, State#state.pending_requests),
p1_queue:in({sql_cmd, Command, From, Timestamp}, Q)
end,
{next_state, connecting,
State#state{pending_requests = PendingRequests}};
connecting(Request, {Who, _Ref}, State) ->
?WARNING_MSG("unexpected call ~p from ~p in 'connecting'",
[Request, Who]),
{reply, {error, badarg}, connecting, State}.
session_established({sql_cmd, Command, Timestamp}, From,
State) ->
run_sql_cmd(Command, From, State, Timestamp);
session_established(Request, {Who, _Ref}, State) ->
?WARNING_MSG("unexpected call ~p from ~p in 'session_establ"
"ished'",
[Request, Who]),
{reply, {error, badarg}, session_established, State}.
session_established({sql_cmd, Command, From, Timestamp},
State) ->
run_sql_cmd(Command, From, State, Timestamp);
session_established(force_timeout, State) ->
{stop, timeout, State};
session_established(Event, State) ->
?WARNING_MSG("unexpected event in 'session_established': ~p",
[Event]),
{next_state, session_established, State}.
handle_event(_Event, StateName, State) ->
{next_state, StateName, State}.
handle_sync_event(_Event, _From, StateName, State) ->
{reply, {error, badarg}, StateName, State}.
code_change(_OldVsn, StateName, State, _Extra) ->
{ok, StateName, State}.
%% We receive the down signal when we loose the MySQL connection (we are
%% monitoring the connection)
handle_info({'DOWN', _MonitorRef, process, _Pid, _Info},
_StateName, State) ->
p1_fsm:send_event(self(), connect),
{next_state, connecting, State};
handle_info(Info, StateName, State) ->
?WARNING_MSG("unexpected info in ~p: ~p",
[StateName, Info]),
{next_state, StateName, State}.
terminate(_Reason, _StateName, State) ->
ejabberd_sql_sup:remove_pid(State#state.host, self()),
case State#state.db_type of
mysql -> catch p1_mysql_conn:stop(State#state.db_ref);
sqlite -> catch sqlite3:close(sqlite_db(State#state.host));
_ -> ok
end,
ok.
%%----------------------------------------------------------------------
%% Func: print_state/1
%% Purpose: Prepare the state to be printed on error log
%% Returns: State to print
%%----------------------------------------------------------------------
print_state(State) -> State.
%%%----------------------------------------------------------------------
%%% Internal functions
%%%----------------------------------------------------------------------
run_sql_cmd(Command, From, State, Timestamp) ->
QueryTimeout = query_timeout(State#state.host),
case p1_time_compat:monotonic_time(milli_seconds) - Timestamp of
Age when Age < QueryTimeout ->
put(?NESTING_KEY, ?TOP_LEVEL_TXN),
put(?STATE_KEY, State),
abort_on_driver_error(outer_op(Command), From);
Age ->
?ERROR_MSG("Database was not available or too slow, "
"discarding ~p milliseconds old request~n~p~n",
[Age, Command]),
{next_state, session_established, State}
end.
%% Only called by handle_call, only handles top level operations.
%% @spec outer_op(Op) -> {error, Reason} | {aborted, Reason} | {atomic, Result}
outer_op({sql_query, Query}) ->
sql_query_internal(Query);
outer_op({sql_transaction, F}) ->
outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, <<"">>);
outer_op({sql_bloc, F}) -> execute_bloc(F).
%% Called via sql_query/transaction/bloc from client code when inside a
%% nested operation
nested_op({sql_query, Query}) ->
sql_query_internal(Query);
nested_op({sql_transaction, F}) ->
NestingLevel = get(?NESTING_KEY),
if NestingLevel =:= (?TOP_LEVEL_TXN) ->
outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, <<"">>);
true -> inner_transaction(F)
end;
nested_op({sql_bloc, F}) -> execute_bloc(F).
%% Never retry nested transactions - only outer transactions
inner_transaction(F) ->
PreviousNestingLevel = get(?NESTING_KEY),
case get(?NESTING_KEY) of
?TOP_LEVEL_TXN ->
{backtrace, T} = process_info(self(), backtrace),
?ERROR_MSG("inner transaction called at outer txn "
"level. Trace: ~s",
[T]),
erlang:exit(implementation_faulty);
_N -> ok
end,
put(?NESTING_KEY, PreviousNestingLevel + 1),
Result = (catch F()),
put(?NESTING_KEY, PreviousNestingLevel),
case Result of
{aborted, Reason} -> {aborted, Reason};
{'EXIT', Reason} -> {'EXIT', Reason};
{atomic, Res} -> {atomic, Res};
Res -> {atomic, Res}
end.
outer_transaction(F, NRestarts, _Reason) ->
PreviousNestingLevel = get(?NESTING_KEY),
case get(?NESTING_KEY) of
?TOP_LEVEL_TXN -> ok;
_N ->
{backtrace, T} = process_info(self(), backtrace),
?ERROR_MSG("outer transaction called at inner txn "
"level. Trace: ~s",
[T]),
erlang:exit(implementation_faulty)
end,
sql_query_internal([<<"begin;">>]),
put(?NESTING_KEY, PreviousNestingLevel + 1),
try F() of
Res ->
sql_query_internal([<<"commit;">>]),
{atomic, Res}
catch
?EX_RULE(throw, {aborted, Reason}, _) when NRestarts > 0 ->
sql_query_internal([<<"rollback;">>]),
outer_transaction(F, NRestarts - 1, Reason);
?EX_RULE(throw, {aborted, Reason}, Stack) when NRestarts =:= 0 ->
?ERROR_MSG("SQL transaction restarts exceeded~n** "
"Restarts: ~p~n** Last abort reason: "
"~p~n** Stacktrace: ~p~n** When State "
"== ~p",
[?MAX_TRANSACTION_RESTARTS, Reason,
?EX_STACK(Stack), get(?STATE_KEY)]),
sql_query_internal([<<"rollback;">>]),
{aborted, Reason};
?EX_RULE(exit, Reason, _) ->
sql_query_internal([<<"rollback;">>]),
{aborted, Reason}
end.
execute_bloc(F) ->
case catch F() of
{aborted, Reason} -> {aborted, Reason};
{'EXIT', Reason} -> {aborted, Reason};
Res -> {atomic, Res}
end.
execute_fun(F) when is_function(F, 0) ->
F();
execute_fun(F) when is_function(F, 2) ->
State = get(?STATE_KEY),
F(State#state.db_type, State#state.db_version).
sql_query_internal([{_, _} | _] = Queries) ->
State = get(?STATE_KEY),
case select_sql_query(Queries, State) of
undefined ->
{error, <<"no matching query for the current DBMS found">>};
Query ->
sql_query_internal(Query)
end;
sql_query_internal(#sql_query{} = Query) ->
State = get(?STATE_KEY),
Res =
try
case State#state.db_type of
odbc ->
generic_sql_query(Query);
mssql ->
mssql_sql_query(Query);
pgsql ->
Key = {?PREPARE_KEY, Query#sql_query.hash},
case get(Key) of
undefined ->
case pgsql_prepare(Query, State) of
{ok, _, _, _} ->
put(Key, prepared);
{error, Error} ->
?ERROR_MSG("PREPARE failed for SQL query "
"at ~p: ~p",
[Query#sql_query.loc, Error]),
put(Key, ignore)
end;
_ ->
ok
end,
case get(Key) of
prepared ->
pgsql_execute_sql_query(Query, State);
_ ->
generic_sql_query(Query)
end;
mysql ->
generic_sql_query(Query);
sqlite ->
sqlite_sql_query(Query)
end
catch exit:{timeout, _} ->
{error, <<"timed out">>};
exit:{killed, _} ->
{error, <<"killed">>};
exit:{normal, _} ->
{error, <<"terminated unexpectedly">>};
?EX_RULE(Class, Reason, Stack) ->
?ERROR_MSG("Internal error while processing SQL query: ~p",
[{Class, Reason, ?EX_STACK(Stack)}]),
{error, <<"internal error">>}
end,
check_error(Res, Query);
sql_query_internal(F) when is_function(F) ->
case catch execute_fun(F) of
{'EXIT', Reason} -> {error, Reason};
Res -> Res
end;
sql_query_internal(Query) ->
State = get(?STATE_KEY),
?DEBUG("SQL: \"~s\"", [Query]),
QueryTimeout = query_timeout(State#state.host),
Res = case State#state.db_type of
odbc ->
to_odbc(odbc:sql_query(State#state.db_ref, [Query],
QueryTimeout - 1000));
mssql ->
to_odbc(odbc:sql_query(State#state.db_ref, [Query],
QueryTimeout - 1000));
pgsql ->
pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query,
QueryTimeout - 1000));
mysql ->
R = mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref,
[Query], self(),
[{timeout, QueryTimeout - 1000},
{result_type, binary}])),
R;
sqlite ->
Host = State#state.host,
sqlite_to_odbc(Host, sqlite3:sql_exec(sqlite_db(Host), Query))
end,
check_error(Res, Query).
select_sql_query(Queries, State) ->
select_sql_query(
Queries, State#state.db_type, State#state.db_version, undefined).
select_sql_query([], _Type, _Version, undefined) ->
undefined;
select_sql_query([], _Type, _Version, Query) ->
Query;
select_sql_query([{any, Query} | _], _Type, _Version, _) ->
Query;
select_sql_query([{Type, Query} | _], Type, _Version, _) ->
Query;
select_sql_query([{{Type, _Version1}, Query1} | Rest], Type, undefined, _) ->
select_sql_query(Rest, Type, undefined, Query1);
select_sql_query([{{Type, Version1}, Query1} | Rest], Type, Version, Query) ->
if
Version >= Version1 ->
Query1;
true ->
select_sql_query(Rest, Type, Version, Query)
end;
select_sql_query([{_, _} | Rest], Type, Version, Query) ->
select_sql_query(Rest, Type, Version, Query).
generic_sql_query(SQLQuery) ->
sql_query_format_res(
sql_query_internal(generic_sql_query_format(SQLQuery)),
SQLQuery).
generic_sql_query_format(SQLQuery) ->
Args = (SQLQuery#sql_query.args)(generic_escape()),
(SQLQuery#sql_query.format_query)(Args).
generic_escape() ->
#sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
integer = fun(X) -> misc:i2l(X) end,
boolean = fun(true) -> <<"1">>;
(false) -> <<"0">>
end
}.
sqlite_sql_query(SQLQuery) ->
sql_query_format_res(
sql_query_internal(sqlite_sql_query_format(SQLQuery)),
SQLQuery).
sqlite_sql_query_format(SQLQuery) ->
Args = (SQLQuery#sql_query.args)(sqlite_escape()),
(SQLQuery#sql_query.format_query)(Args).
sqlite_escape() ->
#sql_escape{string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end,
integer = fun(X) -> misc:i2l(X) end,
boolean = fun(true) -> <<"1">>;
(false) -> <<"0">>
end
}.
standard_escape(S) ->
<< <<(case Char of
$' -> << "''" >>;
_ -> << Char >>
end)/binary>> || <<Char>> <= S >>.
mssql_sql_query(SQLQuery) ->
sqlite_sql_query(SQLQuery).
pgsql_prepare(SQLQuery, State) ->
Escape = #sql_escape{_ = fun(X) -> X end},
N = length((SQLQuery#sql_query.args)(Escape)),
Args = [<<$$, (integer_to_binary(I))/binary>> || I <- lists:seq(1, N)],
Query = (SQLQuery#sql_query.format_query)(Args),
pgsql:prepare(State#state.db_ref, SQLQuery#sql_query.hash, Query).
pgsql_execute_escape() ->
#sql_escape{string = fun(X) -> X end,
integer = fun(X) -> [misc:i2l(X)] end,
boolean = fun(true) -> "1";
(false) -> "0"
end
}.
pgsql_execute_sql_query(SQLQuery, State) ->
Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()),
ExecuteRes =
pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args),
% {T, ExecuteRes} =
% timer:tc(pgsql, execute, [State#state.db_ref, SQLQuery#sql_query.hash, Args]),
% io:format("T ~s ~p~n", [SQLQuery#sql_query.hash, T]),
Res = pgsql_execute_to_odbc(ExecuteRes),
sql_query_format_res(Res, SQLQuery).
sql_query_format_res({selected, _, Rows}, SQLQuery) ->
Res =
lists:flatmap(
fun(Row) ->
try
[(SQLQuery#sql_query.format_res)(Row)]
catch
?EX_RULE(Class, Reason, Stack) ->
?ERROR_MSG("Error while processing "
"SQL query result: ~p~n"
"row: ~p",
[{Class, Reason, ?EX_STACK(Stack)}, Row]),
[]
end
end, Rows),
{selected, Res};
sql_query_format_res(Res, _SQLQuery) ->
Res.
sql_query_to_iolist(SQLQuery) ->
generic_sql_query_format(SQLQuery).
%% Generate the OTP callback return tuple depending on the driver result.
abort_on_driver_error({error,
<<"query timed out">>} = Reply,
From) ->
p1_fsm:reply(From, Reply),
{stop, timeout, get(?STATE_KEY)};
abort_on_driver_error({error,
<<"Failed sending data on socket", _/binary>>} = Reply,
From) ->
p1_fsm:reply(From, Reply),
{stop, closed, get(?STATE_KEY)};
abort_on_driver_error({error,
<<"SQL connection failed">>} = Reply,
From) ->
p1_fsm:reply(From, Reply),
{stop, timeout, get(?STATE_KEY)};
abort_on_driver_error({error,
<<"Communication link failure">>} = Reply,
From) ->
p1_fsm:reply(From, Reply),
{stop, closed, get(?STATE_KEY)};
abort_on_driver_error(Reply, From) ->
p1_fsm:reply(From, Reply),
{next_state, session_established, get(?STATE_KEY)}.
%% == pure ODBC code
%% part of init/1
%% Open an ODBC database connection
odbc_connect(SQLServer, Timeout) ->
ejabberd:start_app(odbc),
odbc:connect(binary_to_list(SQLServer),
[{scrollable_cursors, off},
{extended_errors, on},
{tuple_row, off},
{timeout, Timeout},
{binary_strings, on}]).
%% == Native SQLite code
%% part of init/1
%% Open a database connection to SQLite
sqlite_connect(Host) ->
File = sqlite_file(Host),
case filelib:ensure_dir(File) of
ok ->
case sqlite3:open(sqlite_db(Host), [{file, File}]) of
{ok, Ref} ->
sqlite3:sql_exec(
sqlite_db(Host), "pragma foreign_keys = on"),
{ok, Ref};
{error, {already_started, Ref}} ->
{ok, Ref};
{error, Reason} ->
{error, Reason}
end;
Err ->
Err
end.
%% Convert SQLite query result to Erlang ODBC result formalism
sqlite_to_odbc(Host, ok) ->
{updated, sqlite3:changes(sqlite_db(Host))};
sqlite_to_odbc(Host, {rowid, _}) ->
{updated, sqlite3:changes(sqlite_db(Host))};
sqlite_to_odbc(_Host, [{columns, Columns}, {rows, TRows}]) ->
Rows = [lists:map(
fun(I) when is_integer(I) ->
integer_to_binary(I);
(B) ->
B
end, tuple_to_list(Row)) || Row <- TRows],
{selected, [list_to_binary(C) || C <- Columns], Rows};
sqlite_to_odbc(_Host, {error, _Code, Reason}) ->
{error, Reason};
sqlite_to_odbc(_Host, _) ->
{updated, undefined}.
%% == Native PostgreSQL code
%% part of init/1
%% Open a database connection to PostgreSQL
pgsql_connect(Server, Port, DB, Username, Password, ConnectTimeout,
Transport, SSLOpts) ->
case pgsql:connect([{host, Server},
{database, DB},
{user, Username},
{password, Password},
{port, Port},
{transport, Transport},
{connect_timeout, ConnectTimeout},
{as_binary, true}|SSLOpts]) of
{ok, Ref} ->
pgsql:squery(Ref, [<<"alter database \"">>, DB, <<"\" set ">>,
<<"standard_conforming_strings='off';">>]),
pgsql:squery(Ref, [<<"set standard_conforming_strings to 'off';">>]),
{ok, Ref};
Err ->
Err
end.
%% Convert PostgreSQL query result to Erlang ODBC result formalism
pgsql_to_odbc({ok, PGSQLResult}) ->
case PGSQLResult of
[Item] -> pgsql_item_to_odbc(Item);
Items -> [pgsql_item_to_odbc(Item) || Item <- Items]
end.
pgsql_item_to_odbc({<<"SELECT", _/binary>>, Rows,
Recs}) ->
{selected, [element(1, Row) || Row <- Rows], Recs};
pgsql_item_to_odbc({<<"FETCH", _/binary>>, Rows,
Recs}) ->
{selected, [element(1, Row) || Row <- Rows], Recs};
pgsql_item_to_odbc(<<"INSERT ", OIDN/binary>>) ->
[_OID, N] = str:tokens(OIDN, <<" ">>),
{updated, binary_to_integer(N)};
pgsql_item_to_odbc(<<"DELETE ", N/binary>>) ->
{updated, binary_to_integer(N)};
pgsql_item_to_odbc(<<"UPDATE ", N/binary>>) ->
{updated, binary_to_integer(N)};
pgsql_item_to_odbc({error, Error}) -> {error, Error};
pgsql_item_to_odbc(_) -> {updated, undefined}.
pgsql_execute_to_odbc({ok, {<<"SELECT", _/binary>>, Rows}}) ->
{selected, [], [[Field || {_, Field} <- Row] || Row <- Rows]};
pgsql_execute_to_odbc({ok, {'INSERT', N}}) ->
{updated, N};
pgsql_execute_to_odbc({ok, {'DELETE', N}}) ->
{updated, N};
pgsql_execute_to_odbc({ok, {'UPDATE', N}}) ->
{updated, N};
pgsql_execute_to_odbc({error, Error}) -> {error, Error};
pgsql_execute_to_odbc(_) -> {updated, undefined}.
%% == Native MySQL code
%% part of init/1
%% Open a database connection to MySQL
mysql_connect(Server, Port, DB, Username, Password, ConnectTimeout, _, _) ->
case p1_mysql_conn:start(binary_to_list(Server), Port,
binary_to_list(Username),
binary_to_list(Password),
binary_to_list(DB),
ConnectTimeout, fun log/3)
of
{ok, Ref} ->
p1_mysql_conn:fetch(
Ref, [<<"set names 'utf8mb4' collate 'utf8mb4_bin';">>], self()),
{ok, Ref};
Err -> Err
end.
%% Convert MySQL query result to Erlang ODBC result formalism
mysql_to_odbc({updated, MySQLRes}) ->
{updated, p1_mysql:get_result_affected_rows(MySQLRes)};
mysql_to_odbc({data, MySQLRes}) ->
mysql_item_to_odbc(p1_mysql:get_result_field_info(MySQLRes),
p1_mysql:get_result_rows(MySQLRes));
mysql_to_odbc({error, MySQLRes})
when is_binary(MySQLRes) ->
{error, MySQLRes};
mysql_to_odbc({error, MySQLRes})
when is_list(MySQLRes) ->
{error, list_to_binary(MySQLRes)};
mysql_to_odbc({error, MySQLRes}) ->
{error, p1_mysql:get_result_reason(MySQLRes)};
mysql_to_odbc(ok) ->
ok.
%% When tabular data is returned, convert it to the ODBC formalism
mysql_item_to_odbc(Columns, Recs) ->
{selected, [element(2, Column) || Column <- Columns], Recs}.
to_odbc({selected, Columns, Recs}) ->
Rows = [lists:map(
fun(I) when is_integer(I) ->
integer_to_binary(I);
(B) ->
B
end, Row) || Row <- Recs],
{selected, [list_to_binary(C) || C <- Columns], Rows};
to_odbc({error, Reason}) when is_list(Reason) ->
{error, list_to_binary(Reason)};
to_odbc(Res) ->
Res.
get_db_version(#state{db_type = pgsql} = State) ->
case pgsql:squery(State#state.db_ref,
<<"select current_setting('server_version_num')">>) of
{ok, [{_, _, [[SVersion]]}]} ->
case catch binary_to_integer(SVersion) of
Version when is_integer(Version) ->
State#state{db_version = Version};
Error ->
?WARNING_MSG("error getting pgsql version: ~p", [Error]),
State
end;
Res ->
?WARNING_MSG("error getting pgsql version: ~p", [Res]),
State
end;
get_db_version(State) ->
State.
log(Level, Format, Args) ->
case Level of
debug -> ?DEBUG(Format, Args);
normal -> ?INFO_MSG(Format, Args);
error -> ?ERROR_MSG(Format, Args)
end.
db_opts(Host) ->
Type = ejabberd_config:get_option({sql_type, Host}, odbc),
Server = ejabberd_config:get_option({sql_server, Host}, <<"localhost">>),
Timeout = timer:seconds(
ejabberd_config:get_option({sql_connect_timeout, Host}, 5)),
Transport = case ejabberd_config:get_option({sql_ssl, Host}, false) of
false -> tcp;
true -> ssl
end,
warn_if_ssl_unsupported(Transport, Type),
case Type of
odbc ->
[odbc, Server, Timeout];
sqlite ->
[sqlite, Host];
_ ->
Port = ejabberd_config:get_option(
{sql_port, Host},
case Type of
mssql -> ?MSSQL_PORT;
mysql -> ?MYSQL_PORT;
pgsql -> ?PGSQL_PORT
end),
DB = ejabberd_config:get_option({sql_database, Host},
<<"ejabberd">>),
User = ejabberd_config:get_option({sql_username, Host},
<<"ejabberd">>),
Pass = ejabberd_config:get_option({sql_password, Host},
<<"">>),
SSLOpts = get_ssl_opts(Transport, Host),
case Type of
mssql ->
[mssql, <<"DSN=", Host/binary, ";UID=", User/binary,
";PWD=", Pass/binary>>, Timeout];
_ ->
[Type, Server, Port, DB, User, Pass, Timeout, Transport, SSLOpts]
end
end.
warn_if_ssl_unsupported(tcp, _) ->
ok;
warn_if_ssl_unsupported(ssl, pgsql) ->
ok;
warn_if_ssl_unsupported(ssl, Type) ->
?WARNING_MSG("SSL connection is not supported for ~s", [Type]).
get_ssl_opts(ssl, Host) ->
Opts1 = case ejabberd_config:get_option({sql_ssl_certfile, Host}) of
undefined -> [];
CertFile -> [{certfile, CertFile}]
end,
Opts2 = case ejabberd_config:get_option({sql_ssl_cafile, Host}) of
undefined -> Opts1;
CAFile -> [{cacertfile, CAFile}|Opts1]
end,
case ejabberd_config:get_option({sql_ssl_verify, Host}, false) of
true ->
case lists:keymember(cacertfile, 1, Opts2) of
true ->
[{verify, verify_peer}|Opts2];
false ->
?WARNING_MSG("SSL verification is enabled for "
"SQL connection, but option "
"'sql_ssl_cafile' is not set; "
"verification will be disabled", []),
Opts2
end;
false ->
Opts2
end;
get_ssl_opts(tcp, _) ->
[].
init_mssql(Host) ->
Server = ejabberd_config:get_option({sql_server, Host}, <<"localhost">>),
Port = ejabberd_config:get_option({sql_port, Host}, ?MSSQL_PORT),
DB = ejabberd_config:get_option({sql_database, Host}, <<"ejabberd">>),
FreeTDS = io_lib:fwrite("[~s]~n"
"\thost = ~s~n"
"\tport = ~p~n"
"\tclient charset = UTF-8~n"
"\ttds version = 7.1~n",
[Host, Server, Port]),
ODBCINST = io_lib:fwrite("[freetds]~n"
"Description = MSSQL connection~n"
"Driver = libtdsodbc.so~n"
"Setup = libtdsS.so~n"
"UsageCount = 1~n"
"FileUsage = 1~n", []),
ODBCINI = io_lib:fwrite("[~s]~n"
"Description = MS SQL~n"
"Driver = freetds~n"
"Servername = ~s~n"
"Database = ~s~n"
"Port = ~p~n",
[Host, Host, DB, Port]),
?DEBUG("~s:~n~s", [freetds_config(), FreeTDS]),
?DEBUG("~s:~n~s", [odbcinst_config(), ODBCINST]),
?DEBUG("~s:~n~s", [odbc_config(), ODBCINI]),
case filelib:ensure_dir(freetds_config()) of
ok ->
try
ok = file:write_file(freetds_config(), FreeTDS, [append]),
ok = file:write_file(odbcinst_config(), ODBCINST),
ok = file:write_file(odbc_config(), ODBCINI, [append]),
os:putenv("ODBCSYSINI", tmp_dir()),
os:putenv("FREETDS", freetds_config()),
os:putenv("FREETDSCONF", freetds_config()),
ok
catch error:{badmatch, {error, Reason} = Err} ->
?ERROR_MSG("failed to create temporary files in ~s: ~s",
[tmp_dir(), file:format_error(Reason)]),
Err
end;
{error, Reason} = Err ->
?ERROR_MSG("failed to create temporary directory ~s: ~s",
[tmp_dir(), file:format_error(Reason)]),
Err
end.
tmp_dir() ->
case os:type() of
{win32, _} -> filename:join([os:getenv("HOME"), "conf"]);
_ -> filename:join(["/tmp", "ejabberd"])
end.
odbc_config() ->
filename:join(tmp_dir(), "odbc.ini").
freetds_config() ->
filename:join(tmp_dir(), "freetds.conf").
odbcinst_config() ->
filename:join(tmp_dir(), "odbcinst.ini").
max_fsm_queue() ->
proplists:get_value(max_queue, fsm_limit_opts(), unlimited).
fsm_limit_opts() ->
ejabberd_config:fsm_limit_opts([]).
query_timeout(LServer) ->
timer:seconds(
ejabberd_config:get_option({sql_query_timeout, LServer}, 60)).
%% ***IMPORTANT*** This error format requires extended_errors turned on.
extended_error({"08S01", _, Reason}) ->
% TCP Provider: The specified network name is no longer available
?DEBUG("ODBC Link Failure: ~s", [Reason]),
<<"Communication link failure">>;
extended_error({"08001", _, Reason}) ->
% Login timeout expired
?DEBUG("ODBC Connect Timeout: ~s", [Reason]),
<<"SQL connection failed">>;
extended_error({"IMC01", _, Reason}) ->
% The connection is broken and recovery is not possible
?DEBUG("ODBC Link Failure: ~s", [Reason]),
<<"Communication link failure">>;
extended_error({"IMC06", _, Reason}) ->
% The connection is broken and recovery is not possible
?DEBUG("ODBC Link Failure: ~s", [Reason]),
<<"Communication link failure">>;
extended_error({Code, _, Reason}) ->
?DEBUG("ODBC Error ~s: ~s", [Code, Reason]),
iolist_to_binary(Reason);
extended_error(Error) ->
Error.
check_error({error, Why} = Err, _Query) when Why == killed ->
Err;
check_error({error, Why}, #sql_query{} = Query) ->
Err = extended_error(Why),
?ERROR_MSG("SQL query '~s' at ~p failed: ~p",
[Query#sql_query.hash, Query#sql_query.loc, Err]),
{error, Err};
check_error({error, Why}, Query) ->
Err = extended_error(Why),
case catch iolist_to_binary(Query) of
SQuery when is_binary(SQuery) ->
?ERROR_MSG("SQL query '~s' failed: ~p", [SQuery, Err]);
_ ->
?ERROR_MSG("SQL query ~p failed: ~p", [Query, Err])
end,
{error, Err};
check_error(Result, _Query) ->
Result.
-spec opt_type(atom()) -> fun((any()) -> any()) | [atom()].
opt_type(sql_database) -> fun iolist_to_binary/1;
opt_type(sql_keepalive_interval) ->
fun (I) when is_integer(I), I > 0 -> I end;
opt_type(sql_password) -> fun iolist_to_binary/1;
opt_type(sql_port) ->
fun (P) when is_integer(P), P > 0, P < 65536 -> P end;
opt_type(sql_server) -> fun iolist_to_binary/1;
opt_type(sql_username) -> fun iolist_to_binary/1;
opt_type(sql_ssl) -> fun(B) when is_boolean(B) -> B end;
opt_type(sql_ssl_verify) -> fun(B) when is_boolean(B) -> B end;
opt_type(sql_ssl_certfile) -> fun ejabberd_pkix:try_certfile/1;
opt_type(sql_ssl_cafile) -> fun ejabberd_pkix:try_certfile/1;
opt_type(sql_query_timeout) ->
fun (I) when is_integer(I), I > 0 -> I end;
opt_type(sql_connect_timeout) ->
fun (I) when is_integer(I), I > 0 -> I end;
opt_type(sql_queue_type) ->
fun(ram) -> ram; (file) -> file end;
opt_type(new_sql_schema) -> fun(B) when is_boolean(B) -> B end;
opt_type(_) ->
[sql_database, sql_keepalive_interval,
sql_password, sql_port, sql_server,
sql_username, sql_ssl, sql_ssl_verify, sql_ssl_certfile,
sql_ssl_cafile, sql_queue_type, sql_query_timeout,
sql_connect_timeout,
new_sql_schema].