From 00c76003cbd3ad81762bd581e878f15b6fc51036 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Chmielowski?= Date: Fri, 18 Aug 2023 11:46:26 +0200 Subject: [PATCH] Add ability to force alternative upsert implementation in mysql --- src/ejabberd_option.erl | 8 ++++++++ src/ejabberd_options.erl | 3 +++ src/ejabberd_sql.erl | 12 ++++++++---- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/ejabberd_option.erl b/src/ejabberd_option.erl index 81f4bab7f..6ea63e561 100644 --- a/src/ejabberd_option.erl +++ b/src/ejabberd_option.erl @@ -143,6 +143,7 @@ -export([sm_use_cache/0, sm_use_cache/1]). -export([sql_connect_timeout/0, sql_connect_timeout/1]). -export([sql_database/0, sql_database/1]). +-export([sql_flags/0, sql_flags/1]). -export([sql_keepalive_interval/0, sql_keepalive_interval/1]). -export([sql_odbc_driver/0, sql_odbc_driver/1]). -export([sql_password/0, sql_password/1]). @@ -964,6 +965,13 @@ sql_database() -> sql_database(Host) -> ejabberd_config:get_option({sql_database, Host}). +-spec sql_flags() -> ['mysql_alternative_upsert']. +sql_flags() -> + sql_flags(global). +-spec sql_flags(global | binary()) -> ['mysql_alternative_upsert']. +sql_flags(Host) -> + ejabberd_config:get_option({sql_flags, Host}). + -spec sql_keepalive_interval() -> 'undefined' | pos_integer(). sql_keepalive_interval() -> sql_keepalive_interval(global). diff --git a/src/ejabberd_options.erl b/src/ejabberd_options.erl index 06087921d..9f48839bb 100644 --- a/src/ejabberd_options.erl +++ b/src/ejabberd_options.erl @@ -424,6 +424,8 @@ opt_type(sql_username) -> econf:binary(); opt_type(sql_prepared_statements) -> econf:bool(); +opt_type(sql_flags) -> + econf:list_or_single(econf:enum([mysql_alternative_upsert]), [sorted, unique]); opt_type(trusted_proxies) -> econf:either(all, econf:list(econf:ip_mask())); opt_type(use_cache) -> @@ -708,6 +710,7 @@ options() -> {sql_start_interval, timer:seconds(30)}, {sql_username, <<"ejabberd">>}, {sql_prepared_statements, true}, + {sql_flags, []}, {trusted_proxies, []}, {validate_stream, false}, {websocket_origin, []}, diff --git a/src/ejabberd_sql.erl b/src/ejabberd_sql.erl index c5c06d078..a07dac67c 100644 --- a/src/ejabberd_sql.erl +++ b/src/ejabberd_sql.erl @@ -1148,7 +1148,11 @@ get_db_version(#state{db_type = pgsql} = State) -> ?WARNING_MSG("Error getting pgsql version: ~p", [Res]), State end; -get_db_version(#state{db_type = mysql} = State) -> +get_db_version(#state{db_type = mysql, host = Host} = State) -> + DefaultUpsert = case lists:member(mysql_alternative_upsert, ejabberd_option:sql_flags(Host)) of + true -> 1; + _ -> 0 + end, case mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref, [<<"select version();">>], self(), [{timeout, 5000}, @@ -1160,10 +1164,10 @@ get_db_version(#state{db_type = mysql} = State) -> V = ((bin_to_int(V1)*1000)+bin_to_int(V2))*1000+bin_to_int(V3), TypeA = binary_to_atom(Type, utf8), Flags = case TypeA of - 'MariaDB' -> 0; + 'MariaDB' -> DefaultUpsert; _ when V >= 5007026 andalso V < 8000000 -> 1; _ when V >= 8000020 -> 1; - _ -> 0 + _ -> DefaultUpsert end, State#state{db_version = {V, TypeA, Flags}}; {match, [V1, V2, V3]} -> @@ -1171,7 +1175,7 @@ get_db_version(#state{db_type = mysql} = State) -> Flags = case V of _ when V >= 5007026 andalso V < 8000000 -> 1; _ when V >= 8000020 -> 1; - _ -> 0 + _ -> DefaultUpsert end, State#state{db_version = {V, unknown, Flags}}; _ ->