diff --git a/include/ejabberd_sql_pt.hrl b/include/ejabberd_sql_pt.hrl new file mode 100644 index 000000000..ca6df9ec9 --- /dev/null +++ b/include/ejabberd_sql_pt.hrl @@ -0,0 +1,27 @@ +%%%---------------------------------------------------------------------- +%%% +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%---------------------------------------------------------------------- + +-define(SQL_MARK, sql__mark_). +-define(SQL(SQL), ?SQL_MARK(SQL)). + +-record(sql_query, {hash, format_query, format_res, args, loc}). + +-record(sql_escape, {string, integer, boolean}). + diff --git a/src/ejabberd_odbc.erl b/src/ejabberd_odbc.erl index a15c66b5d..ef3c61d0a 100644 --- a/src/ejabberd_odbc.erl +++ b/src/ejabberd_odbc.erl @@ -63,6 +63,7 @@ -include("ejabberd.hrl"). -include("logger.hrl"). +-include("ejabberd_sql_pt.hrl"). -record(state, {db_ref = self() :: pid(), @@ -92,6 +93,8 @@ -define(KEEPALIVE_QUERY, [<<"SELECT 1;">>]). +-define(PREPARE_KEY, ejabberd_odbc_prepare). + %%-define(DBGFSM, true). -ifdef(DBGFSM). @@ -116,11 +119,12 @@ start_link(Host, StartInterval) -> [Host, StartInterval], fsm_limit_opts() ++ (?FSMOPTS)). --type sql_query() :: [sql_query() | binary()]. +-type sql_query() :: [sql_query() | binary()] | #sql_query{}. -type sql_query_result() :: {updated, non_neg_integer()} | {error, binary()} | {selected, [binary()], - [[binary()]]}. + [[binary()]]} | + {selected, [any]}. -spec sql_query(binary(), sql_query()) -> sql_query_result(). @@ -469,6 +473,52 @@ execute_bloc(F) -> Res -> {atomic, Res} end. +sql_query_internal(#sql_query{} = Query) -> + State = get(?STATE_KEY), + Res = + try + case State#state.db_type of + odbc -> + generic_sql_query(Query); + pgsql -> + Key = {?PREPARE_KEY, Query#sql_query.hash}, + case get(Key) of + undefined -> + case pgsql_prepare(Query, State) of + {ok, _, _, _} -> + put(Key, prepared); + {error, Error} -> + ?ERROR_MSG("PREPARE failed for SQL query " + "at ~p: ~p", + [Query#sql_query.loc, Error]), + put(Key, ignore) + end; + _ -> + ok + end, + case get(Key) of + prepared -> + pgsql_execute_sql_query(Query, State); + _ -> + generic_sql_query(Query) + end; + mysql -> + generic_sql_query(Query); + sqlite -> + generic_sql_query(Query) + end + catch + Class:Reason -> + ST = erlang:get_stacktrace(), + ?ERROR_MSG("Internal error while processing SQL query: ~p", + [{Class, Reason, ST}]), + {error, <<"internal error">>} + end, + case Res of + {error, <<"No SQL-driver information available.">>} -> + {updated, 0}; + _Else -> Res + end; sql_query_internal(Query) -> State = get(?STATE_KEY), ?DEBUG("SQL: \"~s\"", [Query]), @@ -495,6 +545,66 @@ sql_query_internal(Query) -> _Else -> Res end. +generic_sql_query(SQLQuery) -> + sql_query_format_res( + sql_query_internal(generic_sql_query_format(SQLQuery)), + SQLQuery). + +generic_sql_query_format(SQLQuery) -> + Args = (SQLQuery#sql_query.args)(generic_escape()), + (SQLQuery#sql_query.format_query)(Args). + +generic_escape() -> + #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end, + integer = fun(X) -> integer_to_binary(X) end, + boolean = fun(true) -> <<"1">>; + (false) -> <<"0">> + end + }. + +pgsql_prepare(SQLQuery, State) -> + Escape = #sql_escape{_ = fun(X) -> X end}, + N = length((SQLQuery#sql_query.args)(Escape)), + Args = [<<$$, (integer_to_binary(I))/binary>> || I <- lists:seq(1, N)], + Query = (SQLQuery#sql_query.format_query)(Args), + pgsql:prepare(State#state.db_ref, SQLQuery#sql_query.hash, Query). + +pgsql_execute_escape() -> + #sql_escape{string = fun(X) -> X end, + integer = fun(X) -> integer_to_binary(X) end, + boolean = fun(true) -> <<"1">>; + (false) -> <<"0">> + end + }. + +pgsql_execute_sql_query(SQLQuery, State) -> + Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()), + ExecuteRes = + pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args), + Res = pgsql_execute_to_odbc(ExecuteRes), + sql_query_format_res(Res, SQLQuery). + + +sql_query_format_res({selected, _, Rows}, SQLQuery) -> + Res = + lists:flatmap( + fun(Row) -> + try + [(SQLQuery#sql_query.format_res)(Row)] + catch + Class:Reason -> + ST = erlang:get_stacktrace(), + ?ERROR_MSG("Error while processing " + "SQL query result: ~p~n" + "row: ~p", + [{Class, Reason, ST}, Row]), + [] + end + end, Rows), + {selected, Res}; +sql_query_format_res(Res, _SQLQuery) -> + Res. + %% Generate the OTP callback return tuple depending on the driver result. abort_on_driver_error({error, <<"query timed out">>} = Reply, @@ -606,6 +716,18 @@ pgsql_item_to_odbc(<<"UPDATE ", N/binary>>) -> pgsql_item_to_odbc({error, Error}) -> {error, Error}; pgsql_item_to_odbc(_) -> {updated, undefined}. +pgsql_execute_to_odbc({ok, {<<"SELECT", _/binary>>, Rows}}) -> + {selected, [], [[Field || {_, Field} <- Row] || Row <- Rows]}; +pgsql_execute_to_odbc({ok, {'INSERT', N}}) -> + {updated, N}; +pgsql_execute_to_odbc({ok, {'DELETE', N}}) -> + {updated, N}; +pgsql_execute_to_odbc({ok, {'UPDATE', N}}) -> + {updated, N}; +pgsql_execute_to_odbc({error, Error}) -> {error, Error}; +pgsql_execute_to_odbc(_) -> {updated, undefined}. + + %% == Native MySQL code %% part of init/1 @@ -800,6 +922,10 @@ fsm_limit_opts() -> _ -> [] end. +check_error({error, Why} = Err, #sql_query{} = Query) -> + ?ERROR_MSG("SQL query '~s' at ~p failed: ~p", + [Query#sql_query.hash, Query#sql_query.loc, Why]), + Err; check_error({error, Why} = Err, Query) -> ?ERROR_MSG("SQL query '~s' failed: ~p", [Query, Why]), Err; diff --git a/src/ejabberd_sql_pt.erl b/src/ejabberd_sql_pt.erl new file mode 100644 index 000000000..f9701a0be --- /dev/null +++ b/src/ejabberd_sql_pt.erl @@ -0,0 +1,255 @@ +%%%------------------------------------------------------------------- +%%% File : ejabberd_sql_pt.erl +%%% Author : Alexey Shchepin +%%% Description : Parse transform for SQL queries +%%% +%%% Created : 20 Jan 2016 by Alexey Shchepin +%%%------------------------------------------------------------------- +-module(ejabberd_sql_pt). + +%% API +-export([parse_transform/2]). + +-export([parse/2]). + +-include("ejabberd_sql_pt.hrl"). + +-record(state, {loc, + 'query' = [], + params = [], + param_pos = 0, + args = [], + res = [], + res_vars = [], + res_pos = 0}). + +-define(QUERY_RECORD, "sql_query"). + +-define(ESCAPE_RECORD, "sql_escape"). +-define(ESCAPE_VAR, "__SQLEscape"). + +-define(MOD, sql__module_). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: +%% Description: +%%-------------------------------------------------------------------- +parse_transform(AST, _Options) -> + %io:format("PT: ~p~nOpts: ~p~n", [AST, Options]), + NewAST = top_transform(AST), + %io:format("NewPT: ~p~n", [NewAST]), + NewAST. + + + +%%==================================================================== +%% Internal functions +%%==================================================================== + + +transform(Form) -> + case erl_syntax:type(Form) of + application -> + case erl_syntax_lib:analyze_application(Form) of + {?SQL_MARK, 1} -> + case erl_syntax:application_arguments(Form) of + [Arg] -> + case erl_syntax:type(Arg) of + string -> + S = erl_syntax:string_value(Arg), + ParseRes = + parse(S, erl_syntax:get_pos(Arg)), + make_sql_query(ParseRes); + _ -> + throw({error, erl_syntax:get_pos(Form), + "?SQL argument must be " + "a constant string"}) + end; + _ -> + throw({error, erl_syntax:get_pos(Form), + "wrong number of ?SQL args"}) + end; + _ -> + Form + end; + attribute -> + case erl_syntax:atom_value(erl_syntax:attribute_name(Form)) of + module -> + case erl_syntax:attribute_arguments(Form) of + [M | _] -> + Module = erl_syntax:atom_value(M), + %io:format("module ~p~n", [Module]), + put(?MOD, Module), + Form; + _ -> + Form + end; + _ -> + Form + end; + _ -> + Form + end. + +top_transform(Forms) when is_list(Forms) -> + lists:map( + fun(Form) -> + try + Form2 = erl_syntax_lib:map( + fun(Node) -> + %io:format("asd ~p~n", [Node]), + transform(Node) + end, Form), + Form3 = erl_syntax:revert(Form2), + Form3 + catch + throw:{error, Line, Error} -> + {error, {Line, erl_parse, Error}} + end + end, Forms). + +parse(S, Loc) -> + parse1(S, [], #state{loc = Loc}). + +parse1([], Acc, State) -> + State1 = append_string(lists:reverse(Acc), State), + State1#state{'query' = lists:reverse(State1#state.'query'), + params = lists:reverse(State1#state.params), + args = lists:reverse(State1#state.args), + res = lists:reverse(State1#state.res), + res_vars = lists:reverse(State1#state.res_vars) + }; +parse1([$@, $( | S], Acc, State) -> + State1 = append_string(lists:reverse(Acc), State), + {Name, Type, S1, State2} = parse_name(S, State1), + Var = "__V" ++ integer_to_list(State2#state.res_pos), + EVar = erl_syntax:variable(Var), + Convert = + case Type of + integer -> + erl_syntax:application( + erl_syntax:atom(binary_to_integer), + [EVar]); + string -> + EVar; + boolean -> + erl_syntax:application( + erl_syntax:atom(ejabberd_odbc), + erl_syntax:atom(to_bool), + [EVar]) + end, + State3 = append_string(Name, State2), + State4 = State3#state{res_pos = State3#state.res_pos + 1, + res = [Convert | State3#state.res], + res_vars = [EVar | State3#state.res_vars]}, + parse1(S1, [], State4); +parse1([$%, $( | S], Acc, State) -> + State1 = append_string(lists:reverse(Acc), State), + {Name, Type, S1, State2} = parse_name(S, State1), + Var = "__V" ++ integer_to_list(State2#state.param_pos), + EVar = erl_syntax:variable(Var), + Convert = + erl_syntax:application( + erl_syntax:record_access( + erl_syntax:variable(?ESCAPE_VAR), + erl_syntax:atom(?ESCAPE_RECORD), + erl_syntax:atom(Type)), + [erl_syntax:variable(Name)]), + State3 = State2, + State4 = + State3#state{'query' = [{var, EVar} | State3#state.'query'], + args = [Convert | State3#state.args], + params = [EVar | State3#state.params], + param_pos = State3#state.param_pos + 1}, + parse1(S1, [], State4); +parse1([C | S], Acc, State) -> + parse1(S, [C | Acc], State). + +append_string([], State) -> + State; +append_string(S, State) -> + State#state{query = [{str, S} | State#state.query]}. + +parse_name(S, State) -> + parse_name(S, [], State). + +parse_name([], Acc, State) -> + % todo + error; +parse_name([$), T | S], Acc, State) -> + Type = + case T of + $d -> integer; + $s -> string; + $b -> boolean; + _ -> + % todo + error + end, + {lists:reverse(Acc), Type, S, State}; +parse_name([$) | _], Acc, State) -> + % todo + error; +parse_name([C | S], Acc, State) -> + parse_name(S, [C | Acc], State). + + +make_sql_query(State) -> + Hash = erlang:phash2(State#state{loc = undefined}), + SHash = <<"Q", (integer_to_binary(Hash))/binary>>, + Query = pack_query(State#state.'query'), + EQuery = + lists:map( + fun({str, S}) -> + erl_syntax:binary( + [erl_syntax:binary_field( + erl_syntax:string(S))]); + ({var, V}) -> V + end, Query), + erl_syntax:record_expr( + erl_syntax:atom(?QUERY_RECORD), + [erl_syntax:record_field( + erl_syntax:atom(hash), + %erl_syntax:abstract(SHash) + erl_syntax:binary( + [erl_syntax:binary_field( + erl_syntax:string(binary_to_list(SHash)))])), + erl_syntax:record_field( + erl_syntax:atom(args), + erl_syntax:fun_expr( + [erl_syntax:clause( + [erl_syntax:variable(?ESCAPE_VAR)], + none, + [erl_syntax:list(State#state.args)] + )])), + erl_syntax:record_field( + erl_syntax:atom(format_query), + erl_syntax:fun_expr( + [erl_syntax:clause( + [erl_syntax:list(State#state.params)], + none, + [erl_syntax:list(EQuery)] + )])), + erl_syntax:record_field( + erl_syntax:atom(format_res), + erl_syntax:fun_expr( + [erl_syntax:clause( + [erl_syntax:list(State#state.res_vars)], + none, + [erl_syntax:tuple(State#state.res)] + )])), + erl_syntax:record_field( + erl_syntax:atom(loc), + erl_syntax:abstract({get(?MOD), State#state.loc})) + ]). + +pack_query([]) -> + []; +pack_query([{str, S1}, {str, S2} | Rest]) -> + pack_query([{str, S1 ++ S2} | Rest]); +pack_query([X | Rest]) -> + [X | pack_query(Rest)]. +