diff --git a/include/ejabberd_sql_pt.hrl b/include/ejabberd_sql_pt.hrl index ca6df9ec9..f189fdcf6 100644 --- a/include/ejabberd_sql_pt.hrl +++ b/include/ejabberd_sql_pt.hrl @@ -21,6 +21,12 @@ -define(SQL_MARK, sql__mark_). -define(SQL(SQL), ?SQL_MARK(SQL)). +-define(SQL_UPSERT_MARK, sql_upsert__mark_). +-define(SQL_UPSERT(Host, Table, Fields), + ejabberd_odbc:sql_query(Host, ?SQL_UPSERT_MARK(Table, Fields))). +-define(SQL_UPSERT_T(Table, Fields), + ejabberd_odbc:sql_query_t(Host, ?SQL_UPSERT_MARK(Table, Fields))). + -record(sql_query, {hash, format_query, format_res, args, loc}). -record(sql_escape, {string, integer, boolean}). diff --git a/src/ejabberd_odbc.erl b/src/ejabberd_odbc.erl index b430d920a..4f818f513 100644 --- a/src/ejabberd_odbc.erl +++ b/src/ejabberd_odbc.erl @@ -475,6 +475,12 @@ execute_bloc(F) -> Res -> {atomic, Res} end. +execute_fun(F) when is_function(F, 0) -> + F(); +execute_fun(F) when is_function(F, 2) -> + State = get(?STATE_KEY), + F(State#state.db_type, State#state.db_version). + sql_query_internal([{_, _} | _] = Queries) -> State = get(?STATE_KEY), case select_sql_query(Queries, State) of @@ -529,6 +535,11 @@ sql_query_internal(#sql_query{} = Query) -> {updated, 0}; _Else -> Res end; +sql_query_internal(F) when is_function(F) -> + case catch execute_fun(F) of + {'EXIT', Reason} -> {error, Reason}; + Res -> Res + end; sql_query_internal(Query) -> State = get(?STATE_KEY), ?DEBUG("SQL: \"~s\"", [Query]), @@ -615,6 +626,9 @@ pgsql_execute_sql_query(SQLQuery, State) -> Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()), ExecuteRes = pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args), +% {T, ExecuteRes} = +% timer:tc(pgsql, execute, [State#state.db_ref, SQLQuery#sql_query.hash, Args]), +% io:format("T ~s ~p~n", [SQLQuery#sql_query.hash, T]), Res = pgsql_execute_to_odbc(ExecuteRes), sql_query_format_res(Res, SQLQuery). diff --git a/src/ejabberd_sql_pt.erl b/src/ejabberd_sql_pt.erl index 6b26cbcd6..cb7a82e0f 100644 --- a/src/ejabberd_sql_pt.erl +++ b/src/ejabberd_sql_pt.erl @@ -72,6 +72,26 @@ transform(Form) -> throw({error, erl_syntax:get_pos(Form), "wrong number of ?SQL args"}) end; + {?SQL_UPSERT_MARK, 2} -> + case erl_syntax:application_arguments(Form) of + [TableArg, FieldsArg] -> + case {erl_syntax:type(TableArg), + erl_syntax:is_proper_list(FieldsArg)}of + {string, true} -> + Table = erl_syntax:string_value(TableArg), + ParseRes = + parse_upsert( + erl_syntax:list_elements(FieldsArg)), + make_sql_upsert(Table, ParseRes); + _ -> + throw({error, erl_syntax:get_pos(Form), + "?SQL_UPSERT arguments must be " + "a constant string and a list"}) + end; + _ -> + throw({error, erl_syntax:get_pos(Form), + "wrong number of ?SQL_UPSERT args"}) + end; _ -> Form end; @@ -114,6 +134,9 @@ top_transform(Forms) when is_list(Forms) -> parse(S, Loc) -> parse1(S, [], #state{loc = Loc}). +parse(S, ParamPos, Loc) -> + parse1(S, [], #state{loc = Loc, param_pos = ParamPos}). + parse1([], Acc, State) -> State1 = append_string(lists:reverse(Acc), State), State1#state{'query' = lists:reverse(State1#state.'query'), @@ -149,8 +172,7 @@ parse1([$@, $( | S], Acc, State) -> parse1([$%, $( | S], Acc, State) -> State1 = append_string(lists:reverse(Acc), State), {Name, Type, S1, State2} = parse_name(S, State1), - Var = "__V" ++ integer_to_list(State2#state.param_pos), - EVar = erl_syntax:variable(Var), + Var = State2#state.param_pos, Convert = erl_syntax:application( erl_syntax:record_access( @@ -160,9 +182,9 @@ parse1([$%, $( | S], Acc, State) -> [erl_syntax:variable(Name)]), State3 = State2, State4 = - State3#state{'query' = [{var, EVar} | State3#state.'query'], + State3#state{'query' = [{var, Var} | State3#state.'query'], args = [Convert | State3#state.args], - params = [EVar | State3#state.params], + params = [Var | State3#state.params], param_pos = State3#state.param_pos + 1}, parse1(S1, [], State4); parse1([C | S], Acc, State) -> @@ -190,7 +212,7 @@ parse_name([$), T | S], Acc, 0, State) -> ["unknown type specifier '", T, "'"]}) end, {lists:reverse(Acc), Type, S, State}; -parse_name([$)], Acc, 0, State) -> +parse_name([$)], _Acc, 0, State) -> throw({error, State#state.loc, "expected type specifier, found end of string"}); parse_name([$( = C | S], Acc, Depth, State) -> @@ -201,6 +223,11 @@ parse_name([C | S], Acc, Depth, State) -> parse_name(S, [C | Acc], Depth, State). +make_var(V) -> + Var = "__V" ++ integer_to_list(V), + erl_syntax:variable(Var). + + make_sql_query(State) -> Hash = erlang:phash2(State#state{loc = undefined}), SHash = <<"Q", (integer_to_binary(Hash))/binary>>, @@ -211,7 +238,7 @@ make_sql_query(State) -> erl_syntax:binary( [erl_syntax:binary_field( erl_syntax:string(S))]); - ({var, V}) -> V + ({var, V}) -> make_var(V) end, Query), erl_syntax:record_expr( erl_syntax:atom(?QUERY_RECORD), @@ -233,7 +260,7 @@ make_sql_query(State) -> erl_syntax:atom(format_query), erl_syntax:fun_expr( [erl_syntax:clause( - [erl_syntax:list(State#state.params)], + [erl_syntax:list(lists:map(fun make_var/1, State#state.params))], none, [erl_syntax:list(EQuery)] )])), @@ -257,3 +284,232 @@ pack_query([{str, S1}, {str, S2} | Rest]) -> pack_query([X | Rest]) -> [X | pack_query(Rest)]. + +parse_upsert(Fields) -> + {Fs, _} = + lists:foldr( + fun(F, {Acc, Param}) -> + case erl_syntax:type(F) of + string -> + V = erl_syntax:string_value(F), + {_, _, State} = Res = + parse_upsert_field( + V, Param, erl_syntax:get_pos(F)), + {[Res | Acc], State#state.param_pos}; + _ -> + throw({error, erl_syntax:get_pos(F), + "?SQL_UPSERT field must be " + "a constant string"}) + end + end, {[], 0}, Fields), + %io:format("asd ~p~n", [{Fields, Fs}]), + Fs. + +parse_upsert_field([$! | S], ParamPos, Loc) -> + {Name, ParseState} = parse_upsert_field1(S, [], ParamPos, Loc), + {Name, true, ParseState}; +parse_upsert_field(S, ParamPos, Loc) -> + {Name, ParseState} = parse_upsert_field1(S, [], ParamPos, Loc), + {Name, false, ParseState}. + +parse_upsert_field1([], _Acc, _ParamPos, Loc) -> + throw({error, Loc, + "?SQL_UPSERT fields must have the " + "following form: \"[!]name=value\""}); +parse_upsert_field1([$= | S], Acc, ParamPos, Loc) -> + {lists:reverse(Acc), parse(S, ParamPos, Loc)}; +parse_upsert_field1([C | S], Acc, ParamPos, Loc) -> + parse_upsert_field1(S, [C | Acc], ParamPos, Loc). + + +make_sql_upsert(Table, ParseRes) -> + erl_syntax:fun_expr( + [erl_syntax:clause( + [erl_syntax:atom(pgsql), erl_syntax:variable("__Version")], + [erl_syntax:infix_expr( + erl_syntax:variable("__Version"), + erl_syntax:operator('>='), + erl_syntax:integer(90100))], + [make_sql_upsert_pgsql901(Table, ParseRes), + erl_syntax:atom(ok)]), + erl_syntax:clause( + [erl_syntax:underscore(), erl_syntax:underscore()], + none, + [make_sql_upsert_generic(Table, ParseRes)]) + ]). + +make_sql_upsert_generic(Table, ParseRes) -> + Update = make_sql_query(make_sql_upsert_update(Table, ParseRes)), + Insert = make_sql_query(make_sql_upsert_insert(Table, ParseRes)), + InsertBranch = + erl_syntax:case_expr( + erl_syntax:application( + erl_syntax:atom(ejabberd_odbc), + erl_syntax:atom(sql_query_t), + [Insert]), + [erl_syntax:clause( + [erl_syntax:abstract({updated, 1})], + none, + [erl_syntax:atom(ok)]), + erl_syntax:clause( + [erl_syntax:variable("__UpdateRes")], + none, + [erl_syntax:variable("__UpdateRes")])]), + erl_syntax:case_expr( + erl_syntax:application( + erl_syntax:atom(ejabberd_odbc), + erl_syntax:atom(sql_query_t), + [Update]), + [erl_syntax:clause( + [erl_syntax:abstract({updated, 1})], + none, + [erl_syntax:atom(ok)]), + erl_syntax:clause( + [erl_syntax:underscore()], + none, + [InsertBranch])]). + +make_sql_upsert_update(Table, ParseRes) -> + WPairs = + lists:flatmap( + fun({_Field, false, _ST}) -> + []; + ({Field, true, ST}) -> + [ST#state{ + 'query' = [{str, Field}, {str, "="}] ++ ST#state.'query' + }] + end, ParseRes), + Where = join_states(WPairs, " AND "), + SPairs = + lists:flatmap( + fun({_Field, true, _ST}) -> + []; + ({Field, false, ST}) -> + [ST#state{ + 'query' = [{str, Field}, {str, "="}] ++ ST#state.'query' + }] + end, ParseRes), + Set = join_states(SPairs, ", "), + State = + concat_states( + [#state{'query' = [{str, "UPDATE "}, {str, Table}, {str, " SET "}]}, + Set, + #state{'query' = [{str, " WHERE "}]}, + Where + ]), + State. + +make_sql_upsert_insert(Table, ParseRes) -> + Vals = + lists:map( + fun({_Field, _, ST}) -> + ST + end, ParseRes), + Fields = + lists:map( + fun({Field, _, _ST}) -> + #state{'query' = [{str, Field}]} + end, ParseRes), + State = + concat_states( + [#state{'query' = [{str, "INSERT INTO "}, {str, Table}, {str, "("}]}, + join_states(Fields, ", "), + #state{'query' = [{str, ") VALUES ("}]}, + join_states(Vals, ", "), + #state{'query' = [{str, ")"}]} + ]), + State. + +make_sql_upsert_pgsql901(Table, ParseRes) -> + Update = make_sql_upsert_update(Table, ParseRes), + Vals = + lists:map( + fun({_Field, _, ST}) -> + ST + end, ParseRes), + Fields = + lists:map( + fun({Field, _, _ST}) -> + #state{'query' = [{str, Field}]} + end, ParseRes), + Insert = + concat_states( + [#state{'query' = [{str, "INSERT INTO "}, {str, Table}, {str, "("}]}, + join_states(Fields, ", "), + #state{'query' = [{str, ") SELECT "}]}, + join_states(Vals, ", "), + #state{'query' = [{str, " WHERE NOT EXISTS (SELECT * FROM upsert)"}]} + ]), + State = + concat_states( + [#state{'query' = [{str, "WITH upsert AS ("}]}, + Update, + #state{'query' = [{str, " RETURNING *) "}]}, + Insert + ]), + Upsert = make_sql_query(State), + erl_syntax:application( + erl_syntax:atom(ejabberd_odbc), + erl_syntax:atom(sql_query_t), + [Upsert]). + + +concat_states(States) -> + lists:foldr( + fun(ST11, ST2) -> + ST1 = resolve_vars(ST11, ST2), + ST1#state{ + 'query' = ST1#state.'query' ++ ST2#state.'query', + params = ST1#state.params ++ ST2#state.params, + args = ST1#state.args ++ ST2#state.args, + res = ST1#state.res ++ ST2#state.res, + res_vars = ST1#state.res_vars ++ ST2#state.res_vars, + loc = case ST1#state.loc of + undefined -> ST2#state.loc; + _ -> ST1#state.loc + end + } + end, #state{}, States). + +resolve_vars(ST1, ST2) -> + Max = lists:max([0 | ST1#state.params ++ ST2#state.params]), + {Map, _} = + lists:foldl( + fun(Var, {Acc, New}) -> + case lists:member(Var, ST2#state.params) of + true -> + {dict:store(Var, New, Acc), New + 1}; + false -> + {Acc, New} + end + end, {dict:new(), Max + 1}, ST1#state.params), + NewParams = + lists:map( + fun(Var) -> + case dict:find(Var, Map) of + {ok, New} -> + New; + error -> + Var + end + end, ST1#state.params), + NewQuery = + lists:map( + fun({var, Var}) -> + case dict:find(Var, Map) of + {ok, New} -> + {var, New}; + error -> + {var, Var} + end; + (S) -> S + end, ST1#state.'query'), + ST1#state{params = NewParams, 'query' = NewQuery}. + + + +join_states([], _Sep) -> + #state{}; +join_states([H | T], Sep) -> + J = [[H] | [[#state{'query' = [{str, Sep}]}, X] || X <- T]], + concat_states(lists:append(J)).