Fix handling of list arguments on pgsql

This commit is contained in:
Paweł Chmielowski 2019-04-23 17:46:14 +02:00
parent feb4c7f5e9
commit d2ea905926
3 changed files with 81 additions and 21 deletions

View File

@ -32,5 +32,4 @@
-record(sql_query, {hash, format_query, format_res, args, loc}). -record(sql_query, {hash, format_query, format_res, args, loc}).
-record(sql_escape, {string, integer, boolean}). -record(sql_escape, {string, integer, boolean, in_array_string}).

View File

@ -56,7 +56,8 @@
odbcinst_config/0, odbcinst_config/0,
init_mssql/1, init_mssql/1,
keep_alive/2, keep_alive/2,
to_list/2]). to_list/2,
to_array/2]).
%% gen_fsm callbacks %% gen_fsm callbacks
-export([init/1, handle_event/3, handle_sync_event/4, -export([init/1, handle_event/3, handle_sync_event/4,
@ -264,6 +265,10 @@ to_list(EscapeFun, Val) ->
Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)), Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
[<<"(">>, Escaped, <<")">>]. [<<"(">>, Escaped, <<")">>].
to_array(EscapeFun, Val) ->
Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
[<<"{">>, Escaped, <<"}">>].
encode_term(Term) -> encode_term(Term) ->
escape(list_to_binary( escape(list_to_binary(
erl_prettypr:format(erl_syntax:abstract(Term), erl_prettypr:format(erl_syntax:abstract(Term),
@ -676,10 +681,11 @@ generic_sql_query_format(SQLQuery) ->
generic_escape() -> generic_escape() ->
#sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end, #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
integer = fun(X) -> misc:i2l(X) end, integer = fun(X) -> misc:i2l(X) end,
boolean = fun(true) -> <<"1">>; boolean = fun(true) -> <<"1">>;
(false) -> <<"0">> (false) -> <<"0">>
end end,
in_array_string = fun(X) -> <<"'", (escape(X))/binary, "'">> end
}. }.
sqlite_sql_query(SQLQuery) -> sqlite_sql_query(SQLQuery) ->
@ -693,10 +699,11 @@ sqlite_sql_query_format(SQLQuery) ->
sqlite_escape() -> sqlite_escape() ->
#sql_escape{string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end, #sql_escape{string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end,
integer = fun(X) -> misc:i2l(X) end, integer = fun(X) -> misc:i2l(X) end,
boolean = fun(true) -> <<"1">>; boolean = fun(true) -> <<"1">>;
(false) -> <<"0">> (false) -> <<"0">>
end end,
in_array_string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end
}. }.
standard_escape(S) -> standard_escape(S) ->
@ -717,10 +724,11 @@ pgsql_prepare(SQLQuery, State) ->
pgsql_execute_escape() -> pgsql_execute_escape() ->
#sql_escape{string = fun(X) -> X end, #sql_escape{string = fun(X) -> X end,
integer = fun(X) -> [misc:i2l(X)] end, integer = fun(X) -> [misc:i2l(X)] end,
boolean = fun(true) -> "1"; boolean = fun(true) -> "1";
(false) -> "0" (false) -> "0"
end end,
in_array_string = fun(X) -> <<"\"", (escape(X))/binary, "\"">> end
}. }.
pgsql_execute_sql_query(SQLQuery, State) -> pgsql_execute_sql_query(SQLQuery, State) ->

View File

@ -42,7 +42,8 @@
res_pos = 0, res_pos = 0,
server_host_used = false, server_host_used = false,
used_vars = [], used_vars = [],
use_new_schema}). use_new_schema,
need_array_pass = false}).
-define(QUERY_RECORD, "sql_query"). -define(QUERY_RECORD, "sql_query").
@ -183,12 +184,24 @@ transform_sql(Arg) ->
Pos, no_server_host), Pos, no_server_host),
[] []
end, end,
set_pos( case ParseRes#state.need_array_pass of
make_schema_check( true ->
make_sql_query(ParseRes), {PR1, PR2} = perform_array_pass(ParseRes),
make_sql_query(ParseResOld) {PRO1, PRO2} = perform_array_pass(ParseResOld),
), set_pos(make_schema_check(
Pos). erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PR2)]),
erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PR1)])]),
erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PRO2)]),
erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PRO1)])])),
Pos);
false ->
set_pos(
make_schema_check(
make_sql_query(ParseRes),
make_sql_query(ParseResOld)
),
Pos)
end.
transform_upsert(Form, TableArg, FieldsArg) -> transform_upsert(Form, TableArg, FieldsArg) ->
Table = erl_syntax:string_value(TableArg), Table = erl_syntax:string_value(TableArg),
@ -315,8 +328,23 @@ parse1([$%, $( | S], Acc, State) ->
erl_syntax:atom(?ESCAPE_RECORD), erl_syntax:atom(?ESCAPE_RECORD),
erl_syntax:atom(InternalType)), erl_syntax:atom(InternalType)),
erl_syntax:variable(Name)]), erl_syntax:variable(Name)]),
State2#state{'query' = [{var, Var} | State2#state.'query'], IT2 = case InternalType of
args = [Convert | State2#state.args], string ->
in_array_string;
_ ->
InternalType
end,
ConvertArr = erl_syntax:application(
erl_syntax:atom(ejabberd_sql),
erl_syntax:atom(to_array),
[erl_syntax:record_access(
erl_syntax:variable(?ESCAPE_VAR),
erl_syntax:atom(?ESCAPE_RECORD),
erl_syntax:atom(IT2)),
erl_syntax:variable(Name)]),
State2#state{'query' = [[{var, Var}] | State2#state.'query'],
need_array_pass = true,
args = [[Convert, ConvertArr] | State2#state.args],
params = [Var | State2#state.params], params = [Var | State2#state.params],
param_pos = State2#state.param_pos + 1, param_pos = State2#state.param_pos + 1,
used_vars = [Name | State2#state.used_vars]}; used_vars = [Name | State2#state.used_vars]};
@ -389,6 +417,31 @@ make_var(V) ->
Var = "__V" ++ integer_to_list(V), Var = "__V" ++ integer_to_list(V),
erl_syntax:variable(Var). erl_syntax:variable(Var).
perform_array_pass(State) ->
{NQ, PQ, Rest} = lists:foldl(
fun([{var, _} = Var], {N, P, {str, Str} = Prev}) ->
Str2 = re:replace(Str, "(^|\s+)in\s*$", " = any(", [{return, list}]),
{[Var, Prev | N], [{str, ")"}, Var, {str, Str2} | P], none};
([{var, _}], _) ->
throw({error, State#state.loc, ["List variable not following 'in' operator"]});
(Other, {N, P, none}) ->
{N, P, Other};
(Other, {N, P, Prev}) ->
{[Prev | N], [Prev | P], Other}
end, {[], [], none}, State#state.query),
{NQ2, PQ2} = case Rest of
none ->
{NQ, PQ};
_ -> {[Rest | NQ], [Rest | PQ]}
end,
{NA, PA} = lists:foldl(
fun([V1, V2], {N, P}) ->
{[V1 | N], [V2 | P]};
(Other, {N, P}) ->
{[Other | N], [Other | P]}
end, {[], []}, State#state.args),
{State#state{query = lists:reverse(NQ2), args = lists:reverse(NA), need_array_pass = false},
State#state{query = lists:reverse(PQ2), args = lists:reverse(PA), need_array_pass = false}}.
make_sql_query(State) -> make_sql_query(State) ->
Hash = erlang:phash2(State#state{loc = undefined, use_new_schema = true}), Hash = erlang:phash2(State#state{loc = undefined, use_new_schema = true}),