diff --git a/src/ejabberd_pkix.erl b/src/ejabberd_pkix.erl index a67df1288..037fc9e9e 100644 --- a/src/ejabberd_pkix.erl +++ b/src/ejabberd_pkix.erl @@ -40,6 +40,7 @@ notify = false :: boolean(), paths = [] :: [file:filename()], certs = #{} :: map(), + graph :: digraph:graph(), keys = [] :: [public_key:private_key()]}). -type state() :: #state{}. @@ -54,6 +55,8 @@ -type cert_error() :: not_cert | not_der | not_pem | encrypted. -export_type([cert_error/0]). +-define(CA_CACHE, ca_cache). + %%%=================================================================== %%% API %%%=================================================================== @@ -143,6 +146,10 @@ start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). config_reloaded() -> + case use_cache() of + true -> init_cache(); + false -> delete_cache() + end, gen_server:call(?MODULE, config_reloaded, 60000). opt_type(ca_path) -> @@ -182,7 +189,9 @@ init([]) -> if Validate -> check_ca(); true -> ok end, - State = #state{validate = Validate, notify = Notify}, + G = digraph:new([acyclic]), + init_cache(), + State = #state{validate = Validate, notify = Notify, graph = G}, case filelib:ensure_dir(filename:join(certs_dir(), "foo")) of ok -> clean_dir(certs_dir()), @@ -201,11 +210,15 @@ init([]) -> handle_call({add_certfile, Path}, _, State) -> case add_certfile(Path, State) of {ok, State1} -> - case build_chain_and_check(State1) of - {ok, State2} -> - {reply, ok, State2}; - Err -> - {reply, Err, State} + if State /= State1 -> + case build_chain_and_check(State1) of + {ok, State2} -> + {reply, ok, State2}; + Err -> + {reply, Err, State1} + end; + true -> + {reply, ok, State1} end; {Err, State1} -> {reply, Err, State1} @@ -297,6 +310,7 @@ get_certfiles_from_config_options(_State) -> -spec add_certfiles(state()) -> {ok, state()} | {error, bad_cert()}. add_certfiles(State) -> + ?DEBUG("Reading certificates", []), Paths = get_certfiles_from_config_options(State), State1 = lists:foldl( fun(Path, Acc) -> @@ -353,18 +367,21 @@ add_certfile(Path, State) -> -spec build_chain_and_check(state()) -> ok | {error, bad_cert()}. build_chain_and_check(State) -> - ?DEBUG("Rebuilding certificate chains from ~s", - [str:join(State#state.paths, <<", ">>)]), - CertPaths = get_cert_paths(maps:keys(State#state.certs)), + ?DEBUG("Building certificates graph", []), + CertPaths = get_cert_paths(maps:keys(State#state.certs), State#state.graph), + ?DEBUG("Finding matched certificate keys", []), case match_cert_keys(CertPaths, State#state.keys) of {ok, Chains} -> + ?DEBUG("Storing certificate chains", []), CertFilesWithDomains = store_certs(Chains, []), ets:delete_all_objects(?MODULE), lists:foreach( fun({Path, Domain}) -> ets:insert(?MODULE, {Domain, Path}) end, CertFilesWithDomains), + ?DEBUG("Validating certificates", []), Errors = validate(CertPaths, State#state.validate), + ?DEBUG("Subscribing to file events", []), subscribe(State), lists:foreach( fun({Cert, Why}) -> @@ -485,21 +502,43 @@ decode_certs(PemEntries) -> -spec validate([{path, [cert()]}], boolean()) -> [{cert(), bad_cert()}]. validate(Paths, true) -> - lists:flatmap( + {ok, Re} = re:compile("^[a-f0-9]+\\.[0-9]+$", [unicode]), + Hashes = case file:list_dir(ca_dir()) of + {ok, Files} -> + lists:foldl( + fun(File, Acc) -> + try re:run(File, Re) of + {match, _} -> + [Hash|_] = string:tokens(File, "."), + Path = filename:join(ca_dir(), File), + dict:append(Hash, Path, Acc); + nomatch -> + Acc + catch _:badarg -> + ?ERROR_MSG("Regexp failure on ~w", [File]), + Acc + end + end, dict:new(), Files); + {error, Why} -> + ?ERROR_MSG("Failed to list directory ~s: ~s", + [ca_dir(), file:format_error(Why)]), + dict:new() + end, + lists:filtermap( fun({path, Path}) -> - case validate_path(Path) of + case validate_path(Path, Hashes) of ok -> - []; + false; {error, Cert, Reason} -> - [{Cert, Reason}] + {true, {Cert, Reason}} end end, Paths); validate(_, _) -> []. --spec validate_path([cert()]) -> ok | {error, cert(), bad_cert()}. -validate_path([Cert|_] = Certs) -> - case find_local_issuer(Cert) of +-spec validate_path([cert()], dict:dict()) -> ok | {error, cert(), bad_cert()}. +validate_path([Cert|_] = Certs, Cache) -> + case find_local_issuer(Cert, Cache) of {ok, IssuerCert} -> try public_key:pkix_path_validation(IssuerCert, Certs, []) of {ok, _} -> @@ -570,67 +609,88 @@ check_ca() -> ok end. --spec find_local_issuer(cert()) -> {ok, cert()} | {error, {bad_cert, unknown_ca}}. -find_local_issuer(Cert) -> - case find_issuer_in_dir(Cert, ca_dir()) of +-spec find_local_issuer(cert(), dict:dict()) -> {ok, cert()} | + {error, {bad_cert, unknown_ca}}. +find_local_issuer(Cert, Hashes) -> + case find_issuer_in_dir(Cert, Hashes) of {ok, IssuerCert} -> {ok, IssuerCert}; - {error, _} = Err -> + {error, Reason} -> case ca_file() of - undefined -> Err; + undefined -> {error, Reason}; CAFile -> find_issuer_in_file(Cert, CAFile) end end. --spec find_issuer_in_dir(cert(), file:filename_all()) - -> {ok, cert()} | {error, {bad_cert, unknown_ca}}. -find_issuer_in_dir(Cert, CADir) -> +-spec find_issuer_in_dir(cert(), dict:dict()) + -> {{ok, cert()} | {error, {bad_cert, unknown_ca}}, dict:dict()}. +find_issuer_in_dir(Cert, Cache) -> {ok, {_, IssuerID}} = public_key:pkix_issuer_id(Cert, self), Hash = short_name_hash(IssuerID), - filelib:fold_files( - CADir, Hash ++ "\\.[0-9]+", false, - fun(_, {ok, IssuerCert}) -> - {ok, IssuerCert}; - (CertFile, Acc) -> - try - {ok, Data} = file:read_file(CertFile), - {ok, [IssuerCert|_], _} = pem_decode(Data), - case public_key:pkix_is_issuer(Cert, IssuerCert) of - true -> - {ok, IssuerCert}; - false -> - Acc - end - catch _:{badmatch, {error, Why}} -> - ?ERROR_MSG("failed to read CA certificate from \"~s\": ~s", - [CertFile, format_error(Why)]), - Acc + Files = case dict:find(Hash, Cache) of + {ok, L} -> L; + error -> [] + end, + lists:foldl( + fun(_, {ok, _IssuerCert} = Acc) -> + Acc; + (Path, Err) -> + case read_ca_file(Path) of + {ok, [IssuerCert|_]} -> + case public_key:pkix_is_issuer(Cert, IssuerCert) of + true -> + {ok, IssuerCert}; + false -> + Err + end; + error -> + Err end - end, {error, {bad_cert, unknown_ca}}). + end, {error, {bad_cert, unknown_ca}}, Files). -spec find_issuer_in_file(cert(), file:filename_all() | undefined) -> {ok, cert()} | {error, {bad_cert, unknown_ca}}. find_issuer_in_file(_Cert, undefined) -> {error, {bad_cert, unknown_ca}}; find_issuer_in_file(Cert, CAFile) -> - try - {ok, Data} = file:read_file(CAFile), - {ok, IssuerCerts, _} = pem_decode(Data), - lists:foldl( - fun(_, {ok, _} = Res) -> - Res; - (IssuerCert, Err) -> - case public_key:pkix_is_issuer(Cert, IssuerCert) of - true -> {ok, IssuerCert}; - false -> Err - end - end, {error, {bad_cert, unknown_ca}}, IssuerCerts) - catch _:{badmatch, {error, Why}} -> - ?ERROR_MSG("failed to read CA certificates from \"~s\": ~s", - [CAFile, format_error(Why)]), + case read_ca_file(CAFile) of + {ok, IssuerCerts} -> + lists:foldl( + fun(_, {ok, _} = Res) -> + Res; + (IssuerCert, Err) -> + case public_key:pkix_is_issuer(Cert, IssuerCert) of + true -> {ok, IssuerCert}; + false -> Err + end + end, {error, {bad_cert, unknown_ca}}, IssuerCerts); + error -> {error, {bad_cert, unknown_ca}} end. +-spec read_ca_file(file:filename_all()) -> {ok, [cert()]} | error. +read_ca_file(Path) -> + case use_cache() of + true -> + ets_cache:lookup(?CA_CACHE, Path, + fun() -> do_read_ca_file(Path) end); + false -> + do_read_ca_file(Path) + end. + +-spec do_read_ca_file(file:filename_all()) -> {ok, [cert()]} | error. +do_read_ca_file(Path) -> + try + {ok, Data} = file:read_file(Path), + {ok, IssuerCerts, _} = pem_decode(Data), + {ok, IssuerCerts} + catch _:{badmatch, {error, Why}} -> + ?ERROR_MSG("Failed to read CA certificate " + "from \"~s\": ~s", + [Path, format_error(Why)]), + error + end. + -spec match_cert_keys([{path, [cert()]}], [priv_key()]) -> {ok, [{cert(), priv_key()}]} | {error, {bad_cert, missing_priv_key}}. match_cert_keys(CertPaths, PrivKeys) -> @@ -680,13 +740,22 @@ pubkey_from_privkey(#'DSAPrivateKey'{p = P, q = Q, g = G, y = Y}) -> pubkey_from_privkey(#'ECPrivateKey'{publicKey = Key}) -> #'ECPoint'{point = Key}. --spec get_cert_paths([cert()]) -> [{path, [cert()]}]. -get_cert_paths(Certs) -> - G = digraph:new([acyclic]), - lists:foreach( - fun(Cert) -> - digraph:add_vertex(G, Cert) - end, Certs), +-spec get_cert_paths([cert()], digraph:graph()) -> [{path, [cert()]}]. +get_cert_paths(Certs, G) -> + {NewCerts, OldCerts} = + lists:partition( + fun(Cert) -> + case digraph:vertex(G, Cert) of + false -> + digraph:add_vertex(G, Cert), + true; + {_, _} -> + false + end + end, Certs), + CertPairs = [{C1, C2} || C1 <- NewCerts, C2 <- OldCerts] ++ + [{C1, C2} || C1 <- OldCerts, C2 <- NewCerts] ++ + [{C1, C2} || C1 <- NewCerts, C2 <- NewCerts], lists:foreach( fun({Cert1, Cert2}) when Cert1 /= Cert2 -> case public_key:pkix_is_self_signed(Cert1) of @@ -702,18 +771,16 @@ get_cert_paths(Certs) -> end; (_) -> ok - end, [{Cert1, Cert2} || Cert1 <- Certs, Cert2 <- Certs]), - Paths = lists:flatmap( - fun(Cert) -> - case digraph:in_degree(G, Cert) of - 0 -> - get_cert_path(G, [Cert]); - _ -> - [] - end - end, Certs), - digraph:delete(G), - Paths. + end, CertPairs), + lists:flatmap( + fun(Cert) -> + case digraph:in_degree(G, Cert) of + 0 -> + get_cert_path(G, [Cert]); + _ -> + [] + end + end, Certs). get_cert_path(G, [Root|_] = Acc) -> case digraph:out_edges(G, Root) of @@ -783,3 +850,25 @@ wildcard(Path) when is_binary(Path) -> wildcard(binary_to_list(Path)); wildcard(Path) -> filelib:wildcard(Path). + +-spec use_cache() -> boolean(). +use_cache() -> + ejabberd_config:use_cache(global). + +-spec init_cache() -> ok. +init_cache() -> + ets_cache:new(?CA_CACHE, cache_opts()). + +-spec delete_cache() -> ok. +delete_cache() -> + ets_cache:delete(?CA_CACHE). + +-spec cache_opts() -> [proplists:property()]. +cache_opts() -> + MaxSize = ejabberd_config:cache_size(global), + CacheMissed = ejabberd_config:cache_missed(global), + LifeTime = case ejabberd_config:cache_life_time(global) of + infinity -> infinity; + I -> timer:seconds(I) + end, + [{max_size, MaxSize}, {cache_missed, CacheMissed}, {life_time, LifeTime}].