diff --git a/src/mod_sip_proxy.erl b/src/mod_sip_proxy.erl index b05c49061..185d72afe 100644 --- a/src/mod_sip_proxy.erl +++ b/src/mod_sip_proxy.erl @@ -66,15 +66,15 @@ wait_for_request({#sip{type = request} = Req, TrID}, State) -> Opts = State#state.opts, Req1 = prepare_request(State#state.host, Req), case connect(Req1, Opts) of - {ok, SIPSockets} -> + {ok, SIPSocketsWithURIs} -> NewState = lists:foldl( - fun(_SIPSocket, {error, _} = Err) -> + fun(_SIPSocketWithURI, {error, _} = Err) -> Err; - (SIPSocket, #state{tr_ids = TrIDs} = AccState) -> + ({SIPSocket, URI}, #state{tr_ids = TrIDs} = AccState) -> Req2 = add_record_route(SIPSocket, State#state.host, Req1), Req3 = add_via(SIPSocket, State#state.host, Req2), - case esip:request(SIPSocket, Req3, + case esip:request(SIPSocket, Req3#sip{uri = URI}, {?MODULE, route, [self()]}) of {ok, ClientTrID} -> NewTrIDs = [ClientTrID|TrIDs], @@ -83,7 +83,7 @@ wait_for_request({#sip{type = request} = Req, TrID}, State) -> cancel_pending_transactions(AccState), Err end - end, State, SIPSockets), + end, State, SIPSocketsWithURIs), case NewState of {error, _} = Err -> {Status, Reason} = esip:error_status(Err), @@ -214,7 +214,7 @@ connect(#sip{hdrs = Hdrs} = Req, Opts) -> false -> case esip:connect(Req, Opts) of {ok, SIPSock} -> - {ok, [SIPSock]}; + {ok, [{SIPSock, Req#sip.uri}]}; {error, _} = Err -> Err end diff --git a/src/mod_sip_registrar.erl b/src/mod_sip_registrar.erl index 689efe48e..5080cf4a7 100644 --- a/src/mod_sip_registrar.erl +++ b/src/mod_sip_registrar.erl @@ -23,11 +23,13 @@ -include("esip.hrl"). -define(CALL_TIMEOUT, timer:seconds(30)). +-define(DEFAULT_EXPIRES, 3600). -record(binding, {socket = #sip_socket{}, call_id = <<"">> :: binary(), cseq = 0 :: non_neg_integer(), timestamp = now() :: erlang:timestamp(), + contact :: {binary(), #uri{}, [{binary(), binary()}]}, tref = make_ref() :: reference(), expires = 0 :: non_neg_integer()}). @@ -50,20 +52,19 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) -> US = {LUser, LServer}, CallID = esip:get_hdr('call-id', Hdrs), CSeq = esip:get_hdr('cseq', Hdrs), - Expires = esip:get_hdr('expires', Hdrs, 0), + Expires = esip:get_hdr('expires', Hdrs, ?DEFAULT_EXPIRES), case esip:get_hdrs('contact', Hdrs) of [<<"*">>] when Expires == 0 -> - case unregister_session(US, SIPSock, CallID, CSeq) of - ok -> + case unregister_session(US, CallID, CSeq) of + {ok, ContactsWithExpires} -> ?INFO_MSG("unregister SIP session for user ~s@~s from ~s", [LUser, LServer, inet_parse:ntoa(PeerIP)]), - Contact = {<<"">>, #uri{user = LUser, host = LServer}, - [{<<"expires">>, <<"0">>}]}, + Cs = prepare_contacts_to_send(ContactsWithExpires), mod_sip:make_response( Req, #sip{type = response, status = 200, - hdrs = [{'contact', [Contact]}]}); + hdrs = [{'contact', Cs}]}); {error, Why} -> {Status, Reason} = make_status(Why), mod_sip:make_response( @@ -72,51 +73,35 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) -> reason = Reason}) end; [{_, _URI, _Params}|_] = Contacts -> - ExpiresList = lists:map( - fun({_, _, Params}) -> - case to_integer( - esip:get_param( - <<"expires">>, Params), - 0, (1 bsl 32)-1) of - {ok, E} -> E; - _ -> Expires - end - end, Contacts), - Expires1 = lists:max(ExpiresList), - Contact = {<<"">>, #uri{user = LUser, host = LServer}, - [{<<"expires">>, jlib:integer_to_binary(Expires1)}]}, + ContactsWithExpires = make_contacts_with_expires(Contacts, Expires), + Expires1 = lists:max([E || {_, E} <- ContactsWithExpires]), MinExpires = min_expires(), - if Expires1 >= MinExpires -> - case register_session(US, SIPSock, CallID, CSeq, Expires1) of - ok -> - ?INFO_MSG("register SIP session for user ~s@~s from ~s", - [LUser, LServer, inet_parse:ntoa(PeerIP)]), + if Expires1 > 0, Expires1 < MinExpires -> + mod_sip:make_response( + Req, #sip{type = response, + status = 423, + hdrs = [{'min-expires', MinExpires}]}); + true -> + case register_session(US, SIPSock, CallID, CSeq, + ContactsWithExpires) of + {ok, Res} -> + if Res == updated -> + ?INFO_MSG("register SIP session for user " + "~s@~s from ~s", + [LUser, LServer, + inet_parse:ntoa(PeerIP)]); + Res == deleted -> + ?INFO_MSG("unregister SIP session for user " + "~s@~s from ~s", + [LUser, LServer, + inet_parse:ntoa(PeerIP)]) + end, + Cs = prepare_contacts_to_send(ContactsWithExpires), mod_sip:make_response( Req, #sip{type = response, status = 200, - hdrs = [{'contact', [Contact]}]}); - {error, Why} -> - {Status, Reason} = make_status(Why), - mod_sip:make_response( - Req, #sip{type = response, - status = Status, - reason = Reason}) - end; - Expires1 > 0, Expires1 < MinExpires -> - mod_sip:make_response( - Req, #sip{type = response, - status = 423, - hdrs = [{'min-expires', MinExpires}]}); - true -> - case unregister_session(US, SIPSock, CallID, CSeq) of - ok -> - ?INFO_MSG("unregister SIP session for user ~s@~s from ~s", - [LUser, LServer, inet_parse:ntoa(PeerIP)]), - mod_sip:make_response( - Req, - #sip{type = response, status = 200, - hdrs = [{'contact', [Contact]}]}); + hdrs = [{'contact', Cs}]}); {error, Why} -> {Status, Reason} = make_status(Why), mod_sip:make_response( @@ -128,22 +113,15 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) -> [] -> case mnesia:dirty_read(sip_session, US) of [#sip_session{bindings = Bindings}] -> - case pop_previous_binding(SIPSock, Bindings) of - {ok, #binding{expires = Expires1}, _} -> - Contact = {<<"">>, - #uri{user = LUser, host = LServer}, - [{<<"expires">>, - jlib:integer_to_binary(Expires1)}]}, - mod_sip:make_response( - Req, #sip{type = response, status = 200, - hdrs = [{'contact', [Contact]}]}); - {error, notfound} -> - {Status, Reason} = make_status(notfound), - mod_sip:make_response( - Req, #sip{type = response, - status = Status, - reason = Reason}) - end; + ContactsWithExpires = + lists:map( + fun(#binding{contact = Contact, expires = Es}) -> + {Contact, Es} + end, Bindings), + Cs = prepare_contacts_to_send(ContactsWithExpires), + mod_sip:make_response( + Req, #sip{type = response, status = 200, + hdrs = [{'contact', Cs}]}); [] -> {Status, Reason} = make_status(notfound), mod_sip:make_response( @@ -158,7 +136,11 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) -> find_sockets(U, S) -> case mnesia:dirty_read(sip_session, {U, S}) of [#sip_session{bindings = Bindings}] -> - [Binding#binding.socket || Binding <- Bindings]; + lists:map( + fun(#binding{contact = {_, URI, _}, + socket = Socket}) -> + {Socket, URI} + end, Bindings); [] -> [] end. @@ -176,8 +158,8 @@ init([]) -> handle_call({write, Session}, _From, State) -> Res = write_session(Session), {reply, Res, State}; -handle_call({delete, US, SIPSocket, CallID, CSeq}, _From, State) -> - Res = delete_session(US, SIPSocket, CallID, CSeq), +handle_call({delete, US, CallID, CSeq}, _From, State) -> + Res = delete_session(US, CallID, CSeq), {reply, Res, State}; handle_call(_Request, _From, State) -> Reply = ok, @@ -189,8 +171,8 @@ handle_cast(_Msg, State) -> handle_info({write, Session}, State) -> write_session(Session), {noreply, State}; -handle_info({delete, US, SIPSocket, CallID, CSeq}, State) -> - delete_session(US, SIPSocket, CallID, CSeq), +handle_info({delete, US, CallID, CSeq}, State) -> + delete_session(US, CallID, CSeq), {noreply, State}; handle_info({timeout, TRef, US}, State) -> delete_expired_session(US, TRef), @@ -208,70 +190,102 @@ code_change(_OldVsn, State, _Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== -register_session(US, SIPSocket, CallID, CSeq, Expires) -> - Session = #sip_session{us = US, - bindings = [#binding{socket = SIPSocket, - call_id = CallID, - cseq = CSeq, - timestamp = now(), - expires = Expires}]}, +register_session(US, SIPSocket, CallID, CSeq, ContactsWithExpires) -> + Bindings = lists:map( + fun({Contact, Expires}) -> + #binding{socket = SIPSocket, + call_id = CallID, + cseq = CSeq, + timestamp = now(), + contact = Contact, + expires = Expires} + end, ContactsWithExpires), + Session = #sip_session{us = US, bindings = Bindings}, call({write, Session}). -unregister_session(US, SIPSocket, CallID, CSeq) -> - Msg = {delete, US, SIPSocket, CallID, CSeq}, +unregister_session(US, CallID, CSeq) -> + Msg = {delete, US, CallID, CSeq}, call(Msg). -write_session(#sip_session{us = {U, S} = US, - bindings = [#binding{socket = SIPSocket, - call_id = CallID, - expires = Expires, - cseq = CSeq} = Binding]}) -> - case mnesia:dirty_read(sip_session, US) of - [#sip_session{bindings = Bindings}] -> - case pop_previous_binding(SIPSocket, Bindings) of - {ok, #binding{call_id = CallID, cseq = PrevCSeq}, _} - when PrevCSeq > CSeq -> - {error, cseq_out_of_order}; - {ok, #binding{tref = Tref}, Bindings1} -> - erlang:cancel_timer(Tref), - NewTRef = erlang:start_timer(Expires * 1000, self(), US), - NewBindings = [Binding#binding{tref = NewTRef}|Bindings1], - mnesia:dirty_write( - #sip_session{us = US, bindings = NewBindings}); - {error, notfound} -> - MaxSessions = ejabberd_sm:get_max_user_sessions(U, S), - if length(Bindings) < MaxSessions -> - NewTRef = erlang:start_timer(Expires * 1000, self(), US), - NewBindings = [Binding#binding{tref = NewTRef}|Bindings], - mnesia:dirty_write( - #sip_session{us = US, bindings = NewBindings}); - true -> - {error, too_many_sessions} +write_session(#sip_session{us = {U, S} = US, bindings = NewBindings}) -> + PrevBindings = case mnesia:dirty_read(sip_session, US) of + [#sip_session{bindings = PrevBindings1}] -> + PrevBindings1; + [] -> + [] + end, + Res = lists:foldl( + fun(_, {error, _} = Err) -> + Err; + (#binding{call_id = CallID, + expires = Expires, + cseq = CSeq} = Binding, {Add, Del}) -> + case find_binding(Binding, PrevBindings) of + {ok, #binding{call_id = CallID, cseq = PrevCSeq}} + when PrevCSeq > CSeq -> + {error, cseq_out_of_order}; + {ok, PrevBinding} when Expires == 0 -> + {Add, [PrevBinding|Del]}; + {ok, _} -> + {[Binding|Add], Del}; + {error, notfound} when Expires == 0 -> + {error, notfound}; + {error, notfound} -> + {[Binding|Add], Del} end - end; - [] -> - NewTRef = erlang:start_timer(Expires * 1000, self(), US), - NewBindings = [Binding#binding{tref = NewTRef}], - mnesia:dirty_write(#sip_session{us = US, bindings = NewBindings}) + end, {[], []}, NewBindings), + MaxSessions = ejabberd_sm:get_max_user_sessions(U, S), + case Res of + {error, Why} -> + {error, Why}; + {AddBindings, _} when length(AddBindings) > MaxSessions -> + {error, too_many_sessions}; + {AddBindings, DelBindings} -> + lists:foreach( + fun(#binding{tref = TRef}) -> + erlang:cancel_timer(TRef) + end, DelBindings), + Bindings = lists:map( + fun(#binding{tref = TRef, + expires = Expires} = Binding) -> + erlang:cancel_timer(TRef), + NewTRef = erlang:start_timer( + Expires * 1000, self(), US), + Binding#binding{tref = NewTRef} + end, AddBindings), + case Bindings of + [] -> + mnesia:dirty_delete(sip_session, US), + {ok, deleted}; + _ -> + mnesia:dirty_write( + #sip_session{us = US, bindings = Bindings}), + {ok, updated} + end end. -delete_session(US, SIPSocket, CallID, CSeq) -> +delete_session(US, CallID, CSeq) -> case mnesia:dirty_read(sip_session, US) of [#sip_session{bindings = Bindings}] -> - case pop_previous_binding(SIPSocket, Bindings) of - {ok, #binding{call_id = CallID, cseq = PrevCSeq}, _} - when PrevCSeq > CSeq -> - {error, cseq_out_of_order}; - {ok, #binding{tref = TRef}, []} -> - erlang:cancel_timer(TRef), - mnesia:dirty_delete(sip_session, US); - {ok, #binding{tref = TRef}, NewBindings} -> - erlang:cancel_timer(TRef), - mnesia:dirty_write(sip_session, - #sip_session{us = US, - bindings = NewBindings}); - {error, notfound} -> - {error, notfound} + case lists:all( + fun(B) when B#binding.call_id == CallID, + B#binding.cseq > CSeq -> + false; + (_) -> + true + end, Bindings) of + true -> + ContactsWithExpires = + lists:map( + fun(#binding{contact = Contact, + tref = TRef}) -> + erlang:cancel_timer(TRef), + {Contact, 0} + end, Bindings), + mnesia:dirty_delete(sip_session, US), + {ok, ContactsWithExpires}; + false -> + {error, cseq_out_of_order} end; [] -> {error, notfound} @@ -308,17 +322,6 @@ to_integer(Bin, Min, Max) -> error end. -pop_previous_binding(#sip_socket{peer = Peer}, Bindings) -> - case lists:partition( - fun(#binding{socket = #sip_socket{peer = Peer1}}) -> - Peer1 == Peer - end, Bindings) of - {[Binding], RestBindings} -> - {ok, Binding, RestBindings}; - _ -> - {error, notfound} - end. - call(Msg) -> case catch ?GEN_SERVER:call(?MODULE, Msg, ?CALL_TIMEOUT) of {'EXIT', {timeout, _}} -> @@ -329,6 +332,47 @@ call(Msg) -> Reply end. +make_contacts_with_expires(Contacts, Expires) -> + lists:map( + fun({Name, URI, Params}) -> + E1 = case to_integer(esip:get_param(<<"expires">>, Params), + 0, (1 bsl 32)-1) of + {ok, E} -> E; + _ -> Expires + end, + Params1 = lists:keydelete(<<"expires">>, 1, Params), + {{Name, URI, Params1}, E1} + end, Contacts). + +prepare_contacts_to_send(ContactsWithExpires) -> + lists:map( + fun({{Name, URI, Params}, Expires}) -> + Params1 = esip:set_param(<<"expires">>, + list_to_binary( + integer_to_list(Expires)), + Params), + {Name, URI, Params1} + end, ContactsWithExpires). + +find_binding(#binding{contact = {_, URI1, _}} = OrigBinding, + [#binding{contact = {_, URI2, _}} = Binding|Bindings]) -> + case cmp_uri(URI1, URI2) of + true -> + {ok, Binding}; + false -> + find_binding(OrigBinding, Bindings) + end; +find_binding(_, []) -> + {error, notfound}. + +%% TODO: this is *totally* wrong. +%% Rewrite this using URI comparison rules +cmp_uri(#uri{user = U, host = H, port = P}, + #uri{user = U, host = H, port = P}) -> + true; +cmp_uri(_, _) -> + false. + make_status(notfound) -> {404, esip:reason(404)}; make_status(cseq_out_of_order) ->