diff --git a/src/ejabberd_oauth.erl b/src/ejabberd_oauth.erl index 3a0b276d1..8527c9271 100644 --- a/src/ejabberd_oauth.erl +++ b/src/ejabberd_oauth.erl @@ -28,6 +28,8 @@ -behaviour(gen_server). +-compile(export_all). + %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). @@ -46,6 +48,7 @@ check_token/2, scope_in_scope_list/2, process/2, + config_reloaded/0, opt_type/1]). -export([oauth_issue_token/3, oauth_list_tokens/0, oauth_revoke_token/1, oauth_list_scopes/0]). @@ -140,8 +143,14 @@ oauth_revoke_token(Token) -> oauth_list_scopes() -> [ {Scope, string:join([atom_to_list(Cmd) || Cmd <- Cmds], ",")} || {Scope, Cmds} <- dict:to_list(get_cmd_scopes())]. - - +config_reloaded() -> + DBMod = get_db_backend(), + case init_cache(DBMod) of + true -> + ets_cache:setopts(oauth_cache, cache_opts()); + false -> + ok + end. start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). @@ -150,23 +159,13 @@ start_link() -> init([]) -> DBMod = get_db_backend(), DBMod:init(), - MaxSize = - ejabberd_config:get_option( - oauth_cache_size, - fun(I) when is_integer(I), I>0 -> I end, - 1000), - LifeTime = - ejabberd_config:get_option( - oauth_cache_life_time, - fun(I) when is_integer(I), I>0 -> I end, - timer:hours(1) div 1000), - cache_tab:new(oauth_token, - [{max_size, MaxSize}, {life_time, LifeTime}]), + init_cache(DBMod), Expire = expire(), application:set_env(oauth2, backend, ejabberd_oauth), application:set_env(oauth2, expiry_time, Expire), application:start(oauth2), ejabberd_commands:register_commands(get_commands_spec()), + ejabberd_hooks:add(config_reloaded, ?MODULE, config_reloaded, 50), erlang:send_after(expire() * 1000, self(), clean), {ok, ok}. @@ -371,24 +370,59 @@ check_token(ScopeList, Token) -> store(R) -> - cache_tab:insert( - oauth_token, R#oauth_token.token, R, - fun() -> - DBMod = get_db_backend(), - DBMod:store(R) - end). + DBMod = get_db_backend(), + case DBMod:store(R) of + ok -> + ets_cache:delete(oauth_cache, R#oauth_token.token, + ejabberd_cluster:get_nodes()); + {error, _} = Err -> + Err + end. lookup(Token) -> - cache_tab:lookup( - oauth_token, Token, - fun() -> - DBMod = get_db_backend(), - case DBMod:lookup(Token) of - #oauth_token{} = R -> {ok, R}; - _ -> error - end - end). + ets_cache:lookup(oauth_cache, Token, + fun() -> + DBMod = get_db_backend(), + DBMod:lookup(Token) + end). +-spec init_cache(module()) -> boolean(). +init_cache(DBMod) -> + UseCache = use_cache(DBMod), + case UseCache of + true -> + ets_cache:new(oauth_cache, cache_opts()); + false -> + ets_cache:delete(oauth_cache) + end, + UseCache. + +use_cache(DBMod) -> + case erlang:function_exported(DBMod, use_cache, 0) of + true -> DBMod:use_cache(); + false -> + ejabberd_config:get_option( + oauth_use_cache, opt_type(oauth_use_cache), + ejabberd_config:use_cache(global)) + end. + +cache_opts() -> + MaxSize = ejabberd_config:get_option( + oauth_cache_size, + opt_type(oauth_cache_size), + ejabberd_config:cache_size(global)), + CacheMissed = ejabberd_config:get_option( + oauth_cache_missed, + opt_type(oauth_cache_missed), + ejabberd_config:cache_missed(global)), + LifeTime = case ejabberd_config:get_option( + oauth_cache_life_time, + opt_type(oauth_cache_life_time), + ejabberd_config:cache_life_time(global)) of + infinity -> infinity; + I -> timer:seconds(I) + end, + [{max_size, MaxSize}, {life_time, LifeTime}, {cache_missed, CacheMissed}]. expire() -> ejabberd_config:get_option( @@ -746,8 +780,13 @@ opt_type(oauth_access) -> fun acl:access_rules_validator/1; opt_type(oauth_db_type) -> fun(T) -> ejabberd_config:v_db(?MODULE, T) end; -opt_type(oauth_cache_life_time) -> - fun (I) when is_integer(I), I > 0 -> I end; -opt_type(oauth_cache_size) -> - fun (I) when is_integer(I), I > 0 -> I end; -opt_type(_) -> [oauth_expire, oauth_access, oauth_db_type]. +opt_type(O) when O == oauth_cache_life_time; O == oauth_cache_size -> + fun (I) when is_integer(I), I > 0 -> I; + (infinity) -> infinity + end; +opt_type(O) when O == oauth_use_cache; O == oauth_cache_missed -> + fun (B) when is_boolean(B) -> B end; +opt_type(_) -> + [oauth_expire, oauth_access, oauth_db_type, + oauth_cache_life_time, oauth_cache_size, oauth_use_cache, + oauth_cache_missed]. diff --git a/src/ejabberd_oauth_mnesia.erl b/src/ejabberd_oauth_mnesia.erl index c9ef6dcac..8a9997929 100644 --- a/src/ejabberd_oauth_mnesia.erl +++ b/src/ejabberd_oauth_mnesia.erl @@ -47,9 +47,9 @@ store(R) -> lookup(Token) -> case catch mnesia:dirty_read(oauth_token, Token) of [R] -> - R; + {ok, R}; _ -> - false + error end. clean(TS) -> diff --git a/src/ejabberd_oauth_rest.erl b/src/ejabberd_oauth_rest.erl index b9614eb09..15e118a0b 100644 --- a/src/ejabberd_oauth_rest.erl +++ b/src/ejabberd_oauth_rest.erl @@ -58,7 +58,7 @@ store(R) -> ok; Err -> ?ERROR_MSG("failed to store oauth record ~p: ~p", [R, Err]), - {error, Err} + {error, db_failure} end. lookup(Token) -> @@ -72,15 +72,15 @@ lookup(Token) -> US = {JID#jid.luser, JID#jid.lserver}, Scope = proplists:get_value(<<"scope">>, Data, []), Expire = proplists:get_value(<<"expire">>, Data, 0), - #oauth_token{token = Token, - us = US, - scope = Scope, - expire = Expire}; + {ok, #oauth_token{token = Token, + us = US, + scope = Scope, + expire = Expire}}; {ok, 404, _Resp} -> - false; + error; Other -> ?ERROR_MSG("Unexpected response for oauth lookup: ~p", [Other]), - {error, rest_failed} + error end. clean(_TS) -> diff --git a/src/ejabberd_oauth_sql.erl b/src/ejabberd_oauth_sql.erl index 10ca49844..5c4a96641 100644 --- a/src/ejabberd_oauth_sql.erl +++ b/src/ejabberd_oauth_sql.erl @@ -37,6 +37,7 @@ -include("ejabberd.hrl"). -include("ejabberd_sql_pt.hrl"). -include("jid.hrl"). +-include("logger.hrl"). init() -> ok. @@ -47,13 +48,20 @@ store(R) -> SJID = jid:encode({User, Server, <<"">>}), Scope = str:join(R#oauth_token.scope, <<" ">>), Expire = R#oauth_token.expire, - ?SQL_UPSERT( - ?MYNAME, - "oauth_token", - ["!token=%(Token)s", - "jid=%(SJID)s", - "scope=%(Scope)s", - "expire=%(Expire)d"]). + case ?SQL_UPSERT( + ?MYNAME, + "oauth_token", + ["!token=%(Token)s", + "jid=%(SJID)s", + "scope=%(Scope)s", + "expire=%(Expire)d"]) of + ok -> + ok; + Err -> + ?ERROR_MSG("Failed to write to SQL 'oauth_token' table: ~p", + [Err]), + {error, db_failure} + end. lookup(Token) -> case ejabberd_sql:sql_query( @@ -63,12 +71,12 @@ lookup(Token) -> {selected, [{SJID, Scope, Expire}]} -> JID = jid:decode(SJID), US = {JID#jid.luser, JID#jid.lserver}, - #oauth_token{token = Token, - us = US, - scope = str:tokens(Scope, <<" ">>), - expire = Expire}; + {ok, #oauth_token{token = Token, + us = US, + scope = str:tokens(Scope, <<" ">>), + expire = Expire}}; _ -> - false + error end. clean(TS) ->