%%%------------------------------------------------------------------- %%% File : ejabberd_riak.erl %%% Author : Alexey Shchepin %%% Purpose : Interface for Riak database %%% Created : 29 Dec 2011 by Alexey Shchepin %%% %%% %%% ejabberd, Copyright (C) 2002-2019 ProcessOne %%% %%% This program is free software; you can redistribute it and/or %%% modify it under the terms of the GNU General Public License as %%% published by the Free Software Foundation; either version 2 of the %%% License, or (at your option) any later version. %%% %%% This program is distributed in the hope that it will be useful, %%% but WITHOUT ANY WARRANTY; without even the implied warranty of %%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU %%% General Public License for more details. %%% %%% You should have received a copy of the GNU General Public License along %%% with this program; if not, write to the Free Software Foundation, Inc., %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. %%% %%%------------------------------------------------------------------- -module(ejabberd_riak). -behaviour(gen_server). %% API -export([start_link/5, get_proc/1, make_bucket/1, put/2, put/3, get/2, get/3, get_by_index/4, delete/1, delete/2, count_by_index/3, get_by_index_range/5, get_keys/1, get_keys_by_index/3, is_connected/0, count/1, delete_by_index/3]). %% For debugging -export([get_tables/0]). %% map/reduce exports -export([map_key/3]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). -include("logger.hrl"). -record(state, {pid = self() :: pid()}). -type index() :: {binary(), any()}. -type index_info() :: [{i, any()} | {'2i', [index()]}]. %% The `record_schema()' is just a tuple: %% {record_info(fields, some_record), #some_record{}} -type record_schema() :: {[atom()], tuple()}. %% The `index_info()' is used in put/delete functions: %% `i' defines a primary index, `` '2i' '' defines secondary indexes. %% There must be only one primary index. If `i' is not specified, %% the first element of the record is assumed as a primary index, %% i.e. `i' = element(2, Record). -export_type([index_info/0]). %%%=================================================================== %%% API %%%=================================================================== %% @private start_link(Num, Server, Port, _StartInterval, Options) -> gen_server:start_link({local, get_proc(Num)}, ?MODULE, [Server, Port, Options], []). %% @private is_connected() -> lists:all( fun({_Id, Pid, _Type, _Modules}) when is_pid(Pid) -> case catch riakc_pb_socket:is_connected(get_riak_pid(Pid)) of true -> true; _ -> false end; (_) -> false end, supervisor:which_children(ejabberd_riak_sup)). %% @private get_proc(I) -> misc:binary_to_atom( iolist_to_binary( [atom_to_list(?MODULE), $_, integer_to_list(I)])). -spec make_bucket(atom()) -> binary(). %% @doc Makes a bucket from a table name %% @private make_bucket(Table) -> erlang:atom_to_binary(Table, utf8). -spec put(tuple(), record_schema()) -> ok | {error, any()}. %% @equiv put(Record, []) put(Record, RecFields) -> ?MODULE:put(Record, RecFields, []). -spec put(tuple(), record_schema(), index_info()) -> ok | {error, any()}. %% @doc Stores a record `Rec' with indexes described in ``IndexInfo'' put(Rec, RecSchema, IndexInfo) -> Key = encode_key(proplists:get_value(i, IndexInfo, element(2, Rec))), SecIdxs = [encode_index_key(K, V) || {K, V} <- proplists:get_value('2i', IndexInfo, [])], Table = element(1, Rec), Value = encode_record(Rec, RecSchema), case put_raw(Table, Key, Value, SecIdxs) of ok -> ok; {error, _} = Error -> log_error(Error, put, [{record, Rec}, {index_info, IndexInfo}]), Error end. put_raw(Table, Key, Value, Indexes) -> Bucket = make_bucket(Table), Obj = riakc_obj:new(Bucket, Key, Value, "application/x-erlang-term"), Obj1 = if Indexes /= [] -> MetaData = dict:store(<<"index">>, Indexes, dict:new()), riakc_obj:update_metadata(Obj, MetaData); true -> Obj end, catch riakc_pb_socket:put(get_random_pid(), Obj1). get_object_raw(Table, Key) -> Bucket = make_bucket(Table), catch riakc_pb_socket:get(get_random_pid(), Bucket, Key). -spec get(atom(), record_schema()) -> {ok, [any()]} | {error, any()}. %% @doc Returns all objects from table `Table' get(Table, RecSchema) -> Bucket = make_bucket(Table), case catch riakc_pb_socket:mapred( get_random_pid(), Bucket, [{map, {modfun, riak_kv_mapreduce, map_object_value}, none, true}]) of {ok, [{_, Objs}]} -> {ok, lists:flatmap( fun(Obj) -> case catch decode_record(Obj, RecSchema) of {'EXIT', _} -> Error = {error, make_invalid_object(Obj)}, log_error(Error, get, [{table, Table}]), []; Term -> [Term] end end, Objs)}; {ok, []} -> {ok, []}; {error, notfound} -> {ok, []}; {error, _} = Error -> Error end. -spec get(atom(), record_schema(), any()) -> {ok, any()} | {error, any()}. %% @doc Reads record by `Key' from table `Table' get(Table, RecSchema, Key) -> case get_raw(Table, encode_key(Key)) of {ok, Val} -> case catch decode_record(Val, RecSchema) of {'EXIT', _} -> Error = {error, make_invalid_object(Val)}, log_error(Error, get, [{table, Table}, {key, Key}]), {error, notfound}; Term -> {ok, Term} end; {error, _} = Error -> log_error(Error, get, [{table, Table}, {key, Key}]), Error end. -spec get_by_index(atom(), record_schema(), binary(), any()) -> {ok, [any()]} | {error, any()}. %% @doc Reads records by `Index' and value `Key' from `Table' get_by_index(Table, RecSchema, Index, Key) -> {NewIndex, NewKey} = encode_index_key(Index, Key), case get_by_index_raw(Table, NewIndex, NewKey) of {ok, Vals} -> {ok, lists:flatmap( fun(Val) -> case catch decode_record(Val, RecSchema) of {'EXIT', _} -> Error = {error, make_invalid_object(Val)}, log_error(Error, get_by_index, [{table, Table}, {index, Index}, {key, Key}]), []; Term -> [Term] end end, Vals)}; {error, notfound} -> {ok, []}; {error, _} = Error -> log_error(Error, get_by_index, [{table, Table}, {index, Index}, {key, Key}]), Error end. -spec get_by_index_range(atom(), record_schema(), binary(), any(), any()) -> {ok, [any()]} | {error, any()}. %% @doc Reads records by `Index' in the range `FromKey'..`ToKey' from `Table' get_by_index_range(Table, RecSchema, Index, FromKey, ToKey) -> {NewIndex, NewFromKey} = encode_index_key(Index, FromKey), {NewIndex, NewToKey} = encode_index_key(Index, ToKey), case get_by_index_range_raw(Table, NewIndex, NewFromKey, NewToKey) of {ok, Vals} -> {ok, lists:flatmap( fun(Val) -> case catch decode_record(Val, RecSchema) of {'EXIT', _} -> Error = {error, make_invalid_object(Val)}, log_error(Error, get_by_index_range, [{table, Table}, {index, Index}, {start_key, FromKey}, {end_key, ToKey}]), []; Term -> [Term] end end, Vals)}; {error, notfound} -> {ok, []}; {error, _} = Error -> log_error(Error, get_by_index_range, [{table, Table}, {index, Index}, {start_key, FromKey}, {end_key, ToKey}]), Error end. get_raw(Table, Key) -> case get_object_raw(Table, Key) of {ok, Obj} -> {ok, riakc_obj:get_value(Obj)}; {error, _} = Error -> Error end. -spec get_keys(atom()) -> {ok, [any()]} | {error, any()}. %% @doc Returns a list of index values get_keys(Table) -> Bucket = make_bucket(Table), case catch riakc_pb_socket:mapred( get_random_pid(), Bucket, [{map, {modfun, ?MODULE, map_key}, none, true}]) of {ok, [{_, Keys}]} -> {ok, Keys}; {ok, []} -> {ok, []}; {error, _} = Error -> log_error(Error, get_keys, [{table, Table}]), Error end. -spec get_keys_by_index(atom(), binary(), any()) -> {ok, [any()]} | {error, any()}. %% @doc Returns a list of primary keys of objects indexed by `Key'. get_keys_by_index(Table, Index, Key) -> {NewIndex, NewKey} = encode_index_key(Index, Key), Bucket = make_bucket(Table), case catch riakc_pb_socket:mapred( get_random_pid(), {index, Bucket, NewIndex, NewKey}, [{map, {modfun, ?MODULE, map_key}, none, true}]) of {ok, [{_, Keys}]} -> {ok, Keys}; {ok, []} -> {ok, []}; {error, _} = Error -> log_error(Error, get_keys_by_index, [{table, Table}, {index, Index}, {key, Key}]), Error end. %% @hidden get_tables() -> catch riakc_pb_socket:list_buckets(get_random_pid()). get_by_index_raw(Table, Index, Key) -> Bucket = make_bucket(Table), case riakc_pb_socket:mapred( get_random_pid(), {index, Bucket, Index, Key}, [{map, {modfun, riak_kv_mapreduce, map_object_value}, none, true}]) of {ok, [{_, Objs}]} -> {ok, Objs}; {ok, []} -> {ok, []}; {error, _} = Error -> Error end. get_by_index_range_raw(Table, Index, FromKey, ToKey) -> Bucket = make_bucket(Table), case catch riakc_pb_socket:mapred( get_random_pid(), {index, Bucket, Index, FromKey, ToKey}, [{map, {modfun, riak_kv_mapreduce, map_object_value}, none, true}]) of {ok, [{_, Objs}]} -> {ok, Objs}; {ok, []} -> {ok, []}; {error, _} = Error -> Error end. -spec count(atom()) -> {ok, non_neg_integer()} | {error, any()}. %% @doc Returns the number of objects in the `Table' count(Table) -> Bucket = make_bucket(Table), case catch riakc_pb_socket:mapred( get_random_pid(), Bucket, [{reduce, {modfun, riak_kv_mapreduce, reduce_count_inputs}, none, true}]) of {ok, [{_, [Cnt]}]} -> {ok, Cnt}; {error, _} = Error -> log_error(Error, count, [{table, Table}]), Error end. -spec count_by_index(atom(), binary(), any()) -> {ok, non_neg_integer()} | {error, any()}. %% @doc Returns the number of objects in the `Table' by index count_by_index(Tab, Index, Key) -> {NewIndex, NewKey} = encode_index_key(Index, Key), case count_by_index_raw(Tab, NewIndex, NewKey) of {ok, Cnt} -> {ok, Cnt}; {error, notfound} -> {ok, 0}; {error, _} = Error -> log_error(Error, count_by_index, [{table, Tab}, {index, Index}, {key, Key}]), Error end. count_by_index_raw(Table, Index, Key) -> Bucket = make_bucket(Table), case catch riakc_pb_socket:mapred( get_random_pid(), {index, Bucket, Index, Key}, [{reduce, {modfun, riak_kv_mapreduce, reduce_count_inputs}, none, true}]) of {ok, [{_, [Cnt]}]} -> {ok, Cnt}; {error, _} = Error -> Error end. -spec delete(tuple() | atom()) -> ok | {error, any()}. %% @doc Same as delete(T, []) when T is record. %% Or deletes all elements from table if T is atom. delete(Rec) when is_tuple(Rec) -> delete(Rec, []); delete(Table) when is_atom(Table) -> try {ok, Keys} = ?MODULE:get_keys(Table), lists:foreach( fun(K) -> ok = delete(Table, K) end, Keys) catch _:{badmatch, Err} -> Err end. -spec delete(tuple() | atom(), index_info() | any()) -> ok | {error, any()}. %% @doc Delete an object delete(Rec, Opts) when is_tuple(Rec) -> Table = element(1, Rec), Key = proplists:get_value(i, Opts, element(2, Rec)), delete(Table, Key); delete(Table, Key) when is_atom(Table) -> case delete_raw(Table, encode_key(Key)) of ok -> ok; Err -> log_error(Err, delete, [{table, Table}, {key, Key}]), Err end. delete_raw(Table, Key) -> Bucket = make_bucket(Table), catch riakc_pb_socket:delete(get_random_pid(), Bucket, Key). -spec delete_by_index(atom(), binary(), any()) -> ok | {error, any()}. %% @doc Deletes objects by index delete_by_index(Table, Index, Key) -> try {ok, Keys} = get_keys_by_index(Table, Index, Key), lists:foreach( fun(K) -> ok = delete(Table, K) end, Keys) catch _:{badmatch, Err} -> Err end. %%%=================================================================== %%% map/reduce functions %%%=================================================================== %% @private map_key(Obj, _, _) -> [case riak_object:key(Obj) of <<"b_", B/binary>> -> B; <<"i_", B/binary>> -> (binary_to_integer(B)); B -> erlang:binary_to_term(B) end]. %%%=================================================================== %%% gen_server API %%%=================================================================== %% @private init([Server, Port, Options]) -> case riakc_pb_socket:start(Server, Port, Options) of {ok, Pid} -> erlang:monitor(process, Pid), {ok, #state{pid = Pid}}; Err -> {stop, Err} end. %% @private handle_call(get_pid, _From, #state{pid = Pid} = State) -> {reply, {ok, Pid}, State}; handle_call(_Request, _From, State) -> Reply = ok, {reply, Reply, State}. %% @private handle_cast(_Msg, State) -> {noreply, State}. %% @private handle_info({'DOWN', _MonitorRef, _Type, _Object, _Info}, State) -> {stop, normal, State}; handle_info(_Info, State) -> ?ERROR_MSG("unexpected info: ~p", [_Info]), {noreply, State}. %% @private terminate(_Reason, _State) -> ok. %% @private code_change(_OldVsn, State, _Extra) -> {ok, State}. %%%=================================================================== %%% Internal functions %%%=================================================================== encode_index_key(Idx, Key) when is_integer(Key) -> {<>, Key}; encode_index_key(Idx, Key) -> {<>, encode_key(Key)}. encode_key(Bin) when is_binary(Bin) -> <<"b_", Bin/binary>>; encode_key(Int) when is_integer(Int) -> <<"i_", ((integer_to_binary(Int)))/binary>>; encode_key(Term) -> erlang:term_to_binary(Term). log_error({error, notfound}, _, _) -> ok; log_error({error, Why} = Err, Function, Opts) -> Txt = lists:map( fun({table, Table}) -> io_lib:fwrite("** Table: ~p~n", [Table]); ({key, Key}) -> io_lib:fwrite("** Key: ~p~n", [Key]); ({index, Index}) -> io_lib:fwrite("** Index = ~p~n", [Index]); ({start_key, Key}) -> io_lib:fwrite("** Start Key: ~p~n", [Key]); ({end_key, Key}) -> io_lib:fwrite("** End Key: ~p~n", [Key]); ({record, Rec}) -> io_lib:fwrite("** Record = ~p~n", [Rec]); ({index_info, IdxInfo}) -> io_lib:fwrite("** Index info = ~p~n", [IdxInfo]); (_) -> "" end, Opts), ErrTxt = if is_binary(Why) -> io_lib:fwrite("** Error: ~s", [Why]); true -> io_lib:fwrite("** Error: ~p", [Err]) end, ?ERROR_MSG("database error:~n** Function: ~p~n~s~s", [Function, Txt, ErrTxt]); log_error(_, _, _) -> ok. make_invalid_object(Val) -> (str:format("Invalid object: ~p", [Val])). get_random_pid() -> case ejabberd_riak_sup:start() of ok -> PoolPid = ejabberd_riak_sup:get_random_pid(), get_riak_pid(PoolPid); {error, _} = Err -> Err end. get_riak_pid(PoolPid) -> case catch gen_server:call(PoolPid, get_pid) of {ok, Pid} -> Pid; {'EXIT', {timeout, _}} -> throw({error, timeout}); {'EXIT', Err} -> throw({error, Err}) end. encode_record(Rec, {Fields, DefRec}) -> term_to_binary(encode_record(Rec, Fields, DefRec, 2)). encode_record(Rec, [FieldName|Fields], DefRec, Pos) -> Value = element(Pos, Rec), DefValue = element(Pos, DefRec), if Value == DefValue -> encode_record(Rec, Fields, DefRec, Pos+1); true -> [{FieldName, Value}|encode_record(Rec, Fields, DefRec, Pos+1)] end; encode_record(_, [], _, _) -> []. decode_record(Bin, {Fields, DefRec}) -> decode_record(binary_to_term(Bin), Fields, DefRec, 2). decode_record(KeyVals, [FieldName|Fields], Rec, Pos) -> case lists:keyfind(FieldName, 1, KeyVals) of {_, Value} -> NewRec = setelement(Pos, Rec, Value), decode_record(KeyVals, Fields, NewRec, Pos+1); false -> decode_record(KeyVals, Fields, Rec, Pos+1) end; decode_record(_, [], Rec, _) -> Rec.