diff --git a/src/ejabberd_service.erl b/src/ejabberd_service.erl index 9d9f0294e..f80675c8e 100644 --- a/src/ejabberd_service.erl +++ b/src/ejabberd_service.erl @@ -53,8 +53,8 @@ {socket :: ejabberd_socket:socket_state(), sockmod = ejabberd_socket :: ejabberd_socket | ejabberd_frontend_socket, streamid = <<"">> :: binary(), - hosts = [] :: [binary()], - password = <<"">> :: binary(), + host_opts = dict:new() :: ?TDICT, + host = <<"">> :: binary(), access :: atom(), check_from = true :: boolean()}). @@ -126,18 +126,21 @@ init([{SockMod, Socket}, Opts]) -> {value, {_, A}} -> A; _ -> all end, - %% This should be improved probably - {Hosts, HostOpts} = case lists:keyfind(hosts, 1, Opts) of - {_, HOpts} -> - {[H || {H, _} <- HOpts], - lists:flatten( - [O || {_, O} <- HOpts])}; - _ -> - {[], []} - end, - Password = gen_mod:get_opt(password, HostOpts, - fun iolist_to_binary/1, - p1_sha:sha(crypto:rand_bytes(20))), + HostOpts = case lists:keyfind(hosts, 1, Opts) of + {hosts, HOpts} -> + lists:foldl( + fun({H, Os}, D) -> + P = proplists:get_value( + password, Os, + p1_sha:sha(crypto:rand_bytes(20))), + dict:store(H, P, D) + end, dict:new(), HOpts); + false -> + Pass = proplists:get_value( + password, Opts, + p1_sha:sha(crypto:rand_bytes(20))), + dict:from_list([{global, Pass}]) + end, Shaper = case lists:keysearch(shaper_rule, 1, Opts) of {value, {_, S}} -> S; _ -> none @@ -151,7 +154,7 @@ init([{SockMod, Socket}, Opts]) -> SockMod:change_shaper(Socket, Shaper), {ok, wait_for_stream, #state{socket = Socket, sockmod = SockMod, - streamid = new_id(), hosts = Hosts, password = Password, + streamid = new_id(), host_opts = HostOpts, access = Access, check_from = CheckFrom}}. %%---------------------------------------------------------------------- @@ -166,10 +169,33 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, case xml:get_attr_s(<<"xmlns">>, Attrs) of <<"jabber:component:accept">> -> To = xml:get_attr_s(<<"to">>, Attrs), - Header = io_lib:format(?STREAM_HEADER, - [StateData#state.streamid, xml:crypt(To)]), - send_text(StateData, Header), - {next_state, wait_for_handshake, StateData}; + Host = jid:nameprep(To), + if Host == error -> + Header = io_lib:format(?STREAM_HEADER, + [<<"none">>, ?MYNAME]), + send_text(StateData, + <<(list_to_binary(Header))/binary, + (?INVALID_XML_ERR)/binary, + (?STREAM_TRAILER)/binary>>), + {stop, normal, StateData}; + true -> + Header = io_lib:format(?STREAM_HEADER, + [StateData#state.streamid, xml:crypt(To)]), + send_text(StateData, Header), + HostOpts = case dict:is_key(Host, StateData#state.host_opts) of + true -> + StateData#state.host_opts; + false -> + case dict:find(global, StateData#state.host_opts) of + {ok, GlobalPass} -> + dict:from_list([{Host, GlobalPass}]); + error -> + StateData#state.host_opts + end + end, + {next_state, wait_for_handshake, + StateData#state{host = Host, host_opts = HostOpts}} + end; _ -> send_text(StateData, ?INVALID_HEADER_ERR), {stop, normal, StateData} @@ -188,21 +214,26 @@ wait_for_handshake({xmlstreamelement, El}, StateData) -> #xmlel{name = Name, children = Els} = El, case {Name, xml:get_cdata(Els)} of {<<"handshake">>, Digest} -> - case p1_sha:sha(<<(StateData#state.streamid)/binary, - (StateData#state.password)/binary>>) - of - Digest -> - send_text(StateData, <<"">>), - lists:foreach(fun (H) -> - ejabberd_router:register_route(H), - ?INFO_MSG("Route registered for service ~p~n", - [H]) - end, - StateData#state.hosts), - {next_state, stream_established, StateData}; - _ -> - send_text(StateData, ?INVALID_HANDSHAKE_ERR), - {stop, normal, StateData} + case dict:find(StateData#state.host, StateData#state.host_opts) of + {ok, Password} -> + case p1_sha:sha(<<(StateData#state.streamid)/binary, + Password/binary>>) of + Digest -> + send_text(StateData, <<"">>), + lists:foreach( + fun (H) -> + ejabberd_router:register_route(H), + ?INFO_MSG("Route registered for service ~p~n", + [H]) + end, dict:fetch_keys(StateData#state.host_opts)), + {next_state, stream_established, StateData}; + _ -> + send_text(StateData, ?INVALID_HANDSHAKE_ERR), + {stop, normal, StateData} + end; + _ -> + send_text(StateData, ?INVALID_HANDSHAKE_ERR), + {stop, normal, StateData} end; _ -> {next_state, wait_for_handshake, StateData} end; @@ -231,7 +262,7 @@ stream_established({xmlstreamelement, El}, StateData) -> FromJID1 = jid:from_string(From), case FromJID1 of #jid{lserver = Server} -> - case lists:member(Server, StateData#state.hosts) of + case dict:is_key(Server, StateData#state.host_opts) of true -> FromJID1; false -> error end; @@ -349,7 +380,7 @@ terminate(Reason, StateName, StateData) -> lists:foreach(fun (H) -> ejabberd_router:unregister_route(H) end, - StateData#state.hosts); + dict:fetch_keys(StateData#state.host_opts)); _ -> ok end, (StateData#state.sockmod):close(StateData#state.socket),