Don't replace valid certificates with invalid ones

When building the certificates chains, if several certificates
are found matching the same domain their validity is checked:

* the invalid one is ignored and the valid one is picked
* if both are valid or both are invalid, then the one with
  sooner expiration is ignored.

Fixes #2454
This commit is contained in:
Evgeniy Khramtsov 2018-06-27 10:55:37 +03:00
parent 881e02632b
commit 7881c5670c
1 changed files with 88 additions and 37 deletions

View File

@ -329,7 +329,8 @@ get_certfiles_from_config_options(_State) ->
Host <- ejabberd_config:get_myhosts()]),
[iolist_to_binary(P) || P <- lists:usort(Local ++ Global)].
-spec add_certfiles(state()) -> {ok, state()} | {error, bad_cert()}.
-spec add_certfiles(state()) -> {ok, state()} |
{error, bad_cert() | file:posix()}.
add_certfiles(State) ->
?DEBUG("Reading certificates", []),
Paths = get_certfiles_from_config_options(State),
@ -343,7 +344,8 @@ add_certfiles(State) ->
{error, _} = Err -> Err
end.
-spec add_certfiles(binary(), state()) -> {ok, state()} | {error, bad_cert()}.
-spec add_certfiles(binary(), state()) -> {ok, state()} |
{error, bad_cert() | file:posix()}.
add_certfiles(Host, State) ->
State1 = lists:foldl(
fun(Opt, AccState) ->
@ -363,8 +365,8 @@ add_certfiles(Host, State) ->
{ok, State}
end.
-spec add_certfile(file:filename_all(), state()) -> {ok, state()} |
{{error, cert_error()}, state()}.
-spec add_certfile(file:filename_all(), state()) ->
{ok, state()} | {{error, cert_error() | file:posix()}, state()}.
add_certfile(Path, State) ->
case lists:member(Path, State#state.paths) of
true ->
@ -386,30 +388,14 @@ add_certfile(Path, State) ->
end
end.
-spec build_chain_and_check(state()) -> ok | {error, bad_cert()}.
-spec build_chain_and_check(state()) -> ok | {error, bad_cert() | file:posix()}.
build_chain_and_check(State) ->
?DEBUG("Building certificates graph", []),
CertPaths = get_cert_paths(maps:keys(State#state.certs), State#state.graph),
?DEBUG("Finding matched certificate keys", []),
case match_cert_keys(CertPaths, State#state.keys) of
{ok, Chains} ->
?DEBUG("Storing certificate chains", []),
CertFilesWithDomains = store_certs(Chains, []),
ets:delete_all_objects(?MODULE),
lists:foreach(
fun({Path, Domain}) ->
fast_tls:add_certfile(Domain, Path),
ets:insert(?MODULE, {Domain, Path})
end, CertFilesWithDomains),
?DEBUG("Validating certificates", []),
Errors = validate(CertPaths, State#state.validate),
?DEBUG("Subscribing to file events", []),
lists:foreach(
fun({Cert, Why}) ->
Path = maps:get(Cert, State#state.certs),
?WARNING_MSG("Failed to validate certificate from ~s: ~s",
[Path, format_error(Why)])
end, Errors);
InvalidCerts = validate(CertPaths, State),
SortedChains = sort_chains(Chains, InvalidCerts),
store_certs(SortedChains);
{error, Cert, Why} ->
Path = maps:get(Cert, State#state.certs),
?ERROR_MSG("Failed to build certificate chain for ~s: ~s",
@ -417,9 +403,35 @@ build_chain_and_check(State) ->
{error, Why}
end.
-spec store_certs([{[cert()], priv_key()}],
[{binary(), binary()}]) -> [{binary(), binary()}].
store_certs([{Certs, Key}|Chains], Acc) ->
-spec store_certs([{[cert()], priv_key()}]) -> ok | {error, file:posix()}.
store_certs(Chains) ->
?DEBUG("Storing certificate chains", []),
Res = lists:foldl(
fun(_, {error, _} = Err) ->
Err;
({Certs, Key}, Acc) ->
case store_cert(Certs, Key) of
{ok, FileDoms} ->
Acc ++ FileDoms;
{error, _} = Err ->
Err
end
end, [], Chains),
case Res of
{error, Why} ->
{error, Why};
FileDomains ->
ets:delete_all_objects(?MODULE),
lists:foreach(
fun({Path, Domain}) ->
fast_tls:add_certfile(Domain, Path),
ets:insert(?MODULE, {Domain, Path})
end, FileDomains)
end.
-spec store_cert([cert()], priv_key()) -> {ok, [{binary(), binary()}]} |
{error, file:posix()}.
store_cert(Certs, Key) ->
CertPEMs = public_key:pem_encode(
lists:map(
fun(Cert) ->
@ -438,15 +450,39 @@ store_certs([{Certs, Key}|Chains], Acc) ->
case file:write_file(FileName, PEMs) of
ok ->
file:change_mode(FileName, 8#600),
NewAcc = [{FileName, Domain} || Domain <- Domains] ++ Acc,
store_certs(Chains, NewAcc);
{error, Why} ->
{ok, [{FileName, Domain} || Domain <- Domains]};
{error, Why} = Err ->
?ERROR_MSG("Failed to write to ~s: ~s",
[FileName, file:format_error(Why)]),
store_certs(Chains, [])
end;
store_certs([], Acc) ->
Acc.
Err
end.
-spec sort_chains([{[cert()], priv_key()}], [cert()]) -> [{[cert()], priv_key()}].
sort_chains(Chains, InvalidCerts) ->
lists:sort(
fun({[Cert1|_], _}, {[Cert2|_], _}) ->
IsValid1 = not lists:member(Cert1, InvalidCerts),
IsValid2 = not lists:member(Cert2, InvalidCerts),
if IsValid1 and not IsValid2 ->
false;
IsValid2 and not IsValid1 ->
true;
true ->
compare_expiration_date(Cert1, Cert2)
end
end, Chains).
%% Returns true if the first certificate has sooner expiration date
-spec compare_expiration_date(cert(), cert()) -> boolean().
compare_expiration_date(#'OTPCertificate'{
tbsCertificate =
#'OTPTBSCertificate'{
validity = #'Validity'{notAfter = After1}}},
#'OTPCertificate'{
tbsCertificate =
#'OTPTBSCertificate'{
validity = #'Validity'{notAfter = After2}}}) ->
get_timestamp(After1) =< get_timestamp(After2).
-spec load_certfile(file:filename_all()) -> {ok, [cert()], [priv_key()]} |
{error, cert_error() | file:posix()}.
@ -521,8 +557,9 @@ decode_certs(PemEntries) ->
{error, not_der}
end.
-spec validate([{path, [cert()]}], boolean()) -> [{cert(), bad_cert()}].
validate(Paths, true) ->
-spec validate([{path, [cert()]}], state()) -> [cert()].
validate(Paths, #state{validate = true} = State) ->
?DEBUG("Validating certificates", []),
{ok, Re} = re:compile("^[a-f0-9]+\\.[0-9]+$", [unicode]),
Hashes = case file:list_dir(ca_dir()) of
{ok, Files} ->
@ -551,7 +588,10 @@ validate(Paths, true) ->
ok ->
false;
{error, Cert, Reason} ->
{true, {Cert, Reason}}
File = maps:get(Cert, State#state.certs),
?WARNING_MSG("Failed to validate certificate from ~s: ~s",
[File, format_error(Reason)]),
{true, Cert}
end
end, Paths);
validate(_, _) ->
@ -715,6 +755,7 @@ do_read_ca_file(Path) ->
-spec match_cert_keys([{path, [cert()]}], [priv_key()])
-> {ok, [{cert(), priv_key()}]} | {error, {bad_cert, missing_priv_key}}.
match_cert_keys(CertPaths, PrivKeys) ->
?DEBUG("Finding matched certificate keys", []),
KeyPairs = [{pubkey_from_privkey(PrivKey), PrivKey} || PrivKey <- PrivKeys],
match_cert_keys(CertPaths, KeyPairs, []).
@ -763,6 +804,7 @@ pubkey_from_privkey(#'ECPrivateKey'{publicKey = Key}) ->
-spec get_cert_paths([cert()], digraph:graph()) -> [{path, [cert()]}].
get_cert_paths(Certs, G) ->
?DEBUG("Building certificates graph", []),
{NewCerts, OldCerts} =
lists:partition(
fun(Cert) ->
@ -838,6 +880,15 @@ short_name_hash(_) ->
"".
-endif.
-spec get_timestamp({utcTime | generalTime, string()}) -> string().
get_timestamp({utcTime, [Y1,Y2|T]}) ->
case list_to_integer([Y1,Y2]) of
N when N >= 50 -> [$1,$9,Y1,Y2|T];
_ -> [$2,$0,Y1,Y2|T]
end;
get_timestamp({generalTime, TS}) ->
TS.
wildcard(Path) when is_binary(Path) ->
wildcard(binary_to_list(Path));
wildcard(Path) ->