25
1
mirror of https://github.com/processone/ejabberd.git synced 2024-11-28 16:34:13 +01:00

Process 'Contact' headers more accurately (as per RFC3261)

This commit is contained in:
Evgeniy Khramtsov 2014-05-30 23:11:46 +04:00
parent 7261cb29ac
commit 32998f7e18
2 changed files with 183 additions and 139 deletions

View File

@ -66,15 +66,15 @@ wait_for_request({#sip{type = request} = Req, TrID}, State) ->
Opts = State#state.opts, Opts = State#state.opts,
Req1 = prepare_request(State#state.host, Req), Req1 = prepare_request(State#state.host, Req),
case connect(Req1, Opts) of case connect(Req1, Opts) of
{ok, SIPSockets} -> {ok, SIPSocketsWithURIs} ->
NewState = NewState =
lists:foldl( lists:foldl(
fun(_SIPSocket, {error, _} = Err) -> fun(_SIPSocketWithURI, {error, _} = Err) ->
Err; Err;
(SIPSocket, #state{tr_ids = TrIDs} = AccState) -> ({SIPSocket, URI}, #state{tr_ids = TrIDs} = AccState) ->
Req2 = add_record_route(SIPSocket, State#state.host, Req1), Req2 = add_record_route(SIPSocket, State#state.host, Req1),
Req3 = add_via(SIPSocket, State#state.host, Req2), Req3 = add_via(SIPSocket, State#state.host, Req2),
case esip:request(SIPSocket, Req3, case esip:request(SIPSocket, Req3#sip{uri = URI},
{?MODULE, route, [self()]}) of {?MODULE, route, [self()]}) of
{ok, ClientTrID} -> {ok, ClientTrID} ->
NewTrIDs = [ClientTrID|TrIDs], NewTrIDs = [ClientTrID|TrIDs],
@ -83,7 +83,7 @@ wait_for_request({#sip{type = request} = Req, TrID}, State) ->
cancel_pending_transactions(AccState), cancel_pending_transactions(AccState),
Err Err
end end
end, State, SIPSockets), end, State, SIPSocketsWithURIs),
case NewState of case NewState of
{error, _} = Err -> {error, _} = Err ->
{Status, Reason} = esip:error_status(Err), {Status, Reason} = esip:error_status(Err),
@ -214,7 +214,7 @@ connect(#sip{hdrs = Hdrs} = Req, Opts) ->
false -> false ->
case esip:connect(Req, Opts) of case esip:connect(Req, Opts) of
{ok, SIPSock} -> {ok, SIPSock} ->
{ok, [SIPSock]}; {ok, [{SIPSock, Req#sip.uri}]};
{error, _} = Err -> {error, _} = Err ->
Err Err
end end

View File

@ -23,11 +23,13 @@
-include("esip.hrl"). -include("esip.hrl").
-define(CALL_TIMEOUT, timer:seconds(30)). -define(CALL_TIMEOUT, timer:seconds(30)).
-define(DEFAULT_EXPIRES, 3600).
-record(binding, {socket = #sip_socket{}, -record(binding, {socket = #sip_socket{},
call_id = <<"">> :: binary(), call_id = <<"">> :: binary(),
cseq = 0 :: non_neg_integer(), cseq = 0 :: non_neg_integer(),
timestamp = now() :: erlang:timestamp(), timestamp = now() :: erlang:timestamp(),
contact :: {binary(), #uri{}, [{binary(), binary()}]},
tref = make_ref() :: reference(), tref = make_ref() :: reference(),
expires = 0 :: non_neg_integer()}). expires = 0 :: non_neg_integer()}).
@ -50,20 +52,19 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) ->
US = {LUser, LServer}, US = {LUser, LServer},
CallID = esip:get_hdr('call-id', Hdrs), CallID = esip:get_hdr('call-id', Hdrs),
CSeq = esip:get_hdr('cseq', Hdrs), CSeq = esip:get_hdr('cseq', Hdrs),
Expires = esip:get_hdr('expires', Hdrs, 0), Expires = esip:get_hdr('expires', Hdrs, ?DEFAULT_EXPIRES),
case esip:get_hdrs('contact', Hdrs) of case esip:get_hdrs('contact', Hdrs) of
[<<"*">>] when Expires == 0 -> [<<"*">>] when Expires == 0 ->
case unregister_session(US, SIPSock, CallID, CSeq) of case unregister_session(US, CallID, CSeq) of
ok -> {ok, ContactsWithExpires} ->
?INFO_MSG("unregister SIP session for user ~s@~s from ~s", ?INFO_MSG("unregister SIP session for user ~s@~s from ~s",
[LUser, LServer, inet_parse:ntoa(PeerIP)]), [LUser, LServer, inet_parse:ntoa(PeerIP)]),
Contact = {<<"">>, #uri{user = LUser, host = LServer}, Cs = prepare_contacts_to_send(ContactsWithExpires),
[{<<"expires">>, <<"0">>}]},
mod_sip:make_response( mod_sip:make_response(
Req, Req,
#sip{type = response, #sip{type = response,
status = 200, status = 200,
hdrs = [{'contact', [Contact]}]}); hdrs = [{'contact', Cs}]});
{error, Why} -> {error, Why} ->
{Status, Reason} = make_status(Why), {Status, Reason} = make_status(Why),
mod_sip:make_response( mod_sip:make_response(
@ -72,51 +73,35 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) ->
reason = Reason}) reason = Reason})
end; end;
[{_, _URI, _Params}|_] = Contacts -> [{_, _URI, _Params}|_] = Contacts ->
ExpiresList = lists:map( ContactsWithExpires = make_contacts_with_expires(Contacts, Expires),
fun({_, _, Params}) -> Expires1 = lists:max([E || {_, E} <- ContactsWithExpires]),
case to_integer(
esip:get_param(
<<"expires">>, Params),
0, (1 bsl 32)-1) of
{ok, E} -> E;
_ -> Expires
end
end, Contacts),
Expires1 = lists:max(ExpiresList),
Contact = {<<"">>, #uri{user = LUser, host = LServer},
[{<<"expires">>, jlib:integer_to_binary(Expires1)}]},
MinExpires = min_expires(), MinExpires = min_expires(),
if Expires1 >= MinExpires -> if Expires1 > 0, Expires1 < MinExpires ->
case register_session(US, SIPSock, CallID, CSeq, Expires1) of
ok ->
?INFO_MSG("register SIP session for user ~s@~s from ~s",
[LUser, LServer, inet_parse:ntoa(PeerIP)]),
mod_sip:make_response(
Req,
#sip{type = response,
status = 200,
hdrs = [{'contact', [Contact]}]});
{error, Why} ->
{Status, Reason} = make_status(Why),
mod_sip:make_response(
Req, #sip{type = response,
status = Status,
reason = Reason})
end;
Expires1 > 0, Expires1 < MinExpires ->
mod_sip:make_response( mod_sip:make_response(
Req, #sip{type = response, Req, #sip{type = response,
status = 423, status = 423,
hdrs = [{'min-expires', MinExpires}]}); hdrs = [{'min-expires', MinExpires}]});
true -> true ->
case unregister_session(US, SIPSock, CallID, CSeq) of case register_session(US, SIPSock, CallID, CSeq,
ok -> ContactsWithExpires) of
?INFO_MSG("unregister SIP session for user ~s@~s from ~s", {ok, Res} ->
[LUser, LServer, inet_parse:ntoa(PeerIP)]), if Res == updated ->
?INFO_MSG("register SIP session for user "
"~s@~s from ~s",
[LUser, LServer,
inet_parse:ntoa(PeerIP)]);
Res == deleted ->
?INFO_MSG("unregister SIP session for user "
"~s@~s from ~s",
[LUser, LServer,
inet_parse:ntoa(PeerIP)])
end,
Cs = prepare_contacts_to_send(ContactsWithExpires),
mod_sip:make_response( mod_sip:make_response(
Req, Req,
#sip{type = response, status = 200, #sip{type = response,
hdrs = [{'contact', [Contact]}]}); status = 200,
hdrs = [{'contact', Cs}]});
{error, Why} -> {error, Why} ->
{Status, Reason} = make_status(Why), {Status, Reason} = make_status(Why),
mod_sip:make_response( mod_sip:make_response(
@ -128,22 +113,15 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) ->
[] -> [] ->
case mnesia:dirty_read(sip_session, US) of case mnesia:dirty_read(sip_session, US) of
[#sip_session{bindings = Bindings}] -> [#sip_session{bindings = Bindings}] ->
case pop_previous_binding(SIPSock, Bindings) of ContactsWithExpires =
{ok, #binding{expires = Expires1}, _} -> lists:map(
Contact = {<<"">>, fun(#binding{contact = Contact, expires = Es}) ->
#uri{user = LUser, host = LServer}, {Contact, Es}
[{<<"expires">>, end, Bindings),
jlib:integer_to_binary(Expires1)}]}, Cs = prepare_contacts_to_send(ContactsWithExpires),
mod_sip:make_response( mod_sip:make_response(
Req, #sip{type = response, status = 200, Req, #sip{type = response, status = 200,
hdrs = [{'contact', [Contact]}]}); hdrs = [{'contact', Cs}]});
{error, notfound} ->
{Status, Reason} = make_status(notfound),
mod_sip:make_response(
Req, #sip{type = response,
status = Status,
reason = Reason})
end;
[] -> [] ->
{Status, Reason} = make_status(notfound), {Status, Reason} = make_status(notfound),
mod_sip:make_response( mod_sip:make_response(
@ -158,7 +136,11 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) ->
find_sockets(U, S) -> find_sockets(U, S) ->
case mnesia:dirty_read(sip_session, {U, S}) of case mnesia:dirty_read(sip_session, {U, S}) of
[#sip_session{bindings = Bindings}] -> [#sip_session{bindings = Bindings}] ->
[Binding#binding.socket || Binding <- Bindings]; lists:map(
fun(#binding{contact = {_, URI, _},
socket = Socket}) ->
{Socket, URI}
end, Bindings);
[] -> [] ->
[] []
end. end.
@ -176,8 +158,8 @@ init([]) ->
handle_call({write, Session}, _From, State) -> handle_call({write, Session}, _From, State) ->
Res = write_session(Session), Res = write_session(Session),
{reply, Res, State}; {reply, Res, State};
handle_call({delete, US, SIPSocket, CallID, CSeq}, _From, State) -> handle_call({delete, US, CallID, CSeq}, _From, State) ->
Res = delete_session(US, SIPSocket, CallID, CSeq), Res = delete_session(US, CallID, CSeq),
{reply, Res, State}; {reply, Res, State};
handle_call(_Request, _From, State) -> handle_call(_Request, _From, State) ->
Reply = ok, Reply = ok,
@ -189,8 +171,8 @@ handle_cast(_Msg, State) ->
handle_info({write, Session}, State) -> handle_info({write, Session}, State) ->
write_session(Session), write_session(Session),
{noreply, State}; {noreply, State};
handle_info({delete, US, SIPSocket, CallID, CSeq}, State) -> handle_info({delete, US, CallID, CSeq}, State) ->
delete_session(US, SIPSocket, CallID, CSeq), delete_session(US, CallID, CSeq),
{noreply, State}; {noreply, State};
handle_info({timeout, TRef, US}, State) -> handle_info({timeout, TRef, US}, State) ->
delete_expired_session(US, TRef), delete_expired_session(US, TRef),
@ -208,70 +190,102 @@ code_change(_OldVsn, State, _Extra) ->
%%%=================================================================== %%%===================================================================
%%% Internal functions %%% Internal functions
%%%=================================================================== %%%===================================================================
register_session(US, SIPSocket, CallID, CSeq, Expires) -> register_session(US, SIPSocket, CallID, CSeq, ContactsWithExpires) ->
Session = #sip_session{us = US, Bindings = lists:map(
bindings = [#binding{socket = SIPSocket, fun({Contact, Expires}) ->
#binding{socket = SIPSocket,
call_id = CallID, call_id = CallID,
cseq = CSeq, cseq = CSeq,
timestamp = now(), timestamp = now(),
expires = Expires}]}, contact = Contact,
expires = Expires}
end, ContactsWithExpires),
Session = #sip_session{us = US, bindings = Bindings},
call({write, Session}). call({write, Session}).
unregister_session(US, SIPSocket, CallID, CSeq) -> unregister_session(US, CallID, CSeq) ->
Msg = {delete, US, SIPSocket, CallID, CSeq}, Msg = {delete, US, CallID, CSeq},
call(Msg). call(Msg).
write_session(#sip_session{us = {U, S} = US, write_session(#sip_session{us = {U, S} = US, bindings = NewBindings}) ->
bindings = [#binding{socket = SIPSocket, PrevBindings = case mnesia:dirty_read(sip_session, US) of
call_id = CallID, [#sip_session{bindings = PrevBindings1}] ->
PrevBindings1;
[] ->
[]
end,
Res = lists:foldl(
fun(_, {error, _} = Err) ->
Err;
(#binding{call_id = CallID,
expires = Expires, expires = Expires,
cseq = CSeq} = Binding]}) -> cseq = CSeq} = Binding, {Add, Del}) ->
case mnesia:dirty_read(sip_session, US) of case find_binding(Binding, PrevBindings) of
[#sip_session{bindings = Bindings}] -> {ok, #binding{call_id = CallID, cseq = PrevCSeq}}
case pop_previous_binding(SIPSocket, Bindings) of
{ok, #binding{call_id = CallID, cseq = PrevCSeq}, _}
when PrevCSeq > CSeq -> when PrevCSeq > CSeq ->
{error, cseq_out_of_order}; {error, cseq_out_of_order};
{ok, #binding{tref = Tref}, Bindings1} -> {ok, PrevBinding} when Expires == 0 ->
erlang:cancel_timer(Tref), {Add, [PrevBinding|Del]};
NewTRef = erlang:start_timer(Expires * 1000, self(), US), {ok, _} ->
NewBindings = [Binding#binding{tref = NewTRef}|Bindings1], {[Binding|Add], Del};
mnesia:dirty_write( {error, notfound} when Expires == 0 ->
#sip_session{us = US, bindings = NewBindings}); {error, notfound};
{error, notfound} -> {error, notfound} ->
MaxSessions = ejabberd_sm:get_max_user_sessions(U, S), {[Binding|Add], Del}
if length(Bindings) < MaxSessions ->
NewTRef = erlang:start_timer(Expires * 1000, self(), US),
NewBindings = [Binding#binding{tref = NewTRef}|Bindings],
mnesia:dirty_write(
#sip_session{us = US, bindings = NewBindings});
true ->
{error, too_many_sessions}
end end
end; end, {[], []}, NewBindings),
MaxSessions = ejabberd_sm:get_max_user_sessions(U, S),
case Res of
{error, Why} ->
{error, Why};
{AddBindings, _} when length(AddBindings) > MaxSessions ->
{error, too_many_sessions};
{AddBindings, DelBindings} ->
lists:foreach(
fun(#binding{tref = TRef}) ->
erlang:cancel_timer(TRef)
end, DelBindings),
Bindings = lists:map(
fun(#binding{tref = TRef,
expires = Expires} = Binding) ->
erlang:cancel_timer(TRef),
NewTRef = erlang:start_timer(
Expires * 1000, self(), US),
Binding#binding{tref = NewTRef}
end, AddBindings),
case Bindings of
[] -> [] ->
NewTRef = erlang:start_timer(Expires * 1000, self(), US), mnesia:dirty_delete(sip_session, US),
NewBindings = [Binding#binding{tref = NewTRef}], {ok, deleted};
mnesia:dirty_write(#sip_session{us = US, bindings = NewBindings}) _ ->
mnesia:dirty_write(
#sip_session{us = US, bindings = Bindings}),
{ok, updated}
end
end. end.
delete_session(US, SIPSocket, CallID, CSeq) -> delete_session(US, CallID, CSeq) ->
case mnesia:dirty_read(sip_session, US) of case mnesia:dirty_read(sip_session, US) of
[#sip_session{bindings = Bindings}] -> [#sip_session{bindings = Bindings}] ->
case pop_previous_binding(SIPSocket, Bindings) of case lists:all(
{ok, #binding{call_id = CallID, cseq = PrevCSeq}, _} fun(B) when B#binding.call_id == CallID,
when PrevCSeq > CSeq -> B#binding.cseq > CSeq ->
{error, cseq_out_of_order}; false;
{ok, #binding{tref = TRef}, []} -> (_) ->
true
end, Bindings) of
true ->
ContactsWithExpires =
lists:map(
fun(#binding{contact = Contact,
tref = TRef}) ->
erlang:cancel_timer(TRef), erlang:cancel_timer(TRef),
mnesia:dirty_delete(sip_session, US); {Contact, 0}
{ok, #binding{tref = TRef}, NewBindings} -> end, Bindings),
erlang:cancel_timer(TRef), mnesia:dirty_delete(sip_session, US),
mnesia:dirty_write(sip_session, {ok, ContactsWithExpires};
#sip_session{us = US, false ->
bindings = NewBindings}); {error, cseq_out_of_order}
{error, notfound} ->
{error, notfound}
end; end;
[] -> [] ->
{error, notfound} {error, notfound}
@ -308,17 +322,6 @@ to_integer(Bin, Min, Max) ->
error error
end. end.
pop_previous_binding(#sip_socket{peer = Peer}, Bindings) ->
case lists:partition(
fun(#binding{socket = #sip_socket{peer = Peer1}}) ->
Peer1 == Peer
end, Bindings) of
{[Binding], RestBindings} ->
{ok, Binding, RestBindings};
_ ->
{error, notfound}
end.
call(Msg) -> call(Msg) ->
case catch ?GEN_SERVER:call(?MODULE, Msg, ?CALL_TIMEOUT) of case catch ?GEN_SERVER:call(?MODULE, Msg, ?CALL_TIMEOUT) of
{'EXIT', {timeout, _}} -> {'EXIT', {timeout, _}} ->
@ -329,6 +332,47 @@ call(Msg) ->
Reply Reply
end. end.
make_contacts_with_expires(Contacts, Expires) ->
lists:map(
fun({Name, URI, Params}) ->
E1 = case to_integer(esip:get_param(<<"expires">>, Params),
0, (1 bsl 32)-1) of
{ok, E} -> E;
_ -> Expires
end,
Params1 = lists:keydelete(<<"expires">>, 1, Params),
{{Name, URI, Params1}, E1}
end, Contacts).
prepare_contacts_to_send(ContactsWithExpires) ->
lists:map(
fun({{Name, URI, Params}, Expires}) ->
Params1 = esip:set_param(<<"expires">>,
list_to_binary(
integer_to_list(Expires)),
Params),
{Name, URI, Params1}
end, ContactsWithExpires).
find_binding(#binding{contact = {_, URI1, _}} = OrigBinding,
[#binding{contact = {_, URI2, _}} = Binding|Bindings]) ->
case cmp_uri(URI1, URI2) of
true ->
{ok, Binding};
false ->
find_binding(OrigBinding, Bindings)
end;
find_binding(_, []) ->
{error, notfound}.
%% TODO: this is *totally* wrong.
%% Rewrite this using URI comparison rules
cmp_uri(#uri{user = U, host = H, port = P},
#uri{user = U, host = H, port = P}) ->
true;
cmp_uri(_, _) ->
false.
make_status(notfound) -> make_status(notfound) ->
{404, esip:reason(404)}; {404, esip:reason(404)};
make_status(cseq_out_of_order) -> make_status(cseq_out_of_order) ->