Speedup certificate chains creation and validation

This commit is contained in:
Evgeniy Khramtsov 2017-12-07 14:32:12 +03:00
parent d8ace67a50
commit a303373b0f
1 changed files with 166 additions and 77 deletions

View File

@ -40,6 +40,7 @@
notify = false :: boolean(), notify = false :: boolean(),
paths = [] :: [file:filename()], paths = [] :: [file:filename()],
certs = #{} :: map(), certs = #{} :: map(),
graph :: digraph:graph(),
keys = [] :: [public_key:private_key()]}). keys = [] :: [public_key:private_key()]}).
-type state() :: #state{}. -type state() :: #state{}.
@ -54,6 +55,8 @@
-type cert_error() :: not_cert | not_der | not_pem | encrypted. -type cert_error() :: not_cert | not_der | not_pem | encrypted.
-export_type([cert_error/0]). -export_type([cert_error/0]).
-define(CA_CACHE, ca_cache).
%%%=================================================================== %%%===================================================================
%%% API %%% API
%%%=================================================================== %%%===================================================================
@ -143,6 +146,10 @@ start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
config_reloaded() -> config_reloaded() ->
case use_cache() of
true -> init_cache();
false -> delete_cache()
end,
gen_server:call(?MODULE, config_reloaded, 60000). gen_server:call(?MODULE, config_reloaded, 60000).
opt_type(ca_path) -> opt_type(ca_path) ->
@ -182,7 +189,9 @@ init([]) ->
if Validate -> check_ca(); if Validate -> check_ca();
true -> ok true -> ok
end, end,
State = #state{validate = Validate, notify = Notify}, G = digraph:new([acyclic]),
init_cache(),
State = #state{validate = Validate, notify = Notify, graph = G},
case filelib:ensure_dir(filename:join(certs_dir(), "foo")) of case filelib:ensure_dir(filename:join(certs_dir(), "foo")) of
ok -> ok ->
clean_dir(certs_dir()), clean_dir(certs_dir()),
@ -201,11 +210,15 @@ init([]) ->
handle_call({add_certfile, Path}, _, State) -> handle_call({add_certfile, Path}, _, State) ->
case add_certfile(Path, State) of case add_certfile(Path, State) of
{ok, State1} -> {ok, State1} ->
case build_chain_and_check(State1) of if State /= State1 ->
{ok, State2} -> case build_chain_and_check(State1) of
{reply, ok, State2}; {ok, State2} ->
Err -> {reply, ok, State2};
{reply, Err, State} Err ->
{reply, Err, State1}
end;
true ->
{reply, ok, State1}
end; end;
{Err, State1} -> {Err, State1} ->
{reply, Err, State1} {reply, Err, State1}
@ -297,6 +310,7 @@ get_certfiles_from_config_options(_State) ->
-spec add_certfiles(state()) -> {ok, state()} | {error, bad_cert()}. -spec add_certfiles(state()) -> {ok, state()} | {error, bad_cert()}.
add_certfiles(State) -> add_certfiles(State) ->
?DEBUG("Reading certificates", []),
Paths = get_certfiles_from_config_options(State), Paths = get_certfiles_from_config_options(State),
State1 = lists:foldl( State1 = lists:foldl(
fun(Path, Acc) -> fun(Path, Acc) ->
@ -353,18 +367,21 @@ add_certfile(Path, State) ->
-spec build_chain_and_check(state()) -> ok | {error, bad_cert()}. -spec build_chain_and_check(state()) -> ok | {error, bad_cert()}.
build_chain_and_check(State) -> build_chain_and_check(State) ->
?DEBUG("Rebuilding certificate chains from ~s", ?DEBUG("Building certificates graph", []),
[str:join(State#state.paths, <<", ">>)]), CertPaths = get_cert_paths(maps:keys(State#state.certs), State#state.graph),
CertPaths = get_cert_paths(maps:keys(State#state.certs)), ?DEBUG("Finding matched certificate keys", []),
case match_cert_keys(CertPaths, State#state.keys) of case match_cert_keys(CertPaths, State#state.keys) of
{ok, Chains} -> {ok, Chains} ->
?DEBUG("Storing certificate chains", []),
CertFilesWithDomains = store_certs(Chains, []), CertFilesWithDomains = store_certs(Chains, []),
ets:delete_all_objects(?MODULE), ets:delete_all_objects(?MODULE),
lists:foreach( lists:foreach(
fun({Path, Domain}) -> fun({Path, Domain}) ->
ets:insert(?MODULE, {Domain, Path}) ets:insert(?MODULE, {Domain, Path})
end, CertFilesWithDomains), end, CertFilesWithDomains),
?DEBUG("Validating certificates", []),
Errors = validate(CertPaths, State#state.validate), Errors = validate(CertPaths, State#state.validate),
?DEBUG("Subscribing to file events", []),
subscribe(State), subscribe(State),
lists:foreach( lists:foreach(
fun({Cert, Why}) -> fun({Cert, Why}) ->
@ -485,21 +502,43 @@ decode_certs(PemEntries) ->
-spec validate([{path, [cert()]}], boolean()) -> [{cert(), bad_cert()}]. -spec validate([{path, [cert()]}], boolean()) -> [{cert(), bad_cert()}].
validate(Paths, true) -> validate(Paths, true) ->
lists:flatmap( {ok, Re} = re:compile("^[a-f0-9]+\\.[0-9]+$", [unicode]),
Hashes = case file:list_dir(ca_dir()) of
{ok, Files} ->
lists:foldl(
fun(File, Acc) ->
try re:run(File, Re) of
{match, _} ->
[Hash|_] = string:tokens(File, "."),
Path = filename:join(ca_dir(), File),
dict:append(Hash, Path, Acc);
nomatch ->
Acc
catch _:badarg ->
?ERROR_MSG("Regexp failure on ~w", [File]),
Acc
end
end, dict:new(), Files);
{error, Why} ->
?ERROR_MSG("Failed to list directory ~s: ~s",
[ca_dir(), file:format_error(Why)]),
dict:new()
end,
lists:filtermap(
fun({path, Path}) -> fun({path, Path}) ->
case validate_path(Path) of case validate_path(Path, Hashes) of
ok -> ok ->
[]; false;
{error, Cert, Reason} -> {error, Cert, Reason} ->
[{Cert, Reason}] {true, {Cert, Reason}}
end end
end, Paths); end, Paths);
validate(_, _) -> validate(_, _) ->
[]. [].
-spec validate_path([cert()]) -> ok | {error, cert(), bad_cert()}. -spec validate_path([cert()], dict:dict()) -> ok | {error, cert(), bad_cert()}.
validate_path([Cert|_] = Certs) -> validate_path([Cert|_] = Certs, Cache) ->
case find_local_issuer(Cert) of case find_local_issuer(Cert, Cache) of
{ok, IssuerCert} -> {ok, IssuerCert} ->
try public_key:pkix_path_validation(IssuerCert, Certs, []) of try public_key:pkix_path_validation(IssuerCert, Certs, []) of
{ok, _} -> {ok, _} ->
@ -570,67 +609,88 @@ check_ca() ->
ok ok
end. end.
-spec find_local_issuer(cert()) -> {ok, cert()} | {error, {bad_cert, unknown_ca}}. -spec find_local_issuer(cert(), dict:dict()) -> {ok, cert()} |
find_local_issuer(Cert) -> {error, {bad_cert, unknown_ca}}.
case find_issuer_in_dir(Cert, ca_dir()) of find_local_issuer(Cert, Hashes) ->
case find_issuer_in_dir(Cert, Hashes) of
{ok, IssuerCert} -> {ok, IssuerCert} ->
{ok, IssuerCert}; {ok, IssuerCert};
{error, _} = Err -> {error, Reason} ->
case ca_file() of case ca_file() of
undefined -> Err; undefined -> {error, Reason};
CAFile -> find_issuer_in_file(Cert, CAFile) CAFile -> find_issuer_in_file(Cert, CAFile)
end end
end. end.
-spec find_issuer_in_dir(cert(), file:filename_all()) -spec find_issuer_in_dir(cert(), dict:dict())
-> {ok, cert()} | {error, {bad_cert, unknown_ca}}. -> {{ok, cert()} | {error, {bad_cert, unknown_ca}}, dict:dict()}.
find_issuer_in_dir(Cert, CADir) -> find_issuer_in_dir(Cert, Cache) ->
{ok, {_, IssuerID}} = public_key:pkix_issuer_id(Cert, self), {ok, {_, IssuerID}} = public_key:pkix_issuer_id(Cert, self),
Hash = short_name_hash(IssuerID), Hash = short_name_hash(IssuerID),
filelib:fold_files( Files = case dict:find(Hash, Cache) of
CADir, Hash ++ "\\.[0-9]+", false, {ok, L} -> L;
fun(_, {ok, IssuerCert}) -> error -> []
{ok, IssuerCert}; end,
(CertFile, Acc) -> lists:foldl(
try fun(_, {ok, _IssuerCert} = Acc) ->
{ok, Data} = file:read_file(CertFile), Acc;
{ok, [IssuerCert|_], _} = pem_decode(Data), (Path, Err) ->
case public_key:pkix_is_issuer(Cert, IssuerCert) of case read_ca_file(Path) of
true -> {ok, [IssuerCert|_]} ->
{ok, IssuerCert}; case public_key:pkix_is_issuer(Cert, IssuerCert) of
false -> true ->
Acc {ok, IssuerCert};
end false ->
catch _:{badmatch, {error, Why}} -> Err
?ERROR_MSG("failed to read CA certificate from \"~s\": ~s", end;
[CertFile, format_error(Why)]), error ->
Acc Err
end end
end, {error, {bad_cert, unknown_ca}}). end, {error, {bad_cert, unknown_ca}}, Files).
-spec find_issuer_in_file(cert(), file:filename_all() | undefined) -spec find_issuer_in_file(cert(), file:filename_all() | undefined)
-> {ok, cert()} | {error, {bad_cert, unknown_ca}}. -> {ok, cert()} | {error, {bad_cert, unknown_ca}}.
find_issuer_in_file(_Cert, undefined) -> find_issuer_in_file(_Cert, undefined) ->
{error, {bad_cert, unknown_ca}}; {error, {bad_cert, unknown_ca}};
find_issuer_in_file(Cert, CAFile) -> find_issuer_in_file(Cert, CAFile) ->
try case read_ca_file(CAFile) of
{ok, Data} = file:read_file(CAFile), {ok, IssuerCerts} ->
{ok, IssuerCerts, _} = pem_decode(Data), lists:foldl(
lists:foldl( fun(_, {ok, _} = Res) ->
fun(_, {ok, _} = Res) -> Res;
Res; (IssuerCert, Err) ->
(IssuerCert, Err) -> case public_key:pkix_is_issuer(Cert, IssuerCert) of
case public_key:pkix_is_issuer(Cert, IssuerCert) of true -> {ok, IssuerCert};
true -> {ok, IssuerCert}; false -> Err
false -> Err end
end end, {error, {bad_cert, unknown_ca}}, IssuerCerts);
end, {error, {bad_cert, unknown_ca}}, IssuerCerts) error ->
catch _:{badmatch, {error, Why}} ->
?ERROR_MSG("failed to read CA certificates from \"~s\": ~s",
[CAFile, format_error(Why)]),
{error, {bad_cert, unknown_ca}} {error, {bad_cert, unknown_ca}}
end. end.
-spec read_ca_file(file:filename_all()) -> {ok, [cert()]} | error.
read_ca_file(Path) ->
case use_cache() of
true ->
ets_cache:lookup(?CA_CACHE, Path,
fun() -> do_read_ca_file(Path) end);
false ->
do_read_ca_file(Path)
end.
-spec do_read_ca_file(file:filename_all()) -> {ok, [cert()]} | error.
do_read_ca_file(Path) ->
try
{ok, Data} = file:read_file(Path),
{ok, IssuerCerts, _} = pem_decode(Data),
{ok, IssuerCerts}
catch _:{badmatch, {error, Why}} ->
?ERROR_MSG("Failed to read CA certificate "
"from \"~s\": ~s",
[Path, format_error(Why)]),
error
end.
-spec match_cert_keys([{path, [cert()]}], [priv_key()]) -spec match_cert_keys([{path, [cert()]}], [priv_key()])
-> {ok, [{cert(), priv_key()}]} | {error, {bad_cert, missing_priv_key}}. -> {ok, [{cert(), priv_key()}]} | {error, {bad_cert, missing_priv_key}}.
match_cert_keys(CertPaths, PrivKeys) -> match_cert_keys(CertPaths, PrivKeys) ->
@ -680,13 +740,22 @@ pubkey_from_privkey(#'DSAPrivateKey'{p = P, q = Q, g = G, y = Y}) ->
pubkey_from_privkey(#'ECPrivateKey'{publicKey = Key}) -> pubkey_from_privkey(#'ECPrivateKey'{publicKey = Key}) ->
#'ECPoint'{point = Key}. #'ECPoint'{point = Key}.
-spec get_cert_paths([cert()]) -> [{path, [cert()]}]. -spec get_cert_paths([cert()], digraph:graph()) -> [{path, [cert()]}].
get_cert_paths(Certs) -> get_cert_paths(Certs, G) ->
G = digraph:new([acyclic]), {NewCerts, OldCerts} =
lists:foreach( lists:partition(
fun(Cert) -> fun(Cert) ->
digraph:add_vertex(G, Cert) case digraph:vertex(G, Cert) of
end, Certs), false ->
digraph:add_vertex(G, Cert),
true;
{_, _} ->
false
end
end, Certs),
CertPairs = [{C1, C2} || C1 <- NewCerts, C2 <- OldCerts] ++
[{C1, C2} || C1 <- OldCerts, C2 <- NewCerts] ++
[{C1, C2} || C1 <- NewCerts, C2 <- NewCerts],
lists:foreach( lists:foreach(
fun({Cert1, Cert2}) when Cert1 /= Cert2 -> fun({Cert1, Cert2}) when Cert1 /= Cert2 ->
case public_key:pkix_is_self_signed(Cert1) of case public_key:pkix_is_self_signed(Cert1) of
@ -702,18 +771,16 @@ get_cert_paths(Certs) ->
end; end;
(_) -> (_) ->
ok ok
end, [{Cert1, Cert2} || Cert1 <- Certs, Cert2 <- Certs]), end, CertPairs),
Paths = lists:flatmap( lists:flatmap(
fun(Cert) -> fun(Cert) ->
case digraph:in_degree(G, Cert) of case digraph:in_degree(G, Cert) of
0 -> 0 ->
get_cert_path(G, [Cert]); get_cert_path(G, [Cert]);
_ -> _ ->
[] []
end end
end, Certs), end, Certs).
digraph:delete(G),
Paths.
get_cert_path(G, [Root|_] = Acc) -> get_cert_path(G, [Root|_] = Acc) ->
case digraph:out_edges(G, Root) of case digraph:out_edges(G, Root) of
@ -783,3 +850,25 @@ wildcard(Path) when is_binary(Path) ->
wildcard(binary_to_list(Path)); wildcard(binary_to_list(Path));
wildcard(Path) -> wildcard(Path) ->
filelib:wildcard(Path). filelib:wildcard(Path).
-spec use_cache() -> boolean().
use_cache() ->
ejabberd_config:use_cache(global).
-spec init_cache() -> ok.
init_cache() ->
ets_cache:new(?CA_CACHE, cache_opts()).
-spec delete_cache() -> ok.
delete_cache() ->
ets_cache:delete(?CA_CACHE).
-spec cache_opts() -> [proplists:property()].
cache_opts() ->
MaxSize = ejabberd_config:cache_size(global),
CacheMissed = ejabberd_config:cache_missed(global),
LifeTime = case ejabberd_config:cache_life_time(global) of
infinity -> infinity;
I -> timer:seconds(I)
end,
[{max_size, MaxSize}, {cache_missed, CacheMissed}, {life_time, LifeTime}].