New parse transform for SQL queries, use prepare/execute calls with Postgres

This commit is contained in:
Alexey Shchepin 2016-02-09 19:23:15 +03:00
parent eeac7f9b02
commit 6374ef4866
3 changed files with 410 additions and 2 deletions

View File

@ -0,0 +1,27 @@
%%%----------------------------------------------------------------------
%%%
%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
%%%
%%% This program is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU General Public License as
%%% published by the Free Software Foundation; either version 2 of the
%%% License, or (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
%%% General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License along
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%----------------------------------------------------------------------
-define(SQL_MARK, sql__mark_).
-define(SQL(SQL), ?SQL_MARK(SQL)).
-record(sql_query, {hash, format_query, format_res, args, loc}).
-record(sql_escape, {string, integer, boolean}).

View File

@ -63,6 +63,7 @@
-include("ejabberd.hrl").
-include("logger.hrl").
-include("ejabberd_sql_pt.hrl").
-record(state,
{db_ref = self() :: pid(),
@ -92,6 +93,8 @@
-define(KEEPALIVE_QUERY, [<<"SELECT 1;">>]).
-define(PREPARE_KEY, ejabberd_odbc_prepare).
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
@ -116,11 +119,12 @@ start_link(Host, StartInterval) ->
[Host, StartInterval],
fsm_limit_opts() ++ (?FSMOPTS)).
-type sql_query() :: [sql_query() | binary()].
-type sql_query() :: [sql_query() | binary()] | #sql_query{}.
-type sql_query_result() :: {updated, non_neg_integer()} |
{error, binary()} |
{selected, [binary()],
[[binary()]]}.
[[binary()]]} |
{selected, [any]}.
-spec sql_query(binary(), sql_query()) -> sql_query_result().
@ -469,6 +473,52 @@ execute_bloc(F) ->
Res -> {atomic, Res}
end.
sql_query_internal(#sql_query{} = Query) ->
State = get(?STATE_KEY),
Res =
try
case State#state.db_type of
odbc ->
generic_sql_query(Query);
pgsql ->
Key = {?PREPARE_KEY, Query#sql_query.hash},
case get(Key) of
undefined ->
case pgsql_prepare(Query, State) of
{ok, _, _, _} ->
put(Key, prepared);
{error, Error} ->
?ERROR_MSG("PREPARE failed for SQL query "
"at ~p: ~p",
[Query#sql_query.loc, Error]),
put(Key, ignore)
end;
_ ->
ok
end,
case get(Key) of
prepared ->
pgsql_execute_sql_query(Query, State);
_ ->
generic_sql_query(Query)
end;
mysql ->
generic_sql_query(Query);
sqlite ->
generic_sql_query(Query)
end
catch
Class:Reason ->
ST = erlang:get_stacktrace(),
?ERROR_MSG("Internal error while processing SQL query: ~p",
[{Class, Reason, ST}]),
{error, <<"internal error">>}
end,
case Res of
{error, <<"No SQL-driver information available.">>} ->
{updated, 0};
_Else -> Res
end;
sql_query_internal(Query) ->
State = get(?STATE_KEY),
?DEBUG("SQL: \"~s\"", [Query]),
@ -495,6 +545,66 @@ sql_query_internal(Query) ->
_Else -> Res
end.
generic_sql_query(SQLQuery) ->
sql_query_format_res(
sql_query_internal(generic_sql_query_format(SQLQuery)),
SQLQuery).
generic_sql_query_format(SQLQuery) ->
Args = (SQLQuery#sql_query.args)(generic_escape()),
(SQLQuery#sql_query.format_query)(Args).
generic_escape() ->
#sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
integer = fun(X) -> integer_to_binary(X) end,
boolean = fun(true) -> <<"1">>;
(false) -> <<"0">>
end
}.
pgsql_prepare(SQLQuery, State) ->
Escape = #sql_escape{_ = fun(X) -> X end},
N = length((SQLQuery#sql_query.args)(Escape)),
Args = [<<$$, (integer_to_binary(I))/binary>> || I <- lists:seq(1, N)],
Query = (SQLQuery#sql_query.format_query)(Args),
pgsql:prepare(State#state.db_ref, SQLQuery#sql_query.hash, Query).
pgsql_execute_escape() ->
#sql_escape{string = fun(X) -> X end,
integer = fun(X) -> integer_to_binary(X) end,
boolean = fun(true) -> <<"1">>;
(false) -> <<"0">>
end
}.
pgsql_execute_sql_query(SQLQuery, State) ->
Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()),
ExecuteRes =
pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args),
Res = pgsql_execute_to_odbc(ExecuteRes),
sql_query_format_res(Res, SQLQuery).
sql_query_format_res({selected, _, Rows}, SQLQuery) ->
Res =
lists:flatmap(
fun(Row) ->
try
[(SQLQuery#sql_query.format_res)(Row)]
catch
Class:Reason ->
ST = erlang:get_stacktrace(),
?ERROR_MSG("Error while processing "
"SQL query result: ~p~n"
"row: ~p",
[{Class, Reason, ST}, Row]),
[]
end
end, Rows),
{selected, Res};
sql_query_format_res(Res, _SQLQuery) ->
Res.
%% Generate the OTP callback return tuple depending on the driver result.
abort_on_driver_error({error, <<"query timed out">>} =
Reply,
@ -606,6 +716,18 @@ pgsql_item_to_odbc(<<"UPDATE ", N/binary>>) ->
pgsql_item_to_odbc({error, Error}) -> {error, Error};
pgsql_item_to_odbc(_) -> {updated, undefined}.
pgsql_execute_to_odbc({ok, {<<"SELECT", _/binary>>, Rows}}) ->
{selected, [], [[Field || {_, Field} <- Row] || Row <- Rows]};
pgsql_execute_to_odbc({ok, {'INSERT', N}}) ->
{updated, N};
pgsql_execute_to_odbc({ok, {'DELETE', N}}) ->
{updated, N};
pgsql_execute_to_odbc({ok, {'UPDATE', N}}) ->
{updated, N};
pgsql_execute_to_odbc({error, Error}) -> {error, Error};
pgsql_execute_to_odbc(_) -> {updated, undefined}.
%% == Native MySQL code
%% part of init/1
@ -800,6 +922,10 @@ fsm_limit_opts() ->
_ -> []
end.
check_error({error, Why} = Err, #sql_query{} = Query) ->
?ERROR_MSG("SQL query '~s' at ~p failed: ~p",
[Query#sql_query.hash, Query#sql_query.loc, Why]),
Err;
check_error({error, Why} = Err, Query) ->
?ERROR_MSG("SQL query '~s' failed: ~p", [Query, Why]),
Err;

255
src/ejabberd_sql_pt.erl Normal file
View File

@ -0,0 +1,255 @@
%%%-------------------------------------------------------------------
%%% File : ejabberd_sql_pt.erl
%%% Author : Alexey Shchepin <alexey@process-one.net>
%%% Description : Parse transform for SQL queries
%%%
%%% Created : 20 Jan 2016 by Alexey Shchepin <alexey@process-one.net>
%%%-------------------------------------------------------------------
-module(ejabberd_sql_pt).
%% API
-export([parse_transform/2]).
-export([parse/2]).
-include("ejabberd_sql_pt.hrl").
-record(state, {loc,
'query' = [],
params = [],
param_pos = 0,
args = [],
res = [],
res_vars = [],
res_pos = 0}).
-define(QUERY_RECORD, "sql_query").
-define(ESCAPE_RECORD, "sql_escape").
-define(ESCAPE_VAR, "__SQLEscape").
-define(MOD, sql__module_).
%%====================================================================
%% API
%%====================================================================
%%--------------------------------------------------------------------
%% Function:
%% Description:
%%--------------------------------------------------------------------
parse_transform(AST, _Options) ->
%io:format("PT: ~p~nOpts: ~p~n", [AST, Options]),
NewAST = top_transform(AST),
%io:format("NewPT: ~p~n", [NewAST]),
NewAST.
%%====================================================================
%% Internal functions
%%====================================================================
transform(Form) ->
case erl_syntax:type(Form) of
application ->
case erl_syntax_lib:analyze_application(Form) of
{?SQL_MARK, 1} ->
case erl_syntax:application_arguments(Form) of
[Arg] ->
case erl_syntax:type(Arg) of
string ->
S = erl_syntax:string_value(Arg),
ParseRes =
parse(S, erl_syntax:get_pos(Arg)),
make_sql_query(ParseRes);
_ ->
throw({error, erl_syntax:get_pos(Form),
"?SQL argument must be "
"a constant string"})
end;
_ ->
throw({error, erl_syntax:get_pos(Form),
"wrong number of ?SQL args"})
end;
_ ->
Form
end;
attribute ->
case erl_syntax:atom_value(erl_syntax:attribute_name(Form)) of
module ->
case erl_syntax:attribute_arguments(Form) of
[M | _] ->
Module = erl_syntax:atom_value(M),
%io:format("module ~p~n", [Module]),
put(?MOD, Module),
Form;
_ ->
Form
end;
_ ->
Form
end;
_ ->
Form
end.
top_transform(Forms) when is_list(Forms) ->
lists:map(
fun(Form) ->
try
Form2 = erl_syntax_lib:map(
fun(Node) ->
%io:format("asd ~p~n", [Node]),
transform(Node)
end, Form),
Form3 = erl_syntax:revert(Form2),
Form3
catch
throw:{error, Line, Error} ->
{error, {Line, erl_parse, Error}}
end
end, Forms).
parse(S, Loc) ->
parse1(S, [], #state{loc = Loc}).
parse1([], Acc, State) ->
State1 = append_string(lists:reverse(Acc), State),
State1#state{'query' = lists:reverse(State1#state.'query'),
params = lists:reverse(State1#state.params),
args = lists:reverse(State1#state.args),
res = lists:reverse(State1#state.res),
res_vars = lists:reverse(State1#state.res_vars)
};
parse1([$@, $( | S], Acc, State) ->
State1 = append_string(lists:reverse(Acc), State),
{Name, Type, S1, State2} = parse_name(S, State1),
Var = "__V" ++ integer_to_list(State2#state.res_pos),
EVar = erl_syntax:variable(Var),
Convert =
case Type of
integer ->
erl_syntax:application(
erl_syntax:atom(binary_to_integer),
[EVar]);
string ->
EVar;
boolean ->
erl_syntax:application(
erl_syntax:atom(ejabberd_odbc),
erl_syntax:atom(to_bool),
[EVar])
end,
State3 = append_string(Name, State2),
State4 = State3#state{res_pos = State3#state.res_pos + 1,
res = [Convert | State3#state.res],
res_vars = [EVar | State3#state.res_vars]},
parse1(S1, [], State4);
parse1([$%, $( | S], Acc, State) ->
State1 = append_string(lists:reverse(Acc), State),
{Name, Type, S1, State2} = parse_name(S, State1),
Var = "__V" ++ integer_to_list(State2#state.param_pos),
EVar = erl_syntax:variable(Var),
Convert =
erl_syntax:application(
erl_syntax:record_access(
erl_syntax:variable(?ESCAPE_VAR),
erl_syntax:atom(?ESCAPE_RECORD),
erl_syntax:atom(Type)),
[erl_syntax:variable(Name)]),
State3 = State2,
State4 =
State3#state{'query' = [{var, EVar} | State3#state.'query'],
args = [Convert | State3#state.args],
params = [EVar | State3#state.params],
param_pos = State3#state.param_pos + 1},
parse1(S1, [], State4);
parse1([C | S], Acc, State) ->
parse1(S, [C | Acc], State).
append_string([], State) ->
State;
append_string(S, State) ->
State#state{query = [{str, S} | State#state.query]}.
parse_name(S, State) ->
parse_name(S, [], State).
parse_name([], Acc, State) ->
% todo
error;
parse_name([$), T | S], Acc, State) ->
Type =
case T of
$d -> integer;
$s -> string;
$b -> boolean;
_ ->
% todo
error
end,
{lists:reverse(Acc), Type, S, State};
parse_name([$) | _], Acc, State) ->
% todo
error;
parse_name([C | S], Acc, State) ->
parse_name(S, [C | Acc], State).
make_sql_query(State) ->
Hash = erlang:phash2(State#state{loc = undefined}),
SHash = <<"Q", (integer_to_binary(Hash))/binary>>,
Query = pack_query(State#state.'query'),
EQuery =
lists:map(
fun({str, S}) ->
erl_syntax:binary(
[erl_syntax:binary_field(
erl_syntax:string(S))]);
({var, V}) -> V
end, Query),
erl_syntax:record_expr(
erl_syntax:atom(?QUERY_RECORD),
[erl_syntax:record_field(
erl_syntax:atom(hash),
%erl_syntax:abstract(SHash)
erl_syntax:binary(
[erl_syntax:binary_field(
erl_syntax:string(binary_to_list(SHash)))])),
erl_syntax:record_field(
erl_syntax:atom(args),
erl_syntax:fun_expr(
[erl_syntax:clause(
[erl_syntax:variable(?ESCAPE_VAR)],
none,
[erl_syntax:list(State#state.args)]
)])),
erl_syntax:record_field(
erl_syntax:atom(format_query),
erl_syntax:fun_expr(
[erl_syntax:clause(
[erl_syntax:list(State#state.params)],
none,
[erl_syntax:list(EQuery)]
)])),
erl_syntax:record_field(
erl_syntax:atom(format_res),
erl_syntax:fun_expr(
[erl_syntax:clause(
[erl_syntax:list(State#state.res_vars)],
none,
[erl_syntax:tuple(State#state.res)]
)])),
erl_syntax:record_field(
erl_syntax:atom(loc),
erl_syntax:abstract({get(?MOD), State#state.loc}))
]).
pack_query([]) ->
[];
pack_query([{str, S1}, {str, S2} | Rest]) ->
pack_query([{str, S1 ++ S2} | Rest]);
pack_query([X | Rest]) ->
[X | pack_query(Rest)].