diff --git a/src/ejabberd_pkix.erl b/src/ejabberd_pkix.erl index 229492bea..1a4f89dff 100644 --- a/src/ejabberd_pkix.erl +++ b/src/ejabberd_pkix.erl @@ -329,7 +329,8 @@ get_certfiles_from_config_options(_State) -> Host <- ejabberd_config:get_myhosts()]), [iolist_to_binary(P) || P <- lists:usort(Local ++ Global)]. --spec add_certfiles(state()) -> {ok, state()} | {error, bad_cert()}. +-spec add_certfiles(state()) -> {ok, state()} | + {error, bad_cert() | file:posix()}. add_certfiles(State) -> ?DEBUG("Reading certificates", []), Paths = get_certfiles_from_config_options(State), @@ -343,7 +344,8 @@ add_certfiles(State) -> {error, _} = Err -> Err end. --spec add_certfiles(binary(), state()) -> {ok, state()} | {error, bad_cert()}. +-spec add_certfiles(binary(), state()) -> {ok, state()} | + {error, bad_cert() | file:posix()}. add_certfiles(Host, State) -> State1 = lists:foldl( fun(Opt, AccState) -> @@ -363,8 +365,8 @@ add_certfiles(Host, State) -> {ok, State} end. --spec add_certfile(file:filename_all(), state()) -> {ok, state()} | - {{error, cert_error()}, state()}. +-spec add_certfile(file:filename_all(), state()) -> + {ok, state()} | {{error, cert_error() | file:posix()}, state()}. add_certfile(Path, State) -> case lists:member(Path, State#state.paths) of true -> @@ -386,30 +388,14 @@ add_certfile(Path, State) -> end end. --spec build_chain_and_check(state()) -> ok | {error, bad_cert()}. +-spec build_chain_and_check(state()) -> ok | {error, bad_cert() | file:posix()}. build_chain_and_check(State) -> - ?DEBUG("Building certificates graph", []), CertPaths = get_cert_paths(maps:keys(State#state.certs), State#state.graph), - ?DEBUG("Finding matched certificate keys", []), case match_cert_keys(CertPaths, State#state.keys) of {ok, Chains} -> - ?DEBUG("Storing certificate chains", []), - CertFilesWithDomains = store_certs(Chains, []), - ets:delete_all_objects(?MODULE), - lists:foreach( - fun({Path, Domain}) -> - fast_tls:add_certfile(Domain, Path), - ets:insert(?MODULE, {Domain, Path}) - end, CertFilesWithDomains), - ?DEBUG("Validating certificates", []), - Errors = validate(CertPaths, State#state.validate), - ?DEBUG("Subscribing to file events", []), - lists:foreach( - fun({Cert, Why}) -> - Path = maps:get(Cert, State#state.certs), - ?WARNING_MSG("Failed to validate certificate from ~s: ~s", - [Path, format_error(Why)]) - end, Errors); + InvalidCerts = validate(CertPaths, State), + SortedChains = sort_chains(Chains, InvalidCerts), + store_certs(SortedChains); {error, Cert, Why} -> Path = maps:get(Cert, State#state.certs), ?ERROR_MSG("Failed to build certificate chain for ~s: ~s", @@ -417,9 +403,35 @@ build_chain_and_check(State) -> {error, Why} end. --spec store_certs([{[cert()], priv_key()}], - [{binary(), binary()}]) -> [{binary(), binary()}]. -store_certs([{Certs, Key}|Chains], Acc) -> +-spec store_certs([{[cert()], priv_key()}]) -> ok | {error, file:posix()}. +store_certs(Chains) -> + ?DEBUG("Storing certificate chains", []), + Res = lists:foldl( + fun(_, {error, _} = Err) -> + Err; + ({Certs, Key}, Acc) -> + case store_cert(Certs, Key) of + {ok, FileDoms} -> + Acc ++ FileDoms; + {error, _} = Err -> + Err + end + end, [], Chains), + case Res of + {error, Why} -> + {error, Why}; + FileDomains -> + ets:delete_all_objects(?MODULE), + lists:foreach( + fun({Path, Domain}) -> + fast_tls:add_certfile(Domain, Path), + ets:insert(?MODULE, {Domain, Path}) + end, FileDomains) + end. + +-spec store_cert([cert()], priv_key()) -> {ok, [{binary(), binary()}]} | + {error, file:posix()}. +store_cert(Certs, Key) -> CertPEMs = public_key:pem_encode( lists:map( fun(Cert) -> @@ -438,15 +450,39 @@ store_certs([{Certs, Key}|Chains], Acc) -> case file:write_file(FileName, PEMs) of ok -> file:change_mode(FileName, 8#600), - NewAcc = [{FileName, Domain} || Domain <- Domains] ++ Acc, - store_certs(Chains, NewAcc); - {error, Why} -> + {ok, [{FileName, Domain} || Domain <- Domains]}; + {error, Why} = Err -> ?ERROR_MSG("Failed to write to ~s: ~s", [FileName, file:format_error(Why)]), - store_certs(Chains, []) - end; -store_certs([], Acc) -> - Acc. + Err + end. + +-spec sort_chains([{[cert()], priv_key()}], [cert()]) -> [{[cert()], priv_key()}]. +sort_chains(Chains, InvalidCerts) -> + lists:sort( + fun({[Cert1|_], _}, {[Cert2|_], _}) -> + IsValid1 = not lists:member(Cert1, InvalidCerts), + IsValid2 = not lists:member(Cert2, InvalidCerts), + if IsValid1 and not IsValid2 -> + false; + IsValid2 and not IsValid1 -> + true; + true -> + compare_expiration_date(Cert1, Cert2) + end + end, Chains). + +%% Returns true if the first certificate has sooner expiration date +-spec compare_expiration_date(cert(), cert()) -> boolean(). +compare_expiration_date(#'OTPCertificate'{ + tbsCertificate = + #'OTPTBSCertificate'{ + validity = #'Validity'{notAfter = After1}}}, + #'OTPCertificate'{ + tbsCertificate = + #'OTPTBSCertificate'{ + validity = #'Validity'{notAfter = After2}}}) -> + get_timestamp(After1) =< get_timestamp(After2). -spec load_certfile(file:filename_all()) -> {ok, [cert()], [priv_key()]} | {error, cert_error() | file:posix()}. @@ -521,8 +557,9 @@ decode_certs(PemEntries) -> {error, not_der} end. --spec validate([{path, [cert()]}], boolean()) -> [{cert(), bad_cert()}]. -validate(Paths, true) -> +-spec validate([{path, [cert()]}], state()) -> [cert()]. +validate(Paths, #state{validate = true} = State) -> + ?DEBUG("Validating certificates", []), {ok, Re} = re:compile("^[a-f0-9]+\\.[0-9]+$", [unicode]), Hashes = case file:list_dir(ca_dir()) of {ok, Files} -> @@ -551,7 +588,10 @@ validate(Paths, true) -> ok -> false; {error, Cert, Reason} -> - {true, {Cert, Reason}} + File = maps:get(Cert, State#state.certs), + ?WARNING_MSG("Failed to validate certificate from ~s: ~s", + [File, format_error(Reason)]), + {true, Cert} end end, Paths); validate(_, _) -> @@ -715,6 +755,7 @@ do_read_ca_file(Path) -> -spec match_cert_keys([{path, [cert()]}], [priv_key()]) -> {ok, [{cert(), priv_key()}]} | {error, {bad_cert, missing_priv_key}}. match_cert_keys(CertPaths, PrivKeys) -> + ?DEBUG("Finding matched certificate keys", []), KeyPairs = [{pubkey_from_privkey(PrivKey), PrivKey} || PrivKey <- PrivKeys], match_cert_keys(CertPaths, KeyPairs, []). @@ -763,6 +804,7 @@ pubkey_from_privkey(#'ECPrivateKey'{publicKey = Key}) -> -spec get_cert_paths([cert()], digraph:graph()) -> [{path, [cert()]}]. get_cert_paths(Certs, G) -> + ?DEBUG("Building certificates graph", []), {NewCerts, OldCerts} = lists:partition( fun(Cert) -> @@ -838,6 +880,15 @@ short_name_hash(_) -> "". -endif. +-spec get_timestamp({utcTime | generalTime, string()}) -> string(). +get_timestamp({utcTime, [Y1,Y2|T]}) -> + case list_to_integer([Y1,Y2]) of + N when N >= 50 -> [$1,$9,Y1,Y2|T]; + _ -> [$2,$0,Y1,Y2|T] + end; +get_timestamp({generalTime, TS}) -> + TS. + wildcard(Path) when is_binary(Path) -> wildcard(binary_to_list(Path)); wildcard(Path) ->