diff --git a/src/ejabberd_sql.erl b/src/ejabberd_sql.erl index 5a94dcf8e..e0f1e9e10 100644 --- a/src/ejabberd_sql.erl +++ b/src/ejabberd_sql.erl @@ -30,7 +30,7 @@ -behaviour(p1_fsm). %% External exports --export([start/1, start_link/2, +-export([start_link/2, sql_query/2, sql_query_t/1, sql_transaction/2, @@ -73,7 +73,6 @@ {db_ref = self() :: pid(), db_type = odbc :: pgsql | mysql | sqlite | odbc | mssql, db_version = undefined :: undefined | non_neg_integer(), - start_interval = 0 :: non_neg_integer(), host = <<"">> :: binary(), pending_requests :: p1_queue:queue()}). @@ -104,14 +103,11 @@ %%%---------------------------------------------------------------------- %%% API %%%---------------------------------------------------------------------- -start(Host) -> - p1_fsm:start(ejabberd_sql, [Host], - fsm_limit_opts() ++ (?FSMOPTS)). - -start_link(Host, StartInterval) -> - p1_fsm:start_link(ejabberd_sql, - [Host, StartInterval], - fsm_limit_opts() ++ (?FSMOPTS)). +-spec start_link(binary(), pos_integer()) -> {ok, pid()} | {error, term()}. +start_link(Host, I) -> + Proc = binary_to_atom(get_worker_name(Host, I), utf8), + p1_fsm:start_link({local, Proc}, ?MODULE, [Host], + fsm_limit_opts() ++ ?FSMOPTS). -type sql_query_simple() :: [sql_query() | binary()] | #sql_query{} | fun(() -> any()) | fun((atom(), _) -> any()). @@ -154,19 +150,17 @@ sql_bloc(Host, F) -> sql_call(Host, {sql_bloc, F}). sql_call(Host, Msg) -> case get(?STATE_KEY) of - undefined -> - case ejabberd_sql_sup:get_random_pid(Host) of - none -> {error, <<"Unknown Host">>}; - Pid -> - sync_send_event(Pid,{sql_cmd, Msg, - erlang:monotonic_time(millisecond)}, - query_timeout(Host)) - end; - _State -> nested_op(Msg) + undefined -> + Proc = get_worker(Host), + sync_send_event(Proc, {sql_cmd, Msg, + erlang:monotonic_time(millisecond)}, + query_timeout(Host)); + _State -> + nested_op(Msg) end. -keep_alive(Host, PID) -> - case sync_send_event(PID, +keep_alive(Host, Proc) -> + case sync_send_event(Proc, {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}, erlang:monotonic_time(millisecond)}, query_timeout(Host)) of @@ -174,11 +168,11 @@ keep_alive(Host, PID) -> ok; _Err -> ?ERROR_MSG("Keep alive query failed, closing connection: ~p", [_Err]), - sync_send_event(PID, force_timeout, query_timeout(Host)) + sync_send_event(Proc, force_timeout, query_timeout(Host)) end. -sync_send_event(Pid, Msg, Timeout) -> - try p1_fsm:sync_send_event(Pid, Msg, Timeout) +sync_send_event(Proc, Msg, Timeout) -> + try p1_fsm:sync_send_event(Proc, Msg, Timeout) catch _:{Reason, {p1_fsm, _, _}} -> {error, Reason} end. @@ -310,10 +304,20 @@ sqlite_file(Host) -> use_new_schema() -> ejabberd_option:new_sql_schema(). +-spec get_worker(binary()) -> atom(). +get_worker(Host) -> + PoolSize = ejabberd_option:sql_pool_size(Host), + I = p1_rand:round_robin(PoolSize) + 1, + binary_to_existing_atom(get_worker_name(Host, I), utf8). + +-spec get_worker_name(binary(), pos_integer()) -> binary(). +get_worker_name(Host, I) -> + <<"ejabberd_sql_", Host/binary, $_, (integer_to_binary(I))/binary>>. + %%%---------------------------------------------------------------------- %%% Callback functions from gen_fsm %%%---------------------------------------------------------------------- -init([Host, StartInterval]) -> +init([Host]) -> process_flag(trap_exit, true), case ejabberd_option:sql_keepalive_interval(Host) of undefined -> @@ -324,12 +328,10 @@ init([Host, StartInterval]) -> end, [DBType | _] = db_opts(Host), p1_fsm:send_event(self(), connect), - ejabberd_sql_sup:add_pid(Host, self()), QueueType = ejabberd_option:sql_queue_type(Host), {ok, connecting, #state{db_type = DBType, host = Host, - pending_requests = p1_queue:new(QueueType, max_fsm_queue()), - start_interval = StartInterval}}. + pending_requests = p1_queue:new(QueueType, max_fsm_queue())}}. connecting(connect, #state{host = Host} = State) -> ConnectRes = case db_opts(Host) of @@ -359,13 +361,13 @@ connecting(connect, #state{host = Host} = State) -> State2 = get_db_version(State1), {next_state, session_established, State2}; {error, Reason} -> - ?WARNING_MSG("~p connection failed:~n** Reason: ~p~n** " - "Retry after: ~p seconds", - [State#state.db_type, Reason, - State#state.start_interval div 1000]), - p1_fsm:send_event_after(State#state.start_interval, - connect), - {next_state, connecting, State} + StartInterval = ejabberd_option:sql_start_interval(Host), + ?WARNING_MSG("~p connection failed:~n** Reason: ~p~n** " + "Retry after: ~B seconds", + [State#state.db_type, Reason, + StartInterval div 1000]), + p1_fsm:send_event_after(StartInterval, connect), + {next_state, connecting, State} end; connecting(Event, State) -> ?WARNING_MSG("Unexpected event in 'connecting': ~p", @@ -441,7 +443,6 @@ handle_info(Info, StateName, State) -> {next_state, StateName, State}. terminate(_Reason, _StateName, State) -> - ejabberd_sql_sup:remove_pid(State#state.host, self()), case State#state.db_type of mysql -> catch p1_mysql_conn:stop(State#state.db_ref); sqlite -> catch sqlite3:close(sqlite_db(State#state.host)); diff --git a/src/ejabberd_sql_sup.erl b/src/ejabberd_sql_sup.erl index fc16b784b..6d1e63b48 100644 --- a/src/ejabberd_sql_sup.erl +++ b/src/ejabberd_sql_sup.erl @@ -27,22 +27,11 @@ -author('alexey@process-one.net'). --export([start_link/1, init/1, add_pid/2, remove_pid/2, - get_pids/1, get_random_pid/1, reload/1]). +-export([start_link/1, init/1, reload/1, is_started/1]). -include("logger.hrl"). --include_lib("stdlib/include/ms_transform.hrl"). - --record(sql_pool, {host :: binary(), - pid :: pid()}). start_link(Host) -> - ejabberd_mnesia:create(?MODULE, sql_pool, - [{ram_copies, [node()]}, {type, bag}, - {local_content, true}, - {attributes, record_info(fields, sql_pool)}]), - F = fun () -> mnesia:delete({sql_pool, Host}) end, - mnesia:ets(F), supervisor:start_link({local, gen_mod:get_module_proc(Host, ?MODULE)}, ?MODULE, [Host]). @@ -58,61 +47,35 @@ init([Host]) -> _ -> ok end, - {ok, {{one_for_one, PoolSize * 10, 1}, - [child_spec(I, Host) || I <- lists:seq(1, PoolSize)]}}. + {ok, {{one_for_one, PoolSize * 10, 1}, child_specs(Host, PoolSize)}}. +-spec reload(binary()) -> ok. reload(Host) -> - Type = ejabberd_option:sql_type(Host), - NewPoolSize = get_pool_size(Type, Host), - OldPoolSize = ets:select_count( - sql_pool, - ets:fun2ms( - fun(#sql_pool{host = H}) when H == Host -> - true - end)), - reload(Host, NewPoolSize, OldPoolSize). - -reload(Host, NewPoolSize, OldPoolSize) -> - Sup = gen_mod:get_module_proc(Host, ?MODULE), - if NewPoolSize == OldPoolSize -> - ok; - NewPoolSize > OldPoolSize -> + case is_started(Host) of + true -> + Sup = gen_mod:get_module_proc(Host, ?MODULE), + Type = ejabberd_option:sql_type(Host), + PoolSize = get_pool_size(Type, Host), lists:foreach( - fun(I) -> - Spec = child_spec(I, Host), + fun(Spec) -> supervisor:start_child(Sup, Spec) - end, lists:seq(OldPoolSize+1, NewPoolSize)); - OldPoolSize > NewPoolSize -> + end, child_specs(Host, PoolSize)), lists:foreach( - fun(I) -> - supervisor:terminate_child(Sup, I), - supervisor:delete_child(Sup, I) - end, lists:seq(NewPoolSize+1, OldPoolSize)) + fun({Id, _, _, _}) when Id > PoolSize -> + case supervisor:terminate_child(Sup, Id) of + ok -> supervisor:delete_child(Sup, Id); + _ -> ok + end; + (_) -> + ok + end, supervisor:which_children(Sup)); + false -> + ok end. -get_pids(Host) -> - Rs = mnesia:dirty_read(sql_pool, Host), - [R#sql_pool.pid || R <- Rs, is_process_alive(R#sql_pool.pid)]. - -get_random_pid(Host) -> - case get_pids(Host) of - [] -> none; - Pids -> - I = p1_rand:round_robin(length(Pids)) + 1, - lists:nth(I, Pids) - end. - -add_pid(Host, Pid) -> - F = fun () -> - mnesia:write(#sql_pool{host = Host, pid = Pid}) - end, - mnesia:ets(F). - -remove_pid(Host, Pid) -> - F = fun () -> - mnesia:delete_object(#sql_pool{host = Host, pid = Pid}) - end, - mnesia:ets(F). +-spec is_started(binary()) -> boolean(). +is_started(Host) -> + whereis(gen_mod:get_module_proc(Host, ?MODULE)) /= undefined. -spec get_pool_size(atom(), binary()) -> pos_integer(). get_pool_size(SQLType, Host) -> @@ -125,10 +88,18 @@ get_pool_size(SQLType, Host) -> end, PoolSize. -child_spec(I, Host) -> - StartInterval = ejabberd_option:sql_start_interval(Host), - {I, {ejabberd_sql, start_link, [Host, StartInterval]}, - transient, 2000, worker, [?MODULE]}. +-spec child_spec(binary(), pos_integer()) -> supervisor:child_spec(). +child_spec(Host, I) -> + #{id => I, + start => {ejabberd_sql, start_link, [Host, I]}, + restart => transient, + shutdown => 2000, + type => worker, + modules => [?MODULE]}. + +-spec child_specs(binary(), pos_integer()) -> [supervisor:child_spec()]. +child_specs(Host, PoolSize) -> + [child_spec(Host, I) || I <- lists:seq(1, PoolSize)]. check_sqlite_db(Host) -> DB = ejabberd_sql:sqlite_db(Host), diff --git a/src/mod_admin_update_sql.erl b/src/mod_admin_update_sql.erl index 4f45afa94..c5d40f56d 100644 --- a/src/mod_admin_update_sql.erl +++ b/src/mod_admin_update_sql.erl @@ -75,14 +75,13 @@ get_commands_spec() -> update_sql() -> lists:foreach( fun(Host) -> - case ejabberd_sql_sup:get_pids(Host) of - [] -> + case ejabberd_sql_sup:is_started(Host) of + false -> ok; - _ -> + true -> update_sql(Host) end - end, ejabberd_option:hosts()), - ok. + end, ejabberd_option:hosts()). -record(state, {host :: binary(), dbtype :: mysql | pgsql | sqlite | mssql | odbc,