diff --git a/rebar.config b/rebar.config index c91b16251..69dbba775 100644 --- a/rebar.config +++ b/rebar.config @@ -34,7 +34,7 @@ {if_var_true, mysql, {p1_mysql, ".*", {git, "https://github.com/processone/p1_mysql", {tag, "1.0.2"}}}}, {if_var_true, pgsql, {p1_pgsql, ".*", {git, "https://github.com/processone/p1_pgsql", - {tag, "1.1.1"}}}}, + {tag, "1.1.2"}}}}, {if_var_true, sqlite, {sqlite3, ".*", {git, "https://github.com/processone/erlang-sqlite3", {tag, "1.1.5"}}}}, {if_var_true, pam, {p1_pam, ".*", {git, "https://github.com/processone/epam", diff --git a/src/ejabberd_sql.erl b/src/ejabberd_sql.erl index 53e812063..46dd68297 100644 --- a/src/ejabberd_sql.erl +++ b/src/ejabberd_sql.erl @@ -279,8 +279,8 @@ init([Host, StartInterval]) -> connecting(connect, #state{host = Host} = State) -> ConnectRes = case db_opts(Host) of - [mysql | Args] -> apply(fun mysql_connect/5, Args); - [pgsql | Args] -> apply(fun pgsql_connect/5, Args); + [mysql | Args] -> apply(fun mysql_connect/7, Args); + [pgsql | Args] -> apply(fun pgsql_connect/7, Args); [sqlite | Args] -> apply(fun sqlite_connect/1, Args); [mssql | Args] -> apply(fun odbc_connect/1, Args); [odbc | Args] -> apply(fun odbc_connect/1, Args) @@ -782,13 +782,14 @@ sqlite_to_odbc(_Host, _) -> %% part of init/1 %% Open a database connection to PostgreSQL -pgsql_connect(Server, Port, DB, Username, Password) -> +pgsql_connect(Server, Port, DB, Username, Password, Transport, SSLOpts) -> case pgsql:connect([{host, Server}, {database, DB}, {user, Username}, {password, Password}, {port, Port}, - {as_binary, true}]) of + {transport, Transport}, + {as_binary, true}|SSLOpts]) of {ok, Ref} -> pgsql:squery(Ref, [<<"alter database \"">>, DB, <<"\" set ">>, <<"standard_conforming_strings='off';">>]), @@ -837,7 +838,7 @@ pgsql_execute_to_odbc(_) -> {updated, undefined}. %% part of init/1 %% Open a database connection to MySQL -mysql_connect(Server, Port, DB, Username, Password) -> +mysql_connect(Server, Port, DB, Username, Password, _, _) -> case p1_mysql_conn:start(binary_to_list(Server), Port, binary_to_list(Username), binary_to_list(Password), @@ -921,6 +922,14 @@ db_opts(Host) -> Server = ejabberd_config:get_option({sql_server, Host}, fun iolist_to_binary/1, <<"localhost">>), + Transport = case ejabberd_config:get_option( + {sql_ssl, Host}, + fun(B) when is_boolean(B) -> B end, + false) of + false -> tcp; + true -> ssl + end, + warn_if_ssl_unsupported(Transport, Type), case Type of odbc -> [odbc, Server]; @@ -944,15 +953,54 @@ db_opts(Host) -> Pass = ejabberd_config:get_option({sql_password, Host}, fun iolist_to_binary/1, <<"">>), + SSLOpts = get_ssl_opts(Transport, Host), case Type of mssql -> [mssql, <<"DSN=", Host/binary, ";UID=", User/binary, ";PWD=", Pass/binary>>]; _ -> - [Type, Server, Port, DB, User, Pass] + [Type, Server, Port, DB, User, Pass, Transport, SSLOpts] end end. +warn_if_ssl_unsupported(tcp, _) -> + ok; +warn_if_ssl_unsupported(ssl, pgsql) -> + ok; +warn_if_ssl_unsupported(ssl, Type) -> + ?WARNING_MSG("SSL connection is not supported for ~s", [Type]). + +get_ssl_opts(ssl, Host) -> + Opts1 = case ejabberd_config:get_option({sql_ssl_certfile, Host}, + fun iolist_to_binary/1) of + undefined -> []; + CertFile -> [{certfile, CertFile}] + end, + Opts2 = case ejabberd_config:get_option({sql_ssl_cafile, Host}, + fun iolist_to_binary/1) of + undefined -> Opts1; + CAFile -> [{cacertfile, CAFile}|Opts1] + end, + case ejabberd_config:get_option({sql_ssl_verify, Host}, + fun(B) when is_boolean(B) -> B end, + false) of + true -> + case lists:keymember(cacertfile, 1, Opts2) of + true -> + [{verify, verify_peer}|Opts2]; + false -> + ?WARNING_MSG("SSL verification is enabled for " + "SQL connection, but option " + "'sql_ssl_cafile' is not set; " + "verification will be disabled", []), + Opts2 + end; + false -> + Opts2 + end; +get_ssl_opts(tcp, _) -> + []. + init_mssql(Host) -> Server = ejabberd_config:get_option({sql_server, Host}, fun iolist_to_binary/1, @@ -1061,7 +1109,12 @@ opt_type(sql_type) -> (odbc) -> odbc end; opt_type(sql_username) -> fun iolist_to_binary/1; +opt_type(sql_ssl) -> fun(B) when is_boolean(B) -> B end; +opt_type(sql_ssl_verify) -> fun(B) when is_boolean(B) -> B end; +opt_type(sql_ssl_certfile) -> fun iolist_to_binary/1; +opt_type(sql_ssl_cafile) -> fun iolist_to_binary/1; opt_type(_) -> [max_fsm_queue, sql_database, sql_keepalive_interval, sql_password, sql_port, sql_server, sql_type, - sql_username]. + sql_username, sql_ssl, sql_ssl_verify, sql_ssl_cerfile, + sql_ssl_cafile].