diff --git a/src/ejabberd_pkix.erl b/src/ejabberd_pkix.erl index 6bf3e420f..0f23e6871 100644 --- a/src/ejabberd_pkix.erl +++ b/src/ejabberd_pkix.erl @@ -49,9 +49,10 @@ -type bad_cert_reason() :: cert_expired | invalid_issuer | invalid_signature | name_not_permitted | missing_basic_constraint | invalid_key_usage | selfsigned_peer | unknown_sig_algo | - unknown_ca | missing_priv_key. --type bad_cert() :: {bad_cert, bad_cert_reason()}. --type cert_error() :: not_cert | not_der | not_pem | encrypted. + unknown_ca | missing_priv_key | unknown_key_algo | + unknown_key_type | encrypted | not_der | not_cert | + not_pem. +-type cert_error() :: {bad_cert, bad_cert_reason()}. -export_type([cert_error/0]). -define(CA_CACHE, ca_cache). @@ -76,13 +77,13 @@ route_registered(Route) -> gen_server:call(?MODULE, {route_registered, Route}). -spec format_error(cert_error() | file:posix()) -> string(). -format_error(not_cert) -> +format_error({bad_cert, not_cert}) -> "no PEM encoded certificates found"; -format_error(not_pem) -> +format_error({bad_cert, not_pem}) -> "failed to decode from PEM format"; -format_error(not_der) -> +format_error({bad_cert, not_der}) -> "failed to decode from DER format"; -format_error(encrypted) -> +format_error({bad_cert, encrypted}) -> "encrypted certificate"; format_error({bad_cert, cert_expired}) -> "certificate is no longer valid as its expiration date has passed"; @@ -103,6 +104,10 @@ format_error({bad_cert, selfsigned_peer}) -> "self-signed certificate"; format_error({bad_cert, unknown_sig_algo}) -> "certificate is signed using unknown algorithm"; +format_error({bad_cert, unknown_key_algo}) -> + "unknown private key algorithm"; +format_error({bad_cert, unknown_key_type}) -> + "private key is of unknown type"; format_error({bad_cert, unknown_ca}) -> "certificate is signed by unknown CA"; format_error({bad_cert, missing_priv_key}) -> @@ -330,7 +335,7 @@ get_certfiles_from_config_options(_State) -> [iolist_to_binary(P) || P <- lists:usort(Local ++ Global)]. -spec add_certfiles(state()) -> {ok, state()} | - {error, bad_cert() | file:posix()}. + {error, cert_error() | file:posix()}. add_certfiles(State) -> ?DEBUG("Reading certificates", []), Paths = get_certfiles_from_config_options(State), @@ -345,7 +350,7 @@ add_certfiles(State) -> end. -spec add_certfiles(binary(), state()) -> {ok, state()} | - {error, bad_cert() | file:posix()}. + {error, cert_error() | file:posix()}. add_certfiles(Host, State) -> State1 = lists:foldl( fun(Opt, AccState) -> @@ -388,7 +393,7 @@ add_certfile(Path, State) -> end end. --spec build_chain_and_check(state()) -> ok | {error, bad_cert() | file:posix()}. +-spec build_chain_and_check(state()) -> ok | {error, cert_error() | file:posix()}. build_chain_and_check(State) -> CertPaths = get_cert_paths(maps:keys(State#state.certs), State#state.graph), case match_cert_keys(CertPaths, State#state.keys) of @@ -515,55 +520,66 @@ pem_decode(Data) -> (_) -> false end, Objects) of {[], []} -> - {error, not_cert}; + {error, {bad_cert, not_cert}}; {Certs, PrivKeys} -> {ok, Certs, PrivKeys} end end - catch _:_ -> - {error, not_pem} + catch E:R -> + St = erlang:get_stacktrace(), + ?DEBUG("PEM decoding stacktrace: ~p", [{E, {R, St}}]), + {error, {bad_cert, not_pem}} end. --spec decode_certs([public_key:pem_entry()]) -> {[cert()], [priv_key()]} | - {error, not_der | encrypted}. +-spec decode_certs([public_key:pem_entry()]) -> [cert() | priv_key()] | + {error, cert_error()}. decode_certs(PemEntries) -> - try lists:foldr( - fun(_, {error, _} = Err) -> - Err; - ({_, _, Flag}, _) when Flag /= not_encrypted -> - {error, encrypted}; - ({'Certificate', Der, _}, Acc) -> - [public_key:pkix_decode_cert(Der, otp)|Acc]; - ({'PrivateKeyInfo', Der, not_encrypted}, Acc) -> - #'PrivateKeyInfo'{privateKeyAlgorithm = - #'PrivateKeyInfo_privateKeyAlgorithm'{ - algorithm = Algo}, - privateKey = Key} = - public_key:der_decode('PrivateKeyInfo', Der), - case Algo of - ?'rsaEncryption' -> - [public_key:der_decode( - 'RSAPrivateKey', iolist_to_binary(Key))|Acc]; - ?'id-dsa' -> - [public_key:der_decode( - 'DSAPrivateKey', iolist_to_binary(Key))|Acc]; - ?'id-ecPublicKey' -> - [public_key:der_decode( - 'ECPrivateKey', iolist_to_binary(Key))|Acc]; - _ -> - Acc - end; - ({Tag, Der, _}, Acc) when Tag == 'RSAPrivateKey'; - Tag == 'DSAPrivateKey'; - Tag == 'ECPrivateKey' -> - [public_key:der_decode(Tag, Der)|Acc]; - (_, Acc) -> - Acc - end, [], PemEntries) - catch _:_ -> - {error, not_der} + try lists:flatmap( + fun({Tag, Der, Flag}) -> + decode_cert(Tag, Der, Flag) + end, PemEntries) + catch _:{bad_cert, _} = Err -> + {error, Err}; + E:R -> + St = erlang:get_stacktrace(), + ?DEBUG("DER decoding stacktrace: ~p", [{E, {R, St}}]), + {error, {bad_cert, not_der}} end. +-spec decode_cert(atom(), binary(), atom()) -> [cert() | priv_key()]. +decode_cert(_, _, Flag) when Flag /= not_encrypted -> + erlang:error({bad_cert, encrypted}); +decode_cert('Certificate', Der, _) -> + [public_key:pkix_decode_cert(Der, otp)]; +decode_cert('PrivateKeyInfo', Der, not_encrypted) -> + case public_key:der_decode('PrivateKeyInfo', Der) of + #'PrivateKeyInfo'{privateKeyAlgorithm = + #'PrivateKeyInfo_privateKeyAlgorithm'{ + algorithm = Algo}, + privateKey = Key} -> + KeyBin = iolist_to_binary(Key), + case Algo of + ?'rsaEncryption' -> + [public_key:der_decode('RSAPrivateKey', KeyBin)]; + ?'id-dsa' -> + [public_key:der_decode('DSAPrivateKey', KeyBin)]; + ?'id-ecPublicKey' -> + [public_key:der_decode('ECPrivateKey', KeyBin)]; + _ -> + erlang:error({bad_cert, unknown_key_algo}) + end; + #'RSAPrivateKey'{} = Key -> [Key]; + #'DSAPrivateKey'{} = Key -> [Key]; + #'ECPrivateKey'{} = Key -> [Key]; + _ -> erlang:error({bad_cert, unknown_key_type}) + end; +decode_cert(Tag, Der, _) when Tag == 'RSAPrivateKey'; + Tag == 'DSAPrivateKey'; + Tag == 'ECPrivateKey' -> + [public_key:der_decode(Tag, Der)]; +decode_cert(_, _, _) -> + []. + -spec validate([{path, [cert()]}], state()) -> [cert()]. validate(Paths, #state{validate = true} = State) -> ?DEBUG("Validating certificates", []), @@ -604,7 +620,7 @@ validate(Paths, #state{validate = true} = State) -> validate(_, _) -> []. --spec validate_path([cert()], dict:dict()) -> ok | {error, cert(), bad_cert()}. +-spec validate_path([cert()], dict:dict()) -> ok | {error, cert(), cert_error()}. validate_path([Cert|_] = Certs, Cache) -> case find_local_issuer(Cert, Cache) of {ok, IssuerCert} ->