diff --git a/src/Makefile.in b/src/Makefile.in index 42af5b2f2..a70e25c10 100644 --- a/src/Makefile.in +++ b/src/Makefile.in @@ -35,7 +35,7 @@ ERLANG_CFLAGS += @ERLANG_SSLVER@ # make debug=true to compile Erlang module with debug informations. ifdef debug - EFLAGS+=+debug_info +export_all + EFLAGS+=+debug_info endif DEBUGTOOLS = p1_prof.erl @@ -79,7 +79,7 @@ exec_prefix = @exec_prefix@ SUBDIRS = @mod_irc@ @mod_pubsub@ @mod_muc@ @mod_proxy65@ @eldap@ @pam@ @web@ stringprep stun @tls@ @odbc@ @ejabberd_zlib@ ERLSHLIBS += expat_erl.so -ERLBEHAVS = cyrsasl.erl gen_mod.erl p1_fsm.erl +ERLBEHAVS = cyrsasl.erl gen_mod.erl p1_fsm.erl ejabberd_auth.erl SOURCES_ALL = $(wildcard *.erl) SOURCES_MISC = $(ERLBEHAVS) $(DEBUGTOOLS) SOURCES += $(filter-out $(SOURCES_MISC),$(SOURCES_ALL)) diff --git a/src/acl.erl b/src/acl.erl index b5f64141b..77c55e79d 100644 --- a/src/acl.erl +++ b/src/acl.erl @@ -25,225 +25,264 @@ %%%---------------------------------------------------------------------- -module(acl). + -author('alexey@process-one.net'). --export([start/0, - to_record/3, - add/3, - add_list/3, - match_rule/3, - % for debugging only - match_acl/3]). +-export([start/0, to_record/3, add/3, add_list/3, + match_rule/3, match_acl/3]). -include("ejabberd.hrl"). +-include("jlib.hrl"). -record(acl, {aclname, aclspec}). +-type regexp() :: binary(). +-type glob() :: binary(). +-type aclname() :: {atom(), binary() | global}. +-type aclspec() :: all | none | + {user, binary()} | + {user, binary(), binary()} | + {server, binary()} | + {resource, binary()} | + {user_regexp, regexp()} | + {shared_group, binary()} | + {shared_group, binary(), binary()} | + {user_regexp, regexp(), binary()} | + {server_regexp, regexp()} | + {resource_regexp, regexp()} | + {node_regexp, regexp(), regexp()} | + {user_glob, glob()} | + {user_glob, glob(), binary()} | + {server_glob, glob()} | + {resource_glob, glob()} | + {node_glob, glob(), glob()}. + +-type acl() :: #acl{aclname :: aclname(), + aclspec :: aclspec()}. + +-export_type([acl/0]). + start() -> mnesia:create_table(acl, - [{disc_copies, [node()]}, - {type, bag}, + [{disc_copies, [node()]}, {type, bag}, {attributes, record_info(fields, acl)}]), mnesia:add_table_copy(acl, node(), ram_copies), + update_table(), ok. +-spec to_record(binary(), atom(), aclspec()) -> acl(). + to_record(Host, ACLName, ACLSpec) -> - #acl{aclname = {ACLName, Host}, aclspec = normalize_spec(ACLSpec)}. + #acl{aclname = {ACLName, Host}, + aclspec = normalize_spec(ACLSpec)}. + +-spec add(binary(), aclname(), aclspec()) -> {atomic, ok} | {aborted, any()}. add(Host, ACLName, ACLSpec) -> - F = fun() -> + F = fun () -> mnesia:write(#acl{aclname = {ACLName, Host}, aclspec = normalize_spec(ACLSpec)}) end, mnesia:transaction(F). +-spec add_list(binary(), [acl()], boolean()) -> false | ok. + add_list(Host, ACLs, Clear) -> - F = fun() -> - if - Clear -> - Ks = mnesia:select( - acl, [{{acl, {'$1', Host}, '$2'}, [], ['$1']}]), - lists:foreach(fun(K) -> - mnesia:delete({acl, {K, Host}}) - end, Ks); - true -> - ok + F = fun () -> + if Clear -> + Ks = mnesia:select(acl, + [{{acl, {'$1', Host}, '$2'}, [], + ['$1']}]), + lists:foreach(fun (K) -> mnesia:delete({acl, {K, Host}}) + end, + Ks); + true -> ok end, - lists:foreach(fun(ACL) -> + lists:foreach(fun (ACL) -> case ACL of - #acl{aclname = ACLName, - aclspec = ACLSpec} -> - mnesia:write( - #acl{aclname = {ACLName, Host}, - aclspec = normalize_spec(ACLSpec)}) + #acl{aclname = ACLName, + aclspec = ACLSpec} -> + mnesia:write(#acl{aclname = + {ACLName, + Host}, + aclspec = + normalize_spec(ACLSpec)}) end - end, ACLs) + end, + ACLs) end, case mnesia:transaction(F) of - {atomic, _} -> - ok; - _ -> - false + {atomic, _} -> ok; + _ -> false end. -normalize(A) -> - jlib:nodeprep(A). -normalize_spec({A, B}) -> - {A, normalize(B)}; +normalize(A) -> jlib:nodeprep(iolist_to_binary(A)). + +normalize_spec({A, B}) -> {A, normalize(B)}; normalize_spec({A, B, C}) -> {A, normalize(B), normalize(C)}; -normalize_spec(all) -> - all; -normalize_spec(none) -> - none. - +normalize_spec(all) -> all; +normalize_spec(none) -> none. +-spec match_rule(global | binary(), atom(), jid() | ljid()) -> any(). match_rule(global, Rule, JID) -> case Rule of - all -> allow; - none -> deny; - _ -> - case ejabberd_config:get_global_option({access, Rule, global}) of - undefined -> - deny; - GACLs -> - match_acls(GACLs, JID, global) - end + all -> allow; + none -> deny; + _ -> + case ejabberd_config:get_global_option( + {access, Rule, global}, fun(V) -> V end) + of + undefined -> deny; + GACLs -> match_acls(GACLs, JID, global) + end end; - match_rule(Host, Rule, JID) -> case Rule of - all -> allow; - none -> deny; - _ -> - case ejabberd_config:get_global_option({access, Rule, global}) of - undefined -> - case ejabberd_config:get_global_option({access, Rule, Host}) of - undefined -> - deny; - ACLs -> - match_acls(ACLs, JID, Host) - end; - GACLs -> - case ejabberd_config:get_global_option({access, Rule, Host}) of - undefined -> - match_acls(GACLs, JID, Host); - ACLs -> - case lists:reverse(GACLs) of - [{allow, all} | Rest] -> - match_acls( - lists:reverse(Rest) ++ ACLs ++ - [{allow, all}], - JID, Host); - _ -> - match_acls(GACLs ++ ACLs, JID, Host) - end - end - end + all -> allow; + none -> deny; + _ -> + case ejabberd_config:get_global_option( + {access, Rule, global}, fun(V) -> V end) + of + undefined -> + case ejabberd_config:get_global_option( + {access, Rule, Host}, fun(V) -> V end) + of + undefined -> deny; + ACLs -> match_acls(ACLs, JID, Host) + end; + GACLs -> + case ejabberd_config:get_global_option( + {access, Rule, Host}, fun(V) -> V end) + of + undefined -> match_acls(GACLs, JID, Host); + ACLs -> + case lists:reverse(GACLs) of + [{allow, all} | Rest] -> + match_acls(lists:reverse(Rest) ++ + ACLs ++ [{allow, all}], + JID, Host); + _ -> match_acls(GACLs ++ ACLs, JID, Host) + end + end + end end. -match_acls([], _, _Host) -> - deny; +match_acls([], _, _Host) -> deny; match_acls([{Access, ACL} | ACLs], JID, Host) -> case match_acl(ACL, JID, Host) of - true -> - Access; - _ -> - match_acls(ACLs, JID, Host) + true -> Access; + _ -> match_acls(ACLs, JID, Host) end. +-spec match_acl(atom(), jid() | ljid(), binary()) -> boolean(). + match_acl(ACL, JID, Host) -> case ACL of - all -> true; - none -> false; - _ -> - {User, Server, Resource} = jlib:jid_tolower(JID), - lists:any(fun(#acl{aclspec = Spec}) -> - case Spec of - all -> - true; - {user, U} -> - (U == User) - andalso - ((Host == Server) orelse - ((Host == global) andalso - lists:member(Server, ?MYHOSTS))); - {user, U, S} -> - (U == User) andalso (S == Server); - {server, S} -> - S == Server; - {resource, R} -> - R == Resource; - {user_regexp, UR} -> - ((Host == Server) orelse - ((Host == global) andalso - lists:member(Server, ?MYHOSTS))) - andalso is_regexp_match(User, UR); - {shared_group, G} -> - Mod = loaded_shared_roster_module(Host), - Mod:is_user_in_group({User, Server}, G, Host); - {shared_group, G, H} -> - Mod = loaded_shared_roster_module(H), - Mod:is_user_in_group({User, Server}, G, H); - {user_regexp, UR, S} -> - (S == Server) andalso - is_regexp_match(User, UR); - {server_regexp, SR} -> - is_regexp_match(Server, SR); - {resource_regexp, RR} -> - is_regexp_match(Resource, RR); - {node_regexp, UR, SR} -> - is_regexp_match(Server, SR) andalso - is_regexp_match(User, UR); - {user_glob, UR} -> - ((Host == Server) orelse - ((Host == global) andalso - lists:member(Server, ?MYHOSTS))) - andalso - is_glob_match(User, UR); - {user_glob, UR, S} -> - (S == Server) andalso - is_glob_match(User, UR); - {server_glob, SR} -> - is_glob_match(Server, SR); - {resource_glob, RR} -> - is_glob_match(Resource, RR); - {node_glob, UR, SR} -> - is_glob_match(Server, SR) andalso - is_glob_match(User, UR); - WrongSpec -> - ?ERROR_MSG( - "Wrong ACL expression: ~p~n" - "Check your config file and reload it with the override_acls option enabled", - [WrongSpec]), - false - end - end, - ets:lookup(acl, {ACL, global}) ++ + all -> true; + none -> false; + _ -> + {User, Server, Resource} = jlib:jid_tolower(JID), + lists:any(fun (#acl{aclspec = Spec}) -> + case Spec of + all -> true; + {user, U} -> + U == User andalso + (Host == Server orelse + Host == global andalso + lists:member(Server, ?MYHOSTS)); + {user, U, S} -> U == User andalso S == Server; + {server, S} -> S == Server; + {resource, R} -> R == Resource; + {user_regexp, UR} -> + (Host == Server orelse + Host == global andalso + lists:member(Server, ?MYHOSTS)) + andalso is_regexp_match(User, UR); + {shared_group, G} -> + Mod = loaded_shared_roster_module(Host), + Mod:is_user_in_group({User, Server}, G, Host); + {shared_group, G, H} -> + Mod = loaded_shared_roster_module(H), + Mod:is_user_in_group({User, Server}, G, H); + {user_regexp, UR, S} -> + S == Server andalso is_regexp_match(User, UR); + {server_regexp, SR} -> + is_regexp_match(Server, SR); + {resource_regexp, RR} -> + is_regexp_match(Resource, RR); + {node_regexp, UR, SR} -> + is_regexp_match(Server, SR) andalso + is_regexp_match(User, UR); + {user_glob, UR} -> + (Host == Server orelse + Host == global andalso + lists:member(Server, ?MYHOSTS)) + andalso is_glob_match(User, UR); + {user_glob, UR, S} -> + S == Server andalso is_glob_match(User, UR); + {server_glob, SR} -> is_glob_match(Server, SR); + {resource_glob, RR} -> + is_glob_match(Resource, RR); + {node_glob, UR, SR} -> + is_glob_match(Server, SR) andalso + is_glob_match(User, UR); + WrongSpec -> + ?ERROR_MSG("Wrong ACL expression: ~p~nCheck your " + "config file and reload it with the override_a" + "cls option enabled", + [WrongSpec]), + false + end + end, + ets:lookup(acl, {ACL, global}) ++ ets:lookup(acl, {ACL, Host})) end. is_regexp_match(String, RegExp) -> case ejabberd_regexp:run(String, RegExp) of - nomatch -> - false; - match -> - true; - {error, ErrDesc} -> - ?ERROR_MSG( - "Wrong regexp ~p in ACL: ~p", - [RegExp, ErrDesc]), - false + nomatch -> false; + match -> true; + {error, ErrDesc} -> + ?ERROR_MSG("Wrong regexp ~p in ACL: ~p", + [RegExp, ErrDesc]), + false end. is_glob_match(String, Glob) -> - is_regexp_match(String, ejabberd_regexp:sh_to_awk(Glob)). + is_regexp_match(String, + ejabberd_regexp:sh_to_awk(Glob)). loaded_shared_roster_module(Host) -> case gen_mod:is_loaded(Host, mod_shared_roster_ldap) of - true -> - mod_shared_roster_ldap; - false -> - mod_shared_roster + true -> mod_shared_roster_ldap; + false -> mod_shared_roster + end. + +update_table() -> + Fields = record_info(fields, acl), + case mnesia:table_info(acl, attributes) of + Fields -> + ejabberd_config:convert_table_to_binary( + acl, Fields, bag, + fun(#acl{aclspec = Spec}) when is_tuple(Spec) -> + element(2, Spec); + (_) -> + '$next' + end, + fun(#acl{aclname = {ACLName, Host}, + aclspec = Spec} = R) -> + NewHost = if Host == global -> + Host; + true -> + iolist_to_binary(Host) + end, + R#acl{aclname = {ACLName, NewHost}, + aclspec = normalize_spec(Spec)} + end); + _ -> + ?INFO_MSG("Recreating acl table", []), + mnesia:transform_table(acl, ignore, Fields) end. diff --git a/src/adhoc.erl b/src/adhoc.erl index 94353aead..6af65e5d1 100644 --- a/src/adhoc.erl +++ b/src/adhoc.erl @@ -25,11 +25,14 @@ %%%---------------------------------------------------------------------- -module(adhoc). + -author('henoch@dtek.chalmers.se'). --export([parse_request/1, - produce_response/2, - produce_response/1]). +-export([ + parse_request/1, + produce_response/2, + produce_response/1 +]). -include("ejabberd.hrl"). -include("jlib.hrl"). @@ -37,93 +40,121 @@ %% Parse an ad-hoc request. Return either an adhoc_request record or %% an {error, ErrorType} tuple. +%% +-spec(parse_request/1 :: +( + IQ :: iq_request()) + -> adhoc_response() + %% + | {error, _} +). + parse_request(#iq{type = set, lang = Lang, sub_el = SubEl, xmlns = ?NS_COMMANDS}) -> ?DEBUG("entering parse_request...", []), - Node = xml:get_tag_attr_s("node", SubEl), - SessionID = xml:get_tag_attr_s("sessionid", SubEl), - Action = xml:get_tag_attr_s("action", SubEl), + Node = xml:get_tag_attr_s(<<"node">>, SubEl), + SessionID = xml:get_tag_attr_s(<<"sessionid">>, SubEl), + Action = xml:get_tag_attr_s(<<"action">>, SubEl), XData = find_xdata_el(SubEl), - {xmlelement, _, _, AllEls} = SubEl, + #xmlel{children = AllEls} = SubEl, Others = case XData of - false -> - AllEls; - _ -> - lists:delete(XData, AllEls) + false -> AllEls; + _ -> lists:delete(XData, AllEls) end, - - #adhoc_request{lang = Lang, - node = Node, - sessionid = SessionID, - action = Action, - xdata = XData, - others = Others}; -parse_request(_) -> - {error, ?ERR_BAD_REQUEST}. + #adhoc_request{ + lang = Lang, + node = Node, + sessionid = SessionID, + action = Action, + xdata = XData, + others = Others + }; +parse_request(_) -> {error, ?ERR_BAD_REQUEST}. %% Borrowed from mod_vcard.erl -find_xdata_el({xmlelement, _Name, _Attrs, SubEls}) -> +find_xdata_el(#xmlel{children = SubEls}) -> find_xdata_el1(SubEls). -find_xdata_el1([]) -> - false; -find_xdata_el1([{xmlelement, Name, Attrs, SubEls} | Els]) -> - case xml:get_attr_s("xmlns", Attrs) of - ?NS_XDATA -> - {xmlelement, Name, Attrs, SubEls}; - _ -> - find_xdata_el1(Els) +find_xdata_el1([]) -> false; +find_xdata_el1([El | Els]) when is_record(El, xmlel) -> + case xml:get_tag_attr_s(<<"xmlns">>, El) of + ?NS_XDATA -> El; + _ -> find_xdata_el1(Els) end; -find_xdata_el1([_ | Els]) -> - find_xdata_el1(Els). +find_xdata_el1([_ | Els]) -> find_xdata_el1(Els). %% Produce a node to use as response from an adhoc_response %% record, filling in values for language, node and session id from %% the request. -produce_response(#adhoc_request{lang = Lang, - node = Node, - sessionid = SessionID}, - Response) -> - produce_response(Response#adhoc_response{lang = Lang, - node = Node, - sessionid = SessionID}). +%% +-spec(produce_response/2 :: +( + Adhoc_Request :: adhoc_request(), + Adhoc_Response :: adhoc_response()) + -> Xmlel::xmlel() +). %% Produce a node to use as response from an adhoc_response %% record. -produce_response(#adhoc_response{lang = _Lang, - node = Node, - sessionid = ProvidedSessionID, - status = Status, - defaultaction = DefaultAction, - actions = Actions, - notes = Notes, - elements = Elements}) -> - SessionID = if is_list(ProvidedSessionID), ProvidedSessionID /= "" -> - ProvidedSessionID; - true -> - jlib:now_to_utc_string(now()) - end, +produce_response(#adhoc_request{lang = Lang, node = Node, sessionid = SessionID}, + Adhoc_Response) -> + produce_response(Adhoc_Response#adhoc_response{ + lang = Lang, node = Node, sessionid = SessionID + }). + +%% +-spec(produce_response/1 :: +( + Adhoc_Response::adhoc_response()) + -> Xmlel::xmlel() +). + +produce_response( + #adhoc_response{ + %lang = _Lang, + node = Node, + sessionid = ProvidedSessionID, + status = Status, + defaultaction = DefaultAction, + actions = Actions, + notes = Notes, + elements = Elements + }) -> + SessionID = if is_binary(ProvidedSessionID), + ProvidedSessionID /= <<"">> -> ProvidedSessionID; + true -> jlib:now_to_utc_string(now()) + end, case Actions of - [] -> - ActionsEls = []; - _ -> - case DefaultAction of - "" -> - ActionsElAttrs = []; - _ -> - ActionsElAttrs = [{"execute", DefaultAction}] - end, - ActionsEls = [{xmlelement, "actions", - ActionsElAttrs, - [{xmlelement, Action, [], []} || Action <- Actions]}] + [] -> + ActionsEls = []; + _ -> + case DefaultAction of + <<"">> -> ActionsElAttrs = []; + _ -> ActionsElAttrs = [{<<"execute">>, DefaultAction}] + end, + ActionsEls = [ + #xmlel{ + name = <<"actions">>, + attrs = ActionsElAttrs, + children = [ + #xmlel{name = Action, attrs = [], children = []} + || Action <- Actions] + } + ] end, NotesEls = lists:map(fun({Type, Text}) -> - {xmlelement, "note", - [{"type", Type}], - [{xmlcdata, Text}]} - end, Notes), - {xmlelement, "command", - [{"xmlns", ?NS_COMMANDS}, - {"sessionid", SessionID}, - {"node", Node}, - {"status", atom_to_list(Status)}], - ActionsEls ++ NotesEls ++ Elements}. + #xmlel{ + name = <<"note">>, + attrs = [{<<"type">>, Type}], + children = [{xmlcdata, Text}] + } + end, Notes), + #xmlel{ + name = <<"command">>, + attrs = [ + {<<"xmlns">>, ?NS_COMMANDS}, + {<<"sessionid">>, SessionID}, + {<<"node">>, Node}, + {<<"status">>, iolist_to_binary(atom_to_list(Status))} + ], + children = ActionsEls ++ NotesEls ++ Elements + }. diff --git a/src/adhoc.hrl b/src/adhoc.hrl index b294f84fb..0910dc621 100644 --- a/src/adhoc.hrl +++ b/src/adhoc.hrl @@ -19,18 +19,27 @@ %%% %%%---------------------------------------------------------------------- --record(adhoc_request, {lang, - node, - sessionid, - action, - xdata, - others}). +-record(adhoc_request, +{ + lang = <<"">> :: binary(), + node = <<"">> :: binary(), + sessionid = <<"">> :: binary(), + action = <<"">> :: binary(), + xdata = false :: false | xmlel(), + others = [] :: [xmlel()] +}). --record(adhoc_response, {lang, - node, - sessionid, - status, - defaultaction = "", - actions = [], - notes = [], - elements = []}). +-record(adhoc_response, +{ + lang = <<"">> :: binary(), + node = <<"">> :: binary(), + sessionid = <<"">> :: binary(), + status :: atom(), + defaultaction = <<"">> :: binary(), + actions = [] :: [binary()], + notes = [] :: [{binary(), binary()}], + elements = [] :: [xmlel()] +}). + +-type adhoc_request() :: #adhoc_request{}. +-type adhoc_response() :: #adhoc_response{}. diff --git a/src/cache_tab.erl b/src/cache_tab.erl index 74f47db6a..95343e4f5 100644 --- a/src/cache_tab.erl +++ b/src/cache_tab.erl @@ -380,11 +380,11 @@ do_setopts(#state{procs_num = N} = State, Opts) -> shrink_size = ShrinkSize}. get_proc_num() -> - case erlang:system_info(logical_processors) of - unknown -> - 1; - Num -> - Num + case catch erlang:system_info(logical_processors) of + Num when is_integer(Num) -> + Num; + _ -> + 1 end. get_proc_by_hash(Tab, Term) -> diff --git a/src/configure.erl b/src/configure.erl index 34246fd30..87b7bc208 100644 --- a/src/configure.erl +++ b/src/configure.erl @@ -62,7 +62,7 @@ start() -> RootDirS = "ERLANG_DIR = " ++ code:root_dir() ++ "\n", %% Load the ejabberd application description so that ?VERSION can read the vsn key application:load(ejabberd), - Version = "EJABBERD_VERSION = " ++ ?VERSION ++ "\n", + Version = "EJABBERD_VERSION = " ++ binary_to_list(?VERSION) ++ "\n", ExpatDir = "EXPAT_DIR = c:\\sdk\\Expat-2.0.0\n", OpenSSLDir = "OPENSSL_DIR = c:\\sdk\\OpenSSL\n", DBType = "DBTYPE = generic\n", %% 'generic' or 'mssql' diff --git a/src/cyrsasl.erl b/src/cyrsasl.erl index 9d1377ffc..0672267b2 100644 --- a/src/cyrsasl.erl +++ b/src/cyrsasl.erl @@ -25,43 +25,76 @@ %%%---------------------------------------------------------------------- -module(cyrsasl). + -author('alexey@process-one.net'). --export([start/0, - register_mechanism/3, - listmech/1, - server_new/7, - server_start/3, - server_step/2]). +-export([start/0, register_mechanism/3, listmech/1, + server_new/7, server_start/3, server_step/2]). -include("ejabberd.hrl"). --record(sasl_mechanism, {mechanism, module, password_type}). --record(sasl_state, {service, myname, realm, - get_password, check_password, check_password_digest, - mech_mod, mech_state}). +%% +-export_type([ + mechanism/0, + mechanisms/0, + sasl_mechanism/0 +]). --export([behaviour_info/1]). +-record(sasl_mechanism, + {mechanism = <<"">> :: mechanism() | '$1', + module :: atom(), + password_type = plain :: password_type() | '$2'}). -behaviour_info(callbacks) -> - [{mech_new, 4}, {mech_step, 2}]; -behaviour_info(_Other) -> - undefined. +-type(mechanism() :: binary()). +-type(mechanisms() :: [mechanism(),...]). +-type(password_type() :: plain | digest | scram). +-type(props() :: [{username, binary()} | + {authzid, binary()} | + {auth_module, atom()}]). + +-type(sasl_mechanism() :: #sasl_mechanism{}). + +-record(sasl_state, +{ + service, + myname, + realm, + get_password, + check_password, + check_password_digest, + mech_mod, + mech_state +}). + +-callback mech_new(binary(), fun(), fun(), fun()) -> any(). +-callback mech_step(any(), binary()) -> {ok, props()} | + {ok, props(), binary()} | + {continue, binary(), any()} | + {error, binary()} | + {error, binary(), binary()}. start() -> - ets:new(sasl_mechanism, [named_table, - public, - {keypos, #sasl_mechanism.mechanism}]), + ets:new(sasl_mechanism, + [named_table, public, + {keypos, #sasl_mechanism.mechanism}]), cyrsasl_plain:start([]), cyrsasl_digest:start([]), cyrsasl_scram:start([]), cyrsasl_anonymous:start([]), ok. +%% +-spec(register_mechanism/3 :: +( + Mechanim :: mechanism(), + Module :: module(), + PasswordType :: password_type()) + -> any() +). + register_mechanism(Mechanism, Module, PasswordType) -> ets:insert(sasl_mechanism, - #sasl_mechanism{mechanism = Mechanism, - module = Module, + #sasl_mechanism{mechanism = Mechanism, module = Module, password_type = PasswordType}). %%% TODO: use callbacks @@ -89,95 +122,96 @@ register_mechanism(Mechanism, Module, PasswordType) -> %% end. check_credentials(_State, Props) -> - User = xml:get_attr_s(username, Props), + User = proplists:get_value(username, Props, <<>>), case jlib:nodeprep(User) of - error -> - {error, "not-authorized"}; - "" -> - {error, "not-authorized"}; - _LUser -> - ok + error -> {error, <<"not-authorized">>}; + <<"">> -> {error, <<"not-authorized">>}; + _LUser -> ok end. +-spec(listmech/1 :: +( + Host ::binary()) + -> Mechanisms::mechanisms() +). + listmech(Host) -> Mechs = ets:select(sasl_mechanism, [{#sasl_mechanism{mechanism = '$1', - password_type = '$2', - _ = '_'}, + password_type = '$2', _ = '_'}, case catch ejabberd_auth:store_type(Host) of - external -> - [{'==', '$2', plain}]; - scram -> - [{'/=', '$2', digest}]; - {'EXIT',{undef,[{Module,store_type,[]} | _]}} -> - ?WARNING_MSG("~p doesn't implement the function store_type/0", [Module]), - []; - _Else -> - [] + external -> [{'==', '$2', plain}]; + scram -> [{'/=', '$2', digest}]; + {'EXIT', {undef, [{Module, store_type, []} | _]}} -> + ?WARNING_MSG("~p doesn't implement the function store_type/0", + [Module]), + []; + _Else -> [] end, ['$1']}]), filter_anonymous(Host, Mechs). server_new(Service, ServerFQDN, UserRealm, _SecFlags, GetPassword, CheckPassword, CheckPasswordDigest) -> - #sasl_state{service = Service, - myname = ServerFQDN, - realm = UserRealm, - get_password = GetPassword, + #sasl_state{service = Service, myname = ServerFQDN, + realm = UserRealm, get_password = GetPassword, check_password = CheckPassword, - check_password_digest= CheckPasswordDigest}. + check_password_digest = CheckPasswordDigest}. server_start(State, Mech, ClientIn) -> - case lists:member(Mech, listmech(State#sasl_state.myname)) of - true -> - case ets:lookup(sasl_mechanism, Mech) of - [#sasl_mechanism{module = Module}] -> - {ok, MechState} = Module:mech_new( - State#sasl_state.myname, - State#sasl_state.get_password, - State#sasl_state.check_password, - State#sasl_state.check_password_digest), - server_step(State#sasl_state{mech_mod = Module, - mech_state = MechState}, - ClientIn); - _ -> - {error, "no-mechanism"} - end; - false -> - {error, "no-mechanism"} + case lists:member(Mech, + listmech(State#sasl_state.myname)) + of + true -> + case ets:lookup(sasl_mechanism, Mech) of + [#sasl_mechanism{module = Module}] -> + {ok, MechState} = + Module:mech_new(State#sasl_state.myname, + State#sasl_state.get_password, + State#sasl_state.check_password, + State#sasl_state.check_password_digest), + server_step(State#sasl_state{mech_mod = Module, + mech_state = MechState}, + ClientIn); + _ -> {error, <<"no-mechanism">>} + end; + false -> {error, <<"no-mechanism">>} end. server_step(State, ClientIn) -> Module = State#sasl_state.mech_mod, MechState = State#sasl_state.mech_state, case Module:mech_step(MechState, ClientIn) of - {ok, Props} -> - case check_credentials(State, Props) of - ok -> - {ok, Props}; - {error, Error} -> - {error, Error} - end; - {ok, Props, ServerOut} -> - case check_credentials(State, Props) of - ok -> - {ok, Props, ServerOut}; - {error, Error} -> - {error, Error} - end; - {continue, ServerOut, NewMechState} -> - {continue, ServerOut, - State#sasl_state{mech_state = NewMechState}}; - {error, Error, Username} -> - {error, Error, Username}; - {error, Error} -> - {error, Error} + {ok, Props} -> + case check_credentials(State, Props) of + ok -> {ok, Props}; + {error, Error} -> {error, Error} + end; + {ok, Props, ServerOut} -> + case check_credentials(State, Props) of + ok -> {ok, Props, ServerOut}; + {error, Error} -> {error, Error} + end; + {continue, ServerOut, NewMechState} -> + {continue, ServerOut, State#sasl_state{mech_state = NewMechState}}; + {error, Error, Username} -> + {error, Error, Username}; + {error, Error} -> + {error, Error} end. %% Remove the anonymous mechanism from the list if not enabled for the given %% host +%% +-spec(filter_anonymous/2 :: +( + Host :: binary(), + Mechs :: mechanisms()) + -> mechanisms() +). + filter_anonymous(Host, Mechs) -> case ejabberd_auth_anonymous:is_sasl_anonymous_enabled(Host) of - true -> Mechs; - false -> Mechs -- ["ANONYMOUS"] + true -> Mechs; + false -> Mechs -- [<<"ANONYMOUS">>] end. diff --git a/src/cyrsasl_anonymous.erl b/src/cyrsasl_anonymous.erl index cb0b1e3ff..3090cfe9d 100644 --- a/src/cyrsasl_anonymous.erl +++ b/src/cyrsasl_anonymous.erl @@ -31,26 +31,20 @@ -behaviour(cyrsasl). --record(state, {server}). +-record(state, {server = <<"">> :: binary()}). start(_Opts) -> - cyrsasl:register_mechanism("ANONYMOUS", ?MODULE, plain), + cyrsasl:register_mechanism(<<"ANONYMOUS">>, ?MODULE, plain), ok. -stop() -> - ok. +stop() -> ok. mech_new(Host, _GetPassword, _CheckPassword, _CheckPasswordDigest) -> {ok, #state{server = Host}}. -mech_step(State, _ClientIn) -> - %% We generate a random username: - User = lists:concat([randoms:get_string() | tuple_to_list(now())]), - Server = State#state.server, - - %% Checks that the username is available +mech_step(#state{server = Server}, _ClientIn) -> + User = iolist_to_binary([randoms:get_string() | tuple_to_list(now())]), case ejabberd_auth:is_user_exists(User, Server) of - true -> {error, "not-authorized"}; - false -> {ok, [{username, User}, - {auth_module, ejabberd_auth_anonymous}]} + true -> {error, <<"not-authorized">>}; + false -> {ok, [{username, User}, {auth_module, ejabberd_auth_anonymous}]} end. diff --git a/src/cyrsasl_digest.erl b/src/cyrsasl_digest.erl index 557e498cd..3bb88431b 100644 --- a/src/cyrsasl_digest.erl +++ b/src/cyrsasl_digest.erl @@ -25,134 +25,145 @@ %%%---------------------------------------------------------------------- -module(cyrsasl_digest). + -author('alexey@sevcom.net'). --export([start/1, - stop/0, - mech_new/4, - mech_step/2]). +-export([start/1, stop/0, mech_new/4, mech_step/2, parse/1]). -include("ejabberd.hrl"). -behaviour(cyrsasl). --record(state, {step, nonce, username, authzid, get_password, check_password, auth_module, - host, hostfqdn}). +-type get_password_fun() :: fun((binary()) -> {false, any()} | + {binary(), atom()}). + +-type check_password_fun() :: fun((binary(), binary(), binary(), + fun((binary()) -> binary())) -> + {boolean(), any()} | + false). + +-record(state, {step = 1 :: 1 | 3 | 5, + nonce = <<"">> :: binary(), + username = <<"">> :: binary(), + authzid = <<"">> :: binary(), + get_password = fun(_) -> {false, <<>>} end :: get_password_fun(), + check_password = fun(_, _, _, _) -> false end :: check_password_fun(), + auth_module :: atom(), + host = <<"">> :: binary(), + hostfqdn = <<"">> :: binary()}). start(_Opts) -> Fqdn = get_local_fqdn(), - ?INFO_MSG("FQDN used to check DIGEST-MD5 SASL authentication: ~p", [Fqdn]), - cyrsasl:register_mechanism("DIGEST-MD5", ?MODULE, digest). + ?INFO_MSG("FQDN used to check DIGEST-MD5 SASL authentication: ~s", + [Fqdn]), + cyrsasl:register_mechanism(<<"DIGEST-MD5">>, ?MODULE, + digest). -stop() -> - ok. +stop() -> ok. -mech_new(Host, GetPassword, _CheckPassword, CheckPasswordDigest) -> - {ok, #state{step = 1, - nonce = randoms:get_string(), - host = Host, - hostfqdn = get_local_fqdn(), - get_password = GetPassword, - check_password = CheckPasswordDigest}}. +mech_new(Host, GetPassword, _CheckPassword, + CheckPasswordDigest) -> + {ok, + #state{step = 1, nonce = randoms:get_string(), + host = Host, hostfqdn = get_local_fqdn(), + get_password = GetPassword, + check_password = CheckPasswordDigest}}. mech_step(#state{step = 1, nonce = Nonce} = State, _) -> {continue, - "nonce=\"" ++ Nonce ++ - "\",qop=\"auth\",charset=utf-8,algorithm=md5-sess", + <<"nonce=\"", Nonce/binary, + "\",qop=\"auth\",charset=utf-8,algorithm=md5-sess">>, State#state{step = 3}}; -mech_step(#state{step = 3, nonce = Nonce} = State, ClientIn) -> +mech_step(#state{step = 3, nonce = Nonce} = State, + ClientIn) -> case parse(ClientIn) of - bad -> - {error, "bad-protocol"}; - KeyVals -> - DigestURI = xml:get_attr_s("digest-uri", KeyVals), - UserName = xml:get_attr_s("username", KeyVals), - case is_digesturi_valid(DigestURI, State#state.host, State#state.hostfqdn) of - false -> - ?DEBUG("User login not authorized because digest-uri " - "seems invalid: ~p (checking for Host ~p, FQDN ~p)", [DigestURI, - State#state.host, State#state.hostfqdn]), - {error, "not-authorized", UserName}; - true -> - AuthzId = xml:get_attr_s("authzid", KeyVals), - case (State#state.get_password)(UserName) of - {false, _} -> - {error, "not-authorized", UserName}; - {Passwd, AuthModule} -> - case (State#state.check_password)(UserName, "", - xml:get_attr_s("response", KeyVals), - fun(PW) -> response(KeyVals, UserName, PW, Nonce, AuthzId, - "AUTHENTICATE") end) of - {true, _} -> - RspAuth = response(KeyVals, - UserName, Passwd, - Nonce, AuthzId, ""), - {continue, - "rspauth=" ++ RspAuth, - State#state{step = 5, - auth_module = AuthModule, - username = UserName, - authzid = AuthzId}}; - false -> - {error, "not-authorized", UserName}; - {false, _} -> - {error, "not-authorized", UserName} - end - end - end + bad -> {error, <<"bad-protocol">>}; + KeyVals -> + DigestURI = proplists:get_value(<<"digest-uri">>, KeyVals, <<>>), + %DigestURI = xml:get_attr_s(<<"digest-uri">>, KeyVals), + UserName = proplists:get_value(<<"username">>, KeyVals, <<>>), + %UserName = xml:get_attr_s(<<"username">>, KeyVals), + case is_digesturi_valid(DigestURI, State#state.host, + State#state.hostfqdn) + of + false -> + ?DEBUG("User login not authorized because digest-uri " + "seems invalid: ~p (checking for Host " + "~p, FQDN ~p)", + [DigestURI, State#state.host, State#state.hostfqdn]), + {error, <<"not-authorized">>, UserName}; + true -> + AuthzId = proplists:get_value(<<"authzid">>, KeyVals, <<>>), + %AuthzId = xml:get_attr_s(<<"authzid">>, KeyVals), + case (State#state.get_password)(UserName) of + {false, _} -> {error, <<"not-authorized">>, UserName}; + {Passwd, AuthModule} -> + case (State#state.check_password)(UserName, <<"">>, + proplists:get_value(<<"response">>, KeyVals, <<>>), + %xml:get_attr_s(<<"response">>, KeyVals), + fun (PW) -> + response(KeyVals, + UserName, + PW, + Nonce, + AuthzId, + <<"AUTHENTICATE">>) + end) + of + {true, _} -> + RspAuth = response(KeyVals, UserName, Passwd, Nonce, + AuthzId, <<"">>), + {continue, <<"rspauth=", RspAuth/binary>>, + State#state{step = 5, auth_module = AuthModule, + username = UserName, + authzid = AuthzId}}; + false -> {error, <<"not-authorized">>, UserName}; + {false, _} -> {error, <<"not-authorized">>, UserName} + end + end + end end; -mech_step(#state{step = 5, - auth_module = AuthModule, - username = UserName, - authzid = AuthzId}, "") -> - {ok, [{username, UserName}, {authzid, AuthzId}, - {auth_module, AuthModule}]}; +mech_step(#state{step = 5, auth_module = AuthModule, + username = UserName, authzid = AuthzId}, + <<"">>) -> + {ok, + [{username, UserName}, {authzid, AuthzId}, + {auth_module, AuthModule}]}; mech_step(A, B) -> - ?DEBUG("SASL DIGEST: A ~p B ~p", [A,B]), - {error, "bad-protocol"}. + ?DEBUG("SASL DIGEST: A ~p B ~p", [A, B]), + {error, <<"bad-protocol">>}. -parse(S) -> - parse1(S, "", []). +parse(S) -> parse1(binary_to_list(S), "", []). parse1([$= | Cs], S, Ts) -> parse2(Cs, lists:reverse(S), "", Ts); -parse1([$, | Cs], [], Ts) -> - parse1(Cs, [], Ts); -parse1([$\s | Cs], [], Ts) -> - parse1(Cs, [], Ts); -parse1([C | Cs], S, Ts) -> - parse1(Cs, [C | S], Ts); -parse1([], [], T) -> - lists:reverse(T); -parse1([], _S, _T) -> - bad. +parse1([$, | Cs], [], Ts) -> parse1(Cs, [], Ts); +parse1([$\s | Cs], [], Ts) -> parse1(Cs, [], Ts); +parse1([C | Cs], S, Ts) -> parse1(Cs, [C | S], Ts); +parse1([], [], T) -> lists:reverse(T); +parse1([], _S, _T) -> bad. -parse2([$\" | Cs], Key, Val, Ts) -> +parse2([$" | Cs], Key, Val, Ts) -> parse3(Cs, Key, Val, Ts); parse2([C | Cs], Key, Val, Ts) -> parse4(Cs, Key, [C | Val], Ts); -parse2([], _, _, _) -> - bad. +parse2([], _, _, _) -> bad. -parse3([$\" | Cs], Key, Val, Ts) -> +parse3([$" | Cs], Key, Val, Ts) -> parse4(Cs, Key, Val, Ts); parse3([$\\, C | Cs], Key, Val, Ts) -> parse3(Cs, Key, [C | Val], Ts); parse3([C | Cs], Key, Val, Ts) -> parse3(Cs, Key, [C | Val], Ts); -parse3([], _, _, _) -> - bad. +parse3([], _, _, _) -> bad. parse4([$, | Cs], Key, Val, Ts) -> - parse1(Cs, "", [{Key, lists:reverse(Val)} | Ts]); + parse1(Cs, "", [{list_to_binary(Key), list_to_binary(lists:reverse(Val))} | Ts]); parse4([$\s | Cs], Key, Val, Ts) -> parse4(Cs, Key, Val, Ts); parse4([C | Cs], Key, Val, Ts) -> parse4(Cs, Key, [C | Val], Ts); parse4([], Key, Val, Ts) -> - parse1([], "", [{Key, lists:reverse(Val)} | Ts]). - - %% @doc Check if the digest-uri is valid. %% RFC-2831 allows to provide the IP address in Host, %% however ejabberd doesn't allow that. @@ -162,14 +173,17 @@ parse4([], Key, Val, Ts) -> %% xmpp/server3.example.org/jabber.example.org, xmpp/server3.example.org and %% xmpp/jabber.example.org %% The last version is not actually allowed by the RFC, but implemented by popular clients -is_digesturi_valid(DigestURICase, JabberDomain, JabberFQDN) -> + parse1([], "", [{list_to_binary(Key), list_to_binary(lists:reverse(Val))} | Ts]). + +is_digesturi_valid(DigestURICase, JabberDomain, + JabberFQDN) -> DigestURI = stringprep:tolower(DigestURICase), - case catch string:tokens(DigestURI, "/") of - ["xmpp", Host] -> - IsHostFqdn = is_host_fqdn(Host, JabberFQDN), + case catch str:tokens(DigestURI, <<"/">>) of + [<<"xmpp">>, Host] -> + IsHostFqdn = is_host_fqdn(binary_to_list(Host), binary_to_list(JabberFQDN)), (Host == JabberDomain) or IsHostFqdn; - ["xmpp", Host, ServName] -> - IsHostFqdn = is_host_fqdn(Host, JabberFQDN), + [<<"xmpp">>, Host, ServName] -> + IsHostFqdn = is_host_fqdn(binary_to_list(Host), binary_to_list(JabberFQDN)), (ServName == JabberDomain) and IsHostFqdn; _ -> false @@ -185,62 +199,60 @@ is_host_fqdn(Host, [Fqdn | FqdnTail]) when Host /= Fqdn -> is_host_fqdn(Host, FqdnTail). get_local_fqdn() -> - case (catch get_local_fqdn2()) of - Str when is_list(Str) -> Str; - _ -> "unknown-fqdn, please configure fqdn option in ejabberd.cfg!" - end. -get_local_fqdn2() -> - case ejabberd_config:get_local_option(fqdn) of - ConfiguredFqdn when is_list(ConfiguredFqdn) -> - ConfiguredFqdn; - _undefined -> - {ok, Hostname} = inet:gethostname(), - {ok, {hostent, Fqdn, _, _, _, _}} = inet:gethostbyname(Hostname), - Fqdn + case catch get_local_fqdn2() of + Str when is_binary(Str) -> Str; + _ -> + <<"unknown-fqdn, please configure fqdn " + "option in ejabberd.cfg!">> end. -digit_to_xchar(D) when (D >= 0) and (D < 10) -> - D + 48; -digit_to_xchar(D) -> - D + 87. +get_local_fqdn2() -> + case ejabberd_config:get_local_option( + fqdn, fun iolist_to_binary/1) of + ConfiguredFqdn when is_binary(ConfiguredFqdn) -> + ConfiguredFqdn; + undefined -> + {ok, Hostname} = inet:gethostname(), + {ok, {hostent, Fqdn, _, _, _, _}} = + inet:gethostbyname(Hostname), + list_to_binary(Fqdn) + end. hex(S) -> - hex(S, []). + sha:to_hexlist(S). -hex([], Res) -> - lists:reverse(Res); -hex([N | Ns], Res) -> - hex(Ns, [digit_to_xchar(N rem 16), - digit_to_xchar(N div 16) | Res]). +proplists_get_bin_value(Key, Pairs, Default) -> + case proplists:get_value(Key, Pairs, Default) of + L when is_list(L) -> + list_to_binary(L); + L2 -> + L2 + end. - -response(KeyVals, User, Passwd, Nonce, AuthzId, A2Prefix) -> - Realm = xml:get_attr_s("realm", KeyVals), - CNonce = xml:get_attr_s("cnonce", KeyVals), - DigestURI = xml:get_attr_s("digest-uri", KeyVals), - NC = xml:get_attr_s("nc", KeyVals), - QOP = xml:get_attr_s("qop", KeyVals), +response(KeyVals, User, Passwd, Nonce, AuthzId, + A2Prefix) -> + Realm = proplists_get_bin_value(<<"realm">>, KeyVals, <<>>), + CNonce = proplists_get_bin_value(<<"cnonce">>, KeyVals, <<>>), + DigestURI = proplists_get_bin_value(<<"digest-uri">>, KeyVals, <<>>), + NC = proplists_get_bin_value(<<"nc">>, KeyVals, <<>>), + QOP = proplists_get_bin_value(<<"qop">>, KeyVals, <<>>), + MD5Hash = crypto:md5(<>), A1 = case AuthzId of - "" -> - binary_to_list( - crypto:md5(User ++ ":" ++ Realm ++ ":" ++ Passwd)) ++ - ":" ++ Nonce ++ ":" ++ CNonce; - _ -> - binary_to_list( - crypto:md5(User ++ ":" ++ Realm ++ ":" ++ Passwd)) ++ - ":" ++ Nonce ++ ":" ++ CNonce ++ ":" ++ AuthzId + <<"">> -> + <>; + _ -> + <> end, A2 = case QOP of - "auth" -> - A2Prefix ++ ":" ++ DigestURI; - _ -> - A2Prefix ++ ":" ++ DigestURI ++ - ":00000000000000000000000000000000" + <<"auth">> -> + <>; + _ -> + <> end, - T = hex(binary_to_list(crypto:md5(A1))) ++ ":" ++ Nonce ++ ":" ++ - NC ++ ":" ++ CNonce ++ ":" ++ QOP ++ ":" ++ - hex(binary_to_list(crypto:md5(A2))), - hex(binary_to_list(crypto:md5(T))). - - - + T = <<(hex((crypto:md5(A1))))/binary, ":", Nonce/binary, + ":", NC/binary, ":", CNonce/binary, ":", QOP/binary, + ":", (hex((crypto:md5(A2))))/binary>>, + hex((crypto:md5(T))). diff --git a/src/cyrsasl_plain.erl b/src/cyrsasl_plain.erl index 7192cd161..c5c5f2e02 100644 --- a/src/cyrsasl_plain.erl +++ b/src/cyrsasl_plain.erl @@ -25,6 +25,7 @@ %%%---------------------------------------------------------------------- -module(cyrsasl_plain). + -author('alexey@process-one.net'). -export([start/1, stop/0, mech_new/4, mech_step/2, parse/1]). @@ -34,67 +35,56 @@ -record(state, {check_password}). start(_Opts) -> - cyrsasl:register_mechanism("PLAIN", ?MODULE, plain), + cyrsasl:register_mechanism(<<"PLAIN">>, ?MODULE, plain), ok. -stop() -> - ok. +stop() -> ok. mech_new(_Host, _GetPassword, CheckPassword, _CheckPasswordDigest) -> {ok, #state{check_password = CheckPassword}}. mech_step(State, ClientIn) -> case prepare(ClientIn) of - [AuthzId, User, Password] -> - case (State#state.check_password)(User, Password) of - {true, AuthModule} -> - {ok, [{username, User}, {authzid, AuthzId}, - {auth_module, AuthModule}]}; - _ -> - {error, "not-authorized", User} - end; - _ -> - {error, "bad-protocol"} + [AuthzId, User, Password] -> + case (State#state.check_password)(User, Password) of + {true, AuthModule} -> + {ok, + [{username, User}, {authzid, AuthzId}, + {auth_module, AuthModule}]}; + _ -> {error, <<"not-authorized">>, User} + end; + _ -> {error, <<"bad-protocol">>} end. prepare(ClientIn) -> case parse(ClientIn) of - [[], UserMaybeDomain, Password] -> - case parse_domain(UserMaybeDomain) of - %% login@domainpwd - [User, _Domain] -> - [UserMaybeDomain, User, Password]; - %% loginpwd - [User] -> - ["", User, Password] - end; - %% login@domainloginpwd - [AuthzId, User, Password] -> - [AuthzId, User, Password]; - _ -> - error + [<<"">>, UserMaybeDomain, Password] -> + case parse_domain(UserMaybeDomain) of + %% login@domainpwd + [User, _Domain] -> [UserMaybeDomain, User, Password]; + %% loginpwd + [User] -> [<<"">>, User, Password] + end; + %% login@domainloginpwd + [AuthzId, User, Password] -> [AuthzId, User, Password]; + _ -> error end. - -parse(S) -> - parse1(S, "", []). +parse(S) -> parse1(binary_to_list(S), "", []). parse1([0 | Cs], S, T) -> - parse1(Cs, "", [lists:reverse(S) | T]); -parse1([C | Cs], S, T) -> - parse1(Cs, [C | S], T); + parse1(Cs, "", [list_to_binary(lists:reverse(S)) | T]); +parse1([C | Cs], S, T) -> parse1(Cs, [C | S], T); %parse1([], [], T) -> % lists:reverse(T); parse1([], S, T) -> - lists:reverse([lists:reverse(S) | T]). + lists:reverse([list_to_binary(lists:reverse(S)) | T]). - -parse_domain(S) -> - parse_domain1(S, "", []). +parse_domain(S) -> parse_domain1(binary_to_list(S), "", []). parse_domain1([$@ | Cs], S, T) -> - parse_domain1(Cs, "", [lists:reverse(S) | T]); + parse_domain1(Cs, "", [list_to_binary(lists:reverse(S)) | T]); parse_domain1([C | Cs], S, T) -> parse_domain1(Cs, [C | S], T); parse_domain1([], S, T) -> - lists:reverse([lists:reverse(S) | T]). + lists:reverse([list_to_binary(lists:reverse(S)) | T]). diff --git a/src/cyrsasl_scram.erl b/src/cyrsasl_scram.erl index dc671b243..33d18cd1a 100644 --- a/src/cyrsasl_scram.erl +++ b/src/cyrsasl_scram.erl @@ -25,166 +25,185 @@ %%%---------------------------------------------------------------------- -module(cyrsasl_scram). + -author('stephen.roettger@googlemail.com'). --export([start/1, - stop/0, - mech_new/4, - mech_step/2]). +-export([start/1, stop/0, mech_new/4, mech_step/2]). -include("ejabberd.hrl"). -behaviour(cyrsasl). --record(state, {step, stored_key, server_key, username, get_password, check_password, - auth_message, client_nonce, server_nonce}). +-record(state, + {step = 2 :: 2 | 4, + stored_key = <<"">> :: binary(), + server_key = <<"">> :: binary(), + username = <<"">> :: binary(), + get_password :: fun(), + check_password :: fun(), + auth_message = <<"">> :: binary(), + client_nonce = <<"">> :: binary(), + server_nonce = <<"">> :: binary()}). -define(SALT_LENGTH, 16). + -define(NONCE_LENGTH, 16). start(_Opts) -> - cyrsasl:register_mechanism("SCRAM-SHA-1", ?MODULE, scram). + cyrsasl:register_mechanism(<<"SCRAM-SHA-1">>, ?MODULE, + scram). -stop() -> - ok. +stop() -> ok. -mech_new(_Host, GetPassword, _CheckPassword, _CheckPasswordDigest) -> +mech_new(_Host, GetPassword, _CheckPassword, + _CheckPasswordDigest) -> {ok, #state{step = 2, get_password = GetPassword}}. mech_step(#state{step = 2} = State, ClientIn) -> - case string:tokens(ClientIn, ",") of - [CBind, UserNameAttribute, ClientNonceAttribute] when (CBind == "y") or (CBind == "n") -> - case parse_attribute(UserNameAttribute) of - {error, Reason} -> - {error, Reason}; - {_, EscapedUserName} -> - case unescape_username(EscapedUserName) of - error -> - {error, "protocol-error-bad-username"}; - UserName -> - case parse_attribute(ClientNonceAttribute) of - {$r, ClientNonce} -> - case (State#state.get_password)(UserName) of - {false, _} -> - {error, "not-authorized", UserName}; - {Ret, _AuthModule} -> - {StoredKey, ServerKey, Salt, IterationCount} = if - is_tuple(Ret) -> - Ret; - true -> - TempSalt = crypto:rand_bytes(?SALT_LENGTH), - SaltedPassword = scram:salted_password(Ret, TempSalt, ?SCRAM_DEFAULT_ITERATION_COUNT), - {scram:stored_key(scram:client_key(SaltedPassword)), - scram:server_key(SaltedPassword), TempSalt, ?SCRAM_DEFAULT_ITERATION_COUNT} - end, - ClientFirstMessageBare = string:substr(ClientIn, string:str(ClientIn, "n=")), - ServerNonce = base64:encode_to_string(crypto:rand_bytes(?NONCE_LENGTH)), - ServerFirstMessage = "r=" ++ ClientNonce ++ ServerNonce ++ "," ++ - "s=" ++ base64:encode_to_string(Salt) ++ "," ++ - "i=" ++ integer_to_list(IterationCount), - {continue, - ServerFirstMessage, - State#state{step = 4, stored_key = StoredKey, server_key = ServerKey, - auth_message = ClientFirstMessageBare ++ "," ++ ServerFirstMessage, - client_nonce = ClientNonce, server_nonce = ServerNonce, username = UserName}} - end; - _Else -> - {error, "not-supported"} - end - end - end; - _Else -> - {error, "bad-protocol"} - end; + case str:tokens(ClientIn, <<",">>) of + [CBind, UserNameAttribute, ClientNonceAttribute] + when (CBind == <<"y">>) or (CBind == <<"n">>) -> + case parse_attribute(UserNameAttribute) of + {error, Reason} -> {error, Reason}; + {_, EscapedUserName} -> + case unescape_username(EscapedUserName) of + error -> {error, <<"protocol-error-bad-username">>}; + UserName -> + case parse_attribute(ClientNonceAttribute) of + {$r, ClientNonce} -> + case (State#state.get_password)(UserName) of + {false, _} -> {error, <<"not-authorized">>, UserName}; + {Ret, _AuthModule} -> + {StoredKey, ServerKey, Salt, IterationCount} = + if is_tuple(Ret) -> Ret; + true -> + TempSalt = + crypto:rand_bytes(?SALT_LENGTH), + SaltedPassword = + scram:salted_password(Ret, + TempSalt, + ?SCRAM_DEFAULT_ITERATION_COUNT), + {scram:stored_key(scram:client_key(SaltedPassword)), + scram:server_key(SaltedPassword), + TempSalt, + ?SCRAM_DEFAULT_ITERATION_COUNT} + end, + ClientFirstMessageBare = + str:substr(ClientIn, + str:str(ClientIn, <<"n=">>)), + ServerNonce = + jlib:encode_base64(crypto:rand_bytes(?NONCE_LENGTH)), + ServerFirstMessage = + iolist_to_binary( + ["r=", + ClientNonce, + ServerNonce, + ",", "s=", + jlib:encode_base64(Salt), + ",", "i=", + integer_to_list(IterationCount)]), + {continue, ServerFirstMessage, + State#state{step = 4, stored_key = StoredKey, + server_key = ServerKey, + auth_message = + <>, + client_nonce = ClientNonce, + server_nonce = ServerNonce, + username = UserName}} + end; + _Else -> {error, <<"not-supported">>} + end + end + end; + _Else -> {error, <<"bad-protocol">>} + end; mech_step(#state{step = 4} = State, ClientIn) -> - case string:tokens(ClientIn, ",") of - [GS2ChannelBindingAttribute, NonceAttribute, ClientProofAttribute] -> - case parse_attribute(GS2ChannelBindingAttribute) of - {$c, CVal} when (CVal == "biws") or (CVal == "eSws") -> - %% biws is base64 for n,, => channelbinding not supported - %% eSws is base64 for y,, => channelbinding supported by client only - Nonce = State#state.client_nonce ++ State#state.server_nonce, - case parse_attribute(NonceAttribute) of - {$r, CompareNonce} when CompareNonce == Nonce -> - case parse_attribute(ClientProofAttribute) of - {$p, ClientProofB64} -> - ClientProof = base64:decode(ClientProofB64), - AuthMessage = State#state.auth_message ++ "," ++ string:substr(ClientIn, 1, string:str(ClientIn, ",p=")-1), - ClientSignature = scram:client_signature(State#state.stored_key, AuthMessage), - ClientKey = scram:client_key(ClientProof, ClientSignature), - CompareStoredKey = scram:stored_key(ClientKey), - if CompareStoredKey == State#state.stored_key -> - ServerSignature = scram:server_signature(State#state.server_key, AuthMessage), - {ok, [{username, State#state.username}], "v=" ++ base64:encode_to_string(ServerSignature)}; - true -> - {error, "bad-auth"} - end; - _Else -> - {error, "bad-protocol"} - end; - {$r, _} -> - {error, "bad-nonce"}; - _Else -> - {error, "bad-protocol"} - end; - _Else -> - {error, "bad-protocol"} + case str:tokens(ClientIn, <<",">>) of + [GS2ChannelBindingAttribute, NonceAttribute, + ClientProofAttribute] -> + case parse_attribute(GS2ChannelBindingAttribute) of + {$c, CVal} when (CVal == <<"biws">>) or (CVal == <<"eSws">>) -> + %% biws is base64 for n,, => channelbinding not supported + %% eSws is base64 for y,, => channelbinding supported by client only + Nonce = <<(State#state.client_nonce)/binary, + (State#state.server_nonce)/binary>>, + case parse_attribute(NonceAttribute) of + {$r, CompareNonce} when CompareNonce == Nonce -> + case parse_attribute(ClientProofAttribute) of + {$p, ClientProofB64} -> + ClientProof = jlib:decode_base64(ClientProofB64), + AuthMessage = + iolist_to_binary( + [State#state.auth_message, + ",", + str:substr(ClientIn, 1, + str:str(ClientIn, <<",p=">>) + - 1)]), + ClientSignature = + scram:client_signature(State#state.stored_key, + AuthMessage), + ClientKey = scram:client_key(ClientProof, + ClientSignature), + CompareStoredKey = scram:stored_key(ClientKey), + if CompareStoredKey == State#state.stored_key -> + ServerSignature = + scram:server_signature(State#state.server_key, + AuthMessage), + {ok, [{username, State#state.username}], + <<"v=", + (jlib:encode_base64(ServerSignature))/binary>>}; + true -> {error, <<"bad-auth">>} + end; + _Else -> {error, <<"bad-protocol">>} + end; + {$r, _} -> {error, <<"bad-nonce">>}; + _Else -> {error, <<"bad-protocol">>} end; - _Else -> - {error, "bad-protocol"} - end. + _Else -> {error, <<"bad-protocol">>} + end; + _Else -> {error, <<"bad-protocol">>} + end. parse_attribute(Attribute) -> - AttributeLen = string:len(Attribute), - if - AttributeLen >= 3 -> - SecondChar = lists:nth(2, Attribute), - case is_alpha(lists:nth(1, Attribute)) of - true -> - if - SecondChar == $= -> - String = string:substr(Attribute, 3), - {lists:nth(1, Attribute), String}; - true -> - {error, "bad-format second char not equal sign"} - end; - _Else -> - {error, "bad-format first char not a letter"} - end; - true -> - {error, "bad-format attribute too short"} - end. + AttributeLen = byte_size(Attribute), + if AttributeLen >= 3 -> + AttributeS = binary_to_list(Attribute), + SecondChar = lists:nth(2, AttributeS), + case is_alpha(lists:nth(1, AttributeS)) of + true -> + if SecondChar == $= -> + String = str:substr(Attribute, 3), + {lists:nth(1, AttributeS), String}; + true -> {error, <<"bad-format second char not equal sign">>} + end; + _Else -> {error, <<"bad-format first char not a letter">>} + end; + true -> {error, <<"bad-format attribute too short">>} + end. -unescape_username("") -> - ""; +unescape_username(<<"">>) -> <<"">>; unescape_username(EscapedUsername) -> - Pos = string:str(EscapedUsername, "="), - if - Pos == 0 -> - EscapedUsername; - true -> - Start = string:substr(EscapedUsername, 1, Pos-1), - End = string:substr(EscapedUsername, Pos), - EndLen = string:len(End), - if - EndLen < 3 -> - error; - true -> - case string:substr(End, 1, 3) of - "=2C" -> - Start ++ "," ++ unescape_username(string:substr(End, 4)); - "=3D" -> - Start ++ "=" ++ unescape_username(string:substr(End, 4)); - _Else -> - error - end - end - end. - -is_alpha(Char) when Char >= $a, Char =< $z -> - true; -is_alpha(Char) when Char >= $A, Char =< $Z -> - true; -is_alpha(_) -> - false. + Pos = str:str(EscapedUsername, <<"=">>), + if Pos == 0 -> EscapedUsername; + true -> + Start = str:substr(EscapedUsername, 1, Pos - 1), + End = str:substr(EscapedUsername, Pos), + EndLen = byte_size(End), + if EndLen < 3 -> error; + true -> + case str:substr(End, 1, 3) of + <<"=2C">> -> + <>; + <<"=3D">> -> + <>). + +-define(CONFIG_PATH, <<"ejabberd.cfg">>). + +-define(LOG_PATH, <<"ejabberd.log">>). + +-define(EJABBERD_URI, <<"http://www.process-one.net/en/ejabberd/">>). -define(S2STIMEOUT, 600000). %%-define(DBGFSM, true). --record(scram, {storedkey, serverkey, salt, iterationcount}). +-record(scram, + {storedkey = <<"">>, + serverkey = <<"">>, + salt = <<"">>, + iterationcount = 0 :: integer()}). + +-type scram() :: #scram{}. + -define(SCRAM_DEFAULT_ITERATION_COUNT, 4096). %% --------------------------------- %% Logging mechanism %% Print in standard output --define(PRINT(Format, Args), - io:format(Format, Args)). +-define(PRINT(Format, Args), io:format(Format, Args)). -define(DEBUG(Format, Args), - ejabberd_logger:debug_msg(?MODULE,?LINE,Format, Args)). + ejabberd_logger:debug_msg(?MODULE, ?LINE, Format, + Args)). -define(INFO_MSG(Format, Args), - ejabberd_logger:info_msg(?MODULE,?LINE,Format, Args)). - + ejabberd_logger:info_msg(?MODULE, ?LINE, Format, Args)). + -define(WARNING_MSG(Format, Args), - ejabberd_logger:warning_msg(?MODULE,?LINE,Format, Args)). - + ejabberd_logger:warning_msg(?MODULE, ?LINE, Format, + Args)). + -define(ERROR_MSG(Format, Args), - ejabberd_logger:error_msg(?MODULE,?LINE,Format, Args)). + ejabberd_logger:error_msg(?MODULE, ?LINE, Format, + Args)). -define(CRITICAL_MSG(Format, Args), - ejabberd_logger:critical_msg(?MODULE,?LINE,Format, Args)). - + ejabberd_logger:critical_msg(?MODULE, ?LINE, Format, + Args)). diff --git a/src/ejabberd_admin.erl b/src/ejabberd_admin.erl index 031470c04..9fa95abf6 100644 --- a/src/ejabberd_admin.erl +++ b/src/ejabberd_admin.erl @@ -117,17 +117,17 @@ commands() -> #ejabberd_commands{name = register, tags = [accounts], desc = "Register a user", module = ?MODULE, function = register, - args = [{user, string}, {host, string}, {password, string}], + args = [{user, binary}, {host, binary}, {password, binary}], result = {res, restuple}}, #ejabberd_commands{name = unregister, tags = [accounts], desc = "Unregister a user", module = ?MODULE, function = unregister, - args = [{user, string}, {host, string}], + args = [{user, binary}, {host, binary}], result = {res, restuple}}, #ejabberd_commands{name = registered_users, tags = [accounts], desc = "List all registered users in HOST", module = ?MODULE, function = registered_users, - args = [{host, string}], + args = [{host, binary}], result = {users, {list, {username, string}}}}, #ejabberd_commands{name = registered_vhosts, tags = [server], desc = "List all registered vhosts in SERVER", @@ -158,6 +158,11 @@ commands() -> module = ejabberd_piefxis, function = export_host, args = [{dir, string}, {host, string}], result = {res, rescode}}, + #ejabberd_commands{name = export_odbc, tags = [mnesia, odbc], + desc = "Export all tables as SQL queries to a file", + module = ejd2odbc, function = export, + args = [{host, string}, {file, string}], result = {res, rescode}}, + #ejabberd_commands{name = delete_expired_messages, tags = [purge], desc = "Delete expired offline messages from database", module = ?MODULE, function = delete_expired_messages, @@ -296,11 +301,12 @@ stop_kindly(DelaySeconds, AnnouncementText) -> ok. send_service_message_all_mucs(Subject, AnnouncementText) -> - Message = io_lib:format("~s~n~s", [Subject, AnnouncementText]), + Message = list_to_binary( + io_lib:format("~s~n~s", [Subject, AnnouncementText])), lists:foreach( fun(ServerHost) -> MUCHost = gen_mod:get_module_opt_host( - ServerHost, mod_muc, "conference.@HOST@"), + ServerHost, mod_muc, <<"conference.@HOST@">>), mod_muc:broadcast_service_message(MUCHost, Message) end, ?MYHOSTS). @@ -320,6 +326,8 @@ update("all") -> update(ModStr) -> update_module(ModStr). +update_module(ModuleNameBin) when is_binary(ModuleNameBin) -> + update_module(binary_to_list(ModuleNameBin)); update_module(ModuleNameString) -> ModuleName = list_to_atom(ModuleNameString), case ejabberd_update:update([ModuleName]) of diff --git a/src/ejabberd_app.erl b/src/ejabberd_app.erl index 562c4870f..393d7afb2 100644 --- a/src/ejabberd_app.erl +++ b/src/ejabberd_app.erl @@ -57,8 +57,6 @@ start(normal, _Args) -> ejabberd_config:start(), ejabberd_check:config(), connect_nodes(), - %% Loading ASN.1 driver explicitly to avoid races in LDAP - catch asn1rt:load_driver(), Sup = ejabberd_sup:start_link(), ejabberd_rdbms:start(), ejabberd_auth:start(), @@ -135,41 +133,48 @@ db_init() -> start_modules() -> lists:foreach( fun(Host) -> - case ejabberd_config:get_local_option({modules, Host}) of - undefined -> - ok; - Modules -> - lists:foreach( - fun({Module, Args}) -> - gen_mod:start_module(Host, Module, Args) - end, Modules) - end + Modules = ejabberd_config:get_local_option( + {modules, Host}, + fun(Mods) -> + lists:map( + fun({M, A}) when is_atom(M), is_list(A) -> + {M, A} + end, Mods) + end, []), + lists:foreach( + fun({Module, Args}) -> + gen_mod:start_module(Host, Module, Args) + end, Modules) end, ?MYHOSTS). %% Stop all the modules in all the hosts stop_modules() -> lists:foreach( fun(Host) -> - case ejabberd_config:get_local_option({modules, Host}) of - undefined -> - ok; - Modules -> - lists:foreach( - fun({Module, _Args}) -> - gen_mod:stop_module_keep_config(Host, Module) - end, Modules) - end + Modules = ejabberd_config:get_local_option( + {modules, Host}, + fun(Mods) -> + lists:map( + fun({M, A}) when is_atom(M), is_list(A) -> + {M, A} + end, Mods) + end, []), + lists:foreach( + fun({Module, _Args}) -> + gen_mod:stop_module_keep_config(Host, Module) + end, Modules) end, ?MYHOSTS). connect_nodes() -> - case ejabberd_config:get_local_option(cluster_nodes) of - undefined -> - ok; - Nodes when is_list(Nodes) -> - lists:foreach(fun(Node) -> - net_kernel:connect_node(Node) - end, Nodes) - end. + Nodes = ejabberd_config:get_local_option( + cluster_nodes, + fun(Ns) -> + true = lists:all(fun is_atom/1, Ns), + Ns + end, []), + lists:foreach(fun(Node) -> + net_kernel:connect_node(Node) + end, Nodes). %% @spec () -> string() %% @doc Returns the full path to the ejabberd log file. diff --git a/src/ejabberd_auth.erl b/src/ejabberd_auth.erl index 7485f8234..298cdf1eb 100644 --- a/src/ejabberd_auth.erl +++ b/src/ejabberd_auth.erl @@ -27,32 +27,21 @@ %% TODO: Use the functions in ejabberd auth to add and remove users. -module(ejabberd_auth). + -author('alexey@process-one.net'). %% External exports --export([start/0, - set_password/3, - check_password/3, - check_password/5, - check_password_with_authmodule/3, - check_password_with_authmodule/5, - try_register/3, - dirty_get_registered_users/0, - get_vh_registered_users/1, - get_vh_registered_users/2, +-export([start/0, set_password/3, check_password/3, + check_password/5, check_password_with_authmodule/3, + check_password_with_authmodule/5, try_register/3, + dirty_get_registered_users/0, get_vh_registered_users/1, + get_vh_registered_users/2, export/1, get_vh_registered_users_number/1, - get_vh_registered_users_number/2, - get_password/2, - get_password_s/2, - get_password_with_authmodule/2, - is_user_exists/2, - is_user_exists_in_other_modules/3, - remove_user/2, - remove_user/3, - plain_password_required/1, - store_type/1, - entropy/1 - ]). + get_vh_registered_users_number/2, get_password/2, + get_password_s/2, get_password_with_authmodule/2, + is_user_exists/2, is_user_exists_in_other_modules/3, + remove_user/2, remove_user/3, plain_password_required/1, + store_type/1, entropy/1]). -export([auth_modules/1]). @@ -61,55 +50,80 @@ %%%---------------------------------------------------------------------- %%% API %%%---------------------------------------------------------------------- -start() -> - lists:foreach( - fun(Host) -> - lists:foreach( - fun(M) -> - M:start(Host) - end, auth_modules(Host)) - end, ?MYHOSTS). +-type opts() :: [{prefix, binary()} | {from, integer()} | + {to, integer()} | {limit, integer()} | + {offset, integer()}]. +-callback start(binary()) -> any(). +-callback plain_password_required() -> boolean(). +-callback store_type() -> plain | external | scram. +-callback set_password(binary(), binary(), binary()) -> ok | {error, atom()}. +-callback remove_user(binary(), binary()) -> any(). +-callback remove_user(binary(), binary(), binary()) -> any(). +-callback is_user_exists(binary(), binary()) -> boolean() | {error, atom()}. +-callback check_password(binary(), binary(), binary()) -> boolean(). +-callback check_password(binary(), binary(), binary(), binary(), + fun((binary()) -> binary())) -> boolean(). +-callback try_register(binary(), binary(), binary()) -> {atomic, atom()} | + {error, atom()}. +-callback dirty_get_registered_users() -> [{binary(), binary()}]. +-callback get_vh_registered_users(binary()) -> [{binary(), binary()}]. +-callback get_vh_registered_users(binary(), opts()) -> [{binary(), binary()}]. +-callback get_vh_registered_users_number(binary()) -> number(). +-callback get_vh_registered_users_number(binary(), opts()) -> number(). +-callback get_password(binary(), binary()) -> false | binary(). +-callback get_password_s(binary(), binary()) -> binary(). + +start() -> %% This is only executed by ejabberd_c2s for non-SASL auth client + lists:foreach(fun (Host) -> + lists:foreach(fun (M) -> M:start(Host) end, + auth_modules(Host)) + end, + ?MYHOSTS). + plain_password_required(Server) -> - lists:any( - fun(M) -> - M:plain_password_required() - end, auth_modules(Server)). + lists:any(fun (M) -> M:plain_password_required() end, + auth_modules(Server)). store_type(Server) -> - lists:foldl( - fun(_, external) -> - external; - (M, scram) -> - case M:store_type() of - external -> - external; - _Else -> - scram - end; - (M, plain) -> - M:store_type() - end, plain, auth_modules(Server)). - %% @doc Check if the user and password can login in server. %% @spec (User::string(), Server::string(), Password::string()) -> %% true | false + lists:foldl(fun (_, external) -> external; + (M, scram) -> + case M:store_type() of + external -> external; + _Else -> scram + end; + (M, plain) -> M:store_type() + end, + plain, auth_modules(Server)). + +-spec check_password(binary(), binary(), binary()) -> boolean(). + check_password(User, Server, Password) -> - case check_password_with_authmodule(User, Server, Password) of - {true, _AuthModule} -> true; - false -> false + case check_password_with_authmodule(User, Server, + Password) + of + {true, _AuthModule} -> true; + false -> false end. %% @doc Check if the user and password can login in server. %% @spec (User::string(), Server::string(), Password::string(), %% Digest::string(), DigestGen::function()) -> %% true | false -check_password(User, Server, Password, Digest, DigestGen) -> - case check_password_with_authmodule(User, Server, Password, - Digest, DigestGen) of - {true, _AuthModule} -> true; - false -> false +-spec check_password(binary(), binary(), binary(), binary(), + fun((binary()) -> binary())) -> boolean(). + +check_password(User, Server, Password, Digest, + DigestGen) -> + case check_password_with_authmodule(User, Server, + Password, Digest, DigestGen) + of + {true, _AuthModule} -> true; + false -> false end. %% @doc Check if the user and password can login in server. @@ -122,199 +136,224 @@ check_password(User, Server, Password, Digest, DigestGen) -> %% AuthModule = ejabberd_auth_anonymous | ejabberd_auth_external %% | ejabberd_auth_internal | ejabberd_auth_ldap %% | ejabberd_auth_odbc | ejabberd_auth_pam -check_password_with_authmodule(User, Server, Password) -> - check_password_loop(auth_modules(Server), [User, Server, Password]). +-spec check_password_with_authmodule(binary(), binary(), binary()) -> false | + {true, atom()}. -check_password_with_authmodule(User, Server, Password, Digest, DigestGen) -> - check_password_loop(auth_modules(Server), [User, Server, Password, - Digest, DigestGen]). +check_password_with_authmodule(User, Server, + Password) -> + check_password_loop(auth_modules(Server), + [User, Server, Password]). -check_password_loop([], _Args) -> - false; +-spec check_password_with_authmodule(binary(), binary(), binary(), binary(), + fun((binary()) -> binary())) -> false | + {true, atom()}. + +check_password_with_authmodule(User, Server, Password, + Digest, DigestGen) -> + check_password_loop(auth_modules(Server), + [User, Server, Password, Digest, DigestGen]). + +check_password_loop([], _Args) -> false; check_password_loop([AuthModule | AuthModules], Args) -> case apply(AuthModule, check_password, Args) of - true -> - {true, AuthModule}; - false -> - check_password_loop(AuthModules, Args) + true -> {true, AuthModule}; + false -> check_password_loop(AuthModules, Args) end. +-spec set_password(binary(), binary(), binary()) -> ok | + {error, atom()}. %% @spec (User::string(), Server::string(), Password::string()) -> %% ok | {error, ErrorType} %% where ErrorType = empty_password | not_allowed | invalid_jid -set_password(_User, _Server, "") -> - %% We do not allow empty password +set_password(_User, _Server, <<"">>) -> {error, empty_password}; set_password(User, Server, Password) -> - lists:foldl( - fun(M, {error, _}) -> - M:set_password(User, Server, Password); - (_M, Res) -> - Res - end, {error, not_allowed}, auth_modules(Server)). - %% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, not_allowed} -try_register(_User, _Server, "") -> - %% We do not allow empty password - {error, not_allowed}; + lists:foldl(fun (M, {error, _}) -> + M:set_password(User, Server, Password); + (_M, Res) -> Res + end, + {error, not_allowed}, auth_modules(Server)). + +-spec try_register(binary(), binary(), binary()) -> {atomic, atom()} | + {error, atom()}. + +try_register(_User, _Server, <<"">>) -> + {error, not_allowed}; try_register(User, Server, Password) -> - case is_user_exists(User,Server) of - true -> - {atomic, exists}; - false -> - case lists:member(jlib:nameprep(Server), ?MYHOSTS) of - true -> - Res = lists:foldl( - fun(_M, {atomic, ok} = Res) -> - Res; - (M, _) -> - M:try_register(User, Server, Password) - end, {error, not_allowed}, auth_modules(Server)), - case Res of - {atomic, ok} -> - ejabberd_hooks:run(register_user, Server, - [User, Server]), - {atomic, ok}; - _ -> Res - end; - false -> - {error, not_allowed} - end + case is_user_exists(User, Server) of + true -> {atomic, exists}; + false -> + case lists:member(jlib:nameprep(Server), ?MYHOSTS) of + true -> + Res = lists:foldl(fun (_M, {atomic, ok} = Res) -> Res; + (M, _) -> + M:try_register(User, Server, Password) + end, + {error, not_allowed}, auth_modules(Server)), + case Res of + {atomic, ok} -> + ejabberd_hooks:run(register_user, Server, + [User, Server]), + {atomic, ok}; + _ -> Res + end; + false -> {error, not_allowed} + end end. %% Registered users list do not include anonymous users logged +-spec dirty_get_registered_users() -> [{binary(), binary()}]. + dirty_get_registered_users() -> - lists:flatmap( - fun(M) -> - M:dirty_get_registered_users() - end, auth_modules()). + lists:flatmap(fun (M) -> M:dirty_get_registered_users() + end, + auth_modules()). + +-spec get_vh_registered_users(binary()) -> [{binary(), binary()}]. %% Registered users list do not include anonymous users logged get_vh_registered_users(Server) -> - lists:flatmap( - fun(M) -> - M:get_vh_registered_users(Server) - end, auth_modules(Server)). + lists:flatmap(fun (M) -> + M:get_vh_registered_users(Server) + end, + auth_modules(Server)). + +-spec get_vh_registered_users(binary(), opts()) -> [{binary(), binary()}]. get_vh_registered_users(Server, Opts) -> - lists:flatmap( - fun(M) -> - case erlang:function_exported( - M, get_vh_registered_users, 2) of - true -> - M:get_vh_registered_users(Server, Opts); - false -> - M:get_vh_registered_users(Server) - end - end, auth_modules(Server)). + lists:flatmap(fun (M) -> + case erlang:function_exported(M, + get_vh_registered_users, + 2) + of + true -> M:get_vh_registered_users(Server, Opts); + false -> M:get_vh_registered_users(Server) + end + end, + auth_modules(Server)). get_vh_registered_users_number(Server) -> - lists:sum( - lists:map( - fun(M) -> - case erlang:function_exported( - M, get_vh_registered_users_number, 1) of - true -> - M:get_vh_registered_users_number(Server); - false -> - length(M:get_vh_registered_users(Server)) - end - end, auth_modules(Server))). + lists:sum(lists:map(fun (M) -> + case erlang:function_exported(M, + get_vh_registered_users_number, + 1) + of + true -> + M:get_vh_registered_users_number(Server); + false -> + length(M:get_vh_registered_users(Server)) + end + end, + auth_modules(Server))). + +-spec get_vh_registered_users_number(binary(), opts()) -> number(). get_vh_registered_users_number(Server, Opts) -> - lists:sum( - lists:map( - fun(M) -> - case erlang:function_exported( - M, get_vh_registered_users_number, 2) of - true -> - M:get_vh_registered_users_number(Server, Opts); - false -> - length(M:get_vh_registered_users(Server)) - end - end, auth_modules(Server))). - %% @doc Get the password of the user. %% @spec (User::string(), Server::string()) -> Password::string() + lists:sum(lists:map(fun (M) -> + case erlang:function_exported(M, + get_vh_registered_users_number, + 2) + of + true -> + M:get_vh_registered_users_number(Server, + Opts); + false -> + length(M:get_vh_registered_users(Server)) + end + end, + auth_modules(Server))). + +-spec get_password(binary(), binary()) -> false | binary(). + get_password(User, Server) -> - lists:foldl( - fun(M, false) -> - M:get_password(User, Server); - (_M, Password) -> - Password - end, false, auth_modules(Server)). + lists:foldl(fun (M, false) -> + M:get_password(User, Server); + (_M, Password) -> Password + end, + false, auth_modules(Server)). + +-spec get_password_s(binary(), binary()) -> binary(). get_password_s(User, Server) -> case get_password(User, Server) of - false -> - ""; - Password when is_list(Password) -> - Password; - _ -> - "" + false -> <<"">>; + Password -> Password end. %% @doc Get the password of the user and the auth module. %% @spec (User::string(), Server::string()) -> %% {Password::string(), AuthModule::atom()} | {false, none} -get_password_with_authmodule(User, Server) -> - lists:foldl( - fun(M, {false, _}) -> - {M:get_password(User, Server), M}; - (_M, {Password, AuthModule}) -> - {Password, AuthModule} - end, {false, none}, auth_modules(Server)). +-spec get_password_with_authmodule(binary(), binary()) -> {false | binary(), atom()}. +get_password_with_authmodule(User, Server) -> %% Returns true if the user exists in the DB or if an anonymous user is logged %% under the given name -is_user_exists(User, Server) -> - lists:any( - fun(M) -> - case M:is_user_exists(User, Server) of - {error, Error} -> - ?ERROR_MSG("The authentication module ~p returned an " - "error~nwhen checking user ~p in server ~p~n" - "Error message: ~p", - [M, User, Server, Error]), - false; - Else -> - Else - end - end, auth_modules(Server)). + lists:foldl(fun (M, {false, _}) -> + {M:get_password(User, Server), M}; + (_M, {Password, AuthModule}) -> {Password, AuthModule} + end, + {false, none}, auth_modules(Server)). +-spec is_user_exists(binary(), binary()) -> boolean(). + +is_user_exists(User, Server) -> %% Check if the user exists in all authentications module except the module %% passed as parameter %% @spec (Module::atom(), User, Server) -> true | false | maybe + lists:any(fun (M) -> + case M:is_user_exists(User, Server) of + {error, Error} -> + ?ERROR_MSG("The authentication module ~p returned " + "an error~nwhen checking user ~p in server " + "~p~nError message: ~p", + [M, User, Server, Error]), + false; + Else -> Else + end + end, + auth_modules(Server)). + +-spec is_user_exists_in_other_modules(atom(), binary(), binary()) -> boolean() | maybe. + is_user_exists_in_other_modules(Module, User, Server) -> - is_user_exists_in_other_modules_loop( - auth_modules(Server)--[Module], - User, Server). -is_user_exists_in_other_modules_loop([], _User, _Server) -> + is_user_exists_in_other_modules_loop(auth_modules(Server) + -- [Module], + User, Server). + +is_user_exists_in_other_modules_loop([], _User, + _Server) -> false; -is_user_exists_in_other_modules_loop([AuthModule|AuthModules], User, Server) -> +is_user_exists_in_other_modules_loop([AuthModule + | AuthModules], + User, Server) -> case AuthModule:is_user_exists(User, Server) of - true -> - true; - false -> - is_user_exists_in_other_modules_loop(AuthModules, User, Server); - {error, Error} -> - ?DEBUG("The authentication module ~p returned an error~nwhen " - "checking user ~p in server ~p~nError message: ~p", - [AuthModule, User, Server, Error]), - maybe + true -> true; + false -> + is_user_exists_in_other_modules_loop(AuthModules, User, + Server); + {error, Error} -> + ?DEBUG("The authentication module ~p returned " + "an error~nwhen checking user ~p in server " + "~p~nError message: ~p", + [AuthModule, User, Server, Error]), + maybe end. +-spec remove_user(binary(), binary()) -> ok. %% @spec (User, Server) -> ok %% @doc Remove user. %% Note: it may return ok even if there was some problem removing the user. remove_user(User, Server) -> - lists:foreach( - fun(M) -> - M:remove_user(User, Server) - end, auth_modules(Server)), - ejabberd_hooks:run(remove_user, jlib:nameprep(Server), [User, Server]), + lists:foreach(fun (M) -> M:remove_user(User, Server) + end, + auth_modules(Server)), + ejabberd_hooks:run(remove_user, jlib:nameprep(Server), + [User, Server]), ok. %% @spec (User, Server, Password) -> ok | not_exists | not_allowed | bad_request | error @@ -322,41 +361,49 @@ remove_user(User, Server) -> %% The removal is attempted in each auth method provided: %% when one returns 'ok' the loop stops; %% if no method returns 'ok' then it returns the error message indicated by the last method attempted. +-spec remove_user(binary(), binary(), binary()) -> any(). + remove_user(User, Server, Password) -> - R = lists:foldl( - fun(_M, ok = Res) -> - Res; - (M, _) -> - M:remove_user(User, Server, Password) - end, error, auth_modules(Server)), + R = lists:foldl(fun (_M, ok = Res) -> Res; + (M, _) -> M:remove_user(User, Server, Password) + end, + error, auth_modules(Server)), case R of - ok -> ejabberd_hooks:run(remove_user, jlib:nameprep(Server), [User, Server]); - _ -> none + ok -> + ejabberd_hooks:run(remove_user, jlib:nameprep(Server), + [User, Server]); + _ -> none end, R. %% @spec (IOList) -> non_negative_float() %% @doc Calculate informational entropy. -entropy(IOList) -> - case binary_to_list(iolist_to_binary(IOList)) of - "" -> - 0.0; - S -> - Set = lists:foldl( - fun(C, [Digit, Printable, LowLetter, HiLetter, Other]) -> - if C >= $a, C =< $z -> - [Digit, Printable, 26, HiLetter, Other]; - C >= $0, C =< $9 -> - [9, Printable, LowLetter, HiLetter, Other]; - C >= $A, C =< $Z -> - [Digit, Printable, LowLetter, 26, Other]; - C >= 16#21, C =< 16#7e -> - [Digit, 33, LowLetter, HiLetter, Other]; - true -> - [Digit, Printable, LowLetter, HiLetter, 128] - end - end, [0, 0, 0, 0, 0], S), - length(S) * math:log(lists:sum(Set))/math:log(2) +entropy(B) -> + case binary_to_list(B) of + "" -> 0.0; + S -> + Set = lists:foldl(fun (C, + [Digit, Printable, LowLetter, HiLetter, + Other]) -> + if C >= $a, C =< $z -> + [Digit, Printable, 26, HiLetter, + Other]; + C >= $0, C =< $9 -> + [9, Printable, LowLetter, HiLetter, + Other]; + C >= $A, C =< $Z -> + [Digit, Printable, LowLetter, 26, + Other]; + C >= 33, C =< 126 -> + [Digit, 33, LowLetter, HiLetter, + Other]; + true -> + [Digit, Printable, LowLetter, + HiLetter, 128] + end + end, + [0, 0, 0, 0, 0], S), + length(S) * math:log(lists:sum(Set)) / math:log(2) end. %%%---------------------------------------------------------------------- @@ -365,19 +412,27 @@ entropy(IOList) -> %% Return the lists of all the auth modules actually used in the %% configuration auth_modules() -> - lists:usort( - lists:flatmap( - fun(Server) -> - auth_modules(Server) - end, ?MYHOSTS)). + lists:usort(lists:flatmap(fun (Server) -> + auth_modules(Server) + end, + ?MYHOSTS)). + +-spec auth_modules(binary()) -> [atom()]. %% Return the list of authenticated modules for a given host auth_modules(Server) -> LServer = jlib:nameprep(Server), - Method = ejabberd_config:get_local_option({auth_method, LServer}), - Methods = if - Method == undefined -> []; - is_list(Method) -> Method; - is_atom(Method) -> [Method] - end, - [list_to_atom("ejabberd_auth_" ++ atom_to_list(M)) || M <- Methods]. + Methods = ejabberd_config:get_local_option( + {auth_method, LServer}, + fun(V) when is_list(V) -> + true = lists:all(fun is_atom/1, V), + V; + (V) when is_atom(V) -> + [V] + end, []), + [jlib:binary_to_atom(<<"ejabberd_auth_", + (jlib:atom_to_binary(M))/binary>>) + || M <- Methods]. + +export(Server) -> + ejabberd_auth_internal:export(Server). diff --git a/src/ejabberd_auth_anonymous.erl b/src/ejabberd_auth_anonymous.erl index ebdbf9680..c19effabe 100644 --- a/src/ejabberd_auth_anonymous.erl +++ b/src/ejabberd_auth_anonymous.erl @@ -39,27 +39,24 @@ %% Function used by ejabberd_auth: --export([login/2, - set_password/3, - check_password/3, - check_password/5, - try_register/3, - dirty_get_registered_users/0, - get_vh_registered_users/1, - get_password/2, - get_password/3, - is_user_exists/2, - remove_user/2, - remove_user/3, - store_type/0, +-export([login/2, set_password/3, check_password/3, + check_password/5, try_register/3, + dirty_get_registered_users/0, get_vh_registered_users/1, + get_vh_registered_users/2, get_vh_registered_users_number/1, + get_vh_registered_users_number/2, get_password_s/2, + get_password/2, get_password/3, is_user_exists/2, + remove_user/2, remove_user/3, store_type/0, plain_password_required/0]). -include("ejabberd.hrl"). + -include("jlib.hrl"). --record(anonymous, {us, sid}). %% Create the anonymous table if at least one virtual host has anonymous features enabled %% Register to login / logout events +-record(anonymous, {us = {<<"">>, <<"">>} :: {binary(), binary()}, + sid = {now(), self()} :: ejabberd_sm:sid()}). + start(Host) -> %% TODO: Check cluster mode mnesia:create_table(anonymous, [{ram_copies, [node()]}, @@ -80,13 +77,13 @@ allow_anonymous(Host) -> %% anonymous protocol can be: sasl_anon|login_anon|both is_sasl_anonymous_enabled(Host) -> case allow_anonymous(Host) of - false -> false; - true -> - case anonymous_protocol(Host) of - sasl_anon -> true; - both -> true; - _Other -> false - end + false -> false; + true -> + case anonymous_protocol(Host) of + sasl_anon -> true; + both -> true; + _Other -> false + end end. %% Return true if anonymous login is enabled on the server @@ -94,30 +91,33 @@ is_sasl_anonymous_enabled(Host) -> %% clients that do not support anonymous login) is_login_anonymous_enabled(Host) -> case allow_anonymous(Host) of - false -> false; - true -> - case anonymous_protocol(Host) of - login_anon -> true; - both -> true; - _Other -> false - end + false -> false; + true -> + case anonymous_protocol(Host) of + login_anon -> true; + both -> true; + _Other -> false + end end. %% Return the anonymous protocol to use: sasl_anon|login_anon|both %% defaults to login_anon anonymous_protocol(Host) -> - case ejabberd_config:get_local_option({anonymous_protocol, Host}) of - sasl_anon -> sasl_anon; - login_anon -> login_anon; - both -> both; - _Other -> sasl_anon - end. + ejabberd_config:get_local_option( + {anonymous_protocol, Host}, + fun(sasl_anon) -> sasl_anon; + (login_anon) -> login_anon; + (both) -> both + end, + sasl_anon). %% Return true if multiple connections have been allowed in the config file %% defaults to false allow_multiple_connections(Host) -> ejabberd_config:get_local_option( - {allow_multiple_connections, Host}) =:= true. + {allow_multiple_connections, Host}, + fun(V) when is_boolean(V) -> V end, + false). %% Check if user exist in the anonymus database anonymous_user_exist(User, Server) -> @@ -134,36 +134,39 @@ anonymous_user_exist(User, Server) -> %% Remove connection from Mnesia tables remove_connection(SID, LUser, LServer) -> US = {LUser, LServer}, - F = fun() -> - mnesia:delete_object({anonymous, US, SID}) - end, + F = fun () -> mnesia:delete_object({anonymous, US, SID}) + end, mnesia:transaction(F). %% Register connection -register_connection(SID, #jid{luser = LUser, lserver = LServer}, Info) -> - AuthModule = xml:get_attr_s(auth_module, Info), - case AuthModule == ?MODULE of - true -> - ejabberd_hooks:run(register_user, LServer, [LUser, LServer]), - US = {LUser, LServer}, - mnesia:sync_dirty( - fun() -> mnesia:write(#anonymous{us = US, sid=SID}) - end); - false -> - ok +register_connection(SID, + #jid{luser = LUser, lserver = LServer}, Info) -> + AuthModule = list_to_atom(binary_to_list(xml:get_attr_s(<<"auth_module">>, Info))), + case AuthModule == (?MODULE) of + true -> + ejabberd_hooks:run(register_user, LServer, + [LUser, LServer]), + US = {LUser, LServer}, + mnesia:sync_dirty(fun () -> + mnesia:write(#anonymous{us = US, + sid = SID}) + end); + false -> ok end. %% Remove an anonymous user from the anonymous users table -unregister_connection(SID, #jid{luser = LUser, lserver = LServer}, _) -> - purge_hook(anonymous_user_exist(LUser, LServer), - LUser, LServer), +unregister_connection(SID, + #jid{luser = LUser, lserver = LServer}, _) -> + purge_hook(anonymous_user_exist(LUser, LServer), LUser, + LServer), remove_connection(SID, LUser, LServer). %% Launch the hook to purge user data only for anonymous users purge_hook(false, _LUser, _LServer) -> ok; purge_hook(true, LUser, LServer) -> - ejabberd_hooks:run(anonymous_purge_hook, LServer, [LUser, LServer]). + ejabberd_hooks:run(anonymous_purge_hook, LServer, + [LUser, LServer]). %% --------------------------------- %% Specific anonymous auth functions @@ -172,41 +175,42 @@ purge_hook(true, LUser, LServer) -> %% When anonymous login is enabled, check the password for permenant users %% before allowing access check_password(User, Server, Password) -> - check_password(User, Server, Password, undefined, undefined). -check_password(User, Server, _Password, _Digest, _DigestGen) -> - %% We refuse login for registered accounts (They cannot logged but - %% they however are "reserved") - case ejabberd_auth:is_user_exists_in_other_modules(?MODULE, - User, Server) of - %% If user exists in other module, reject anonnymous authentication - true -> false; - %% If we are not sure whether the user exists in other module, reject anon auth - maybe -> false; - false -> login(User, Server) + check_password(User, Server, Password, undefined, + undefined). + +check_password(User, Server, _Password, _Digest, + _DigestGen) -> + case + ejabberd_auth:is_user_exists_in_other_modules(?MODULE, + User, Server) + of + %% If user exists in other module, reject anonnymous authentication + true -> false; + %% If we are not sure whether the user exists in other module, reject anon auth + maybe -> false; + false -> login(User, Server) end. login(User, Server) -> case is_login_anonymous_enabled(Server) of - false -> false; - true -> - case anonymous_user_exist(User, Server) of - %% Reject the login if an anonymous user with the same login - %% is already logged and if multiple login has not been enable - %% in the config file. - true -> allow_multiple_connections(Server); - %% Accept login and add user to the anonymous table - false -> true - end + false -> false; + true -> + case anonymous_user_exist(User, Server) of + %% Reject the login if an anonymous user with the same login + %% is already logged and if multiple login has not been enable + %% in the config file. + true -> allow_multiple_connections(Server); + %% Accept login and add user to the anonymous table + false -> true + end end. %% When anonymous login is enabled, check that the user is permanent before %% changing its password set_password(User, Server, _Password) -> case anonymous_user_exist(User, Server) of - true -> - ok; - false -> - {error, not_allowed} + true -> ok; + false -> {error, not_allowed} end. %% When anonymous login is enabled, check if permanent users are allowed on @@ -214,25 +218,42 @@ set_password(User, Server, _Password) -> try_register(_User, _Server, _Password) -> {error, not_allowed}. -dirty_get_registered_users() -> - []. +dirty_get_registered_users() -> []. get_vh_registered_users(Server) -> - [{U, S} || {U, S, _R} <- ejabberd_sm:get_vh_session_list(Server)]. + [{U, S} + || {U, S, _R} + <- ejabberd_sm:get_vh_session_list(Server)]. +get_vh_registered_users(Server, _) -> + get_vh_registered_users(Server). + +get_vh_registered_users_number(Server) -> + length(get_vh_registered_users(Server)). + +get_vh_registered_users_number(Server, _) -> + get_vh_registered_users_number(Server). %% Return password of permanent user or false for anonymous users get_password(User, Server) -> - get_password(User, Server, ""). + get_password(User, Server, <<"">>). get_password(User, Server, DefaultValue) -> - case anonymous_user_exist(User, Server) or login(User, Server) of - %% We return the default value if the user is anonymous - true -> - DefaultValue; - %% We return the permanent user password otherwise - false -> - false + case anonymous_user_exist(User, Server) or + login(User, Server) + of + %% We return the default value if the user is anonymous + true -> DefaultValue; + %% We return the permanent user password otherwise + false -> false + end. + +get_password_s(User, Server) -> + case get_password(User, Server) of + false -> + <<"">>; + Password -> + Password end. %% Returns true if the user exists in the DB or if an anonymous user is logged @@ -240,14 +261,11 @@ get_password(User, Server, DefaultValue) -> is_user_exists(User, Server) -> anonymous_user_exist(User, Server). -remove_user(_User, _Server) -> - {error, not_allowed}. +remove_user(_User, _Server) -> {error, not_allowed}. -remove_user(_User, _Server, _Password) -> - not_allowed. +remove_user(_User, _Server, _Password) -> not_allowed. -plain_password_required() -> - false. +plain_password_required() -> false. store_type() -> plain. diff --git a/src/ejabberd_auth_external.erl b/src/ejabberd_auth_external.erl index 08f2856ac..8ae6a1df5 100644 --- a/src/ejabberd_auth_external.erl +++ b/src/ejabberd_auth_external.erl @@ -25,27 +25,21 @@ %%%---------------------------------------------------------------------- -module(ejabberd_auth_external). + -author('alexey@process-one.net'). +-behaviour(ejabberd_auth). + %% External exports --export([start/1, - set_password/3, - check_password/3, - check_password/5, - try_register/3, - dirty_get_registered_users/0, - get_vh_registered_users/1, +-export([start/1, set_password/3, check_password/3, + check_password/5, try_register/3, + dirty_get_registered_users/0, get_vh_registered_users/1, get_vh_registered_users/2, get_vh_registered_users_number/1, - get_vh_registered_users_number/2, - get_password/2, - get_password_s/2, - is_user_exists/2, - remove_user/2, - remove_user/3, - store_type/0, - plain_password_required/0 - ]). + get_vh_registered_users_number/2, get_password/2, + get_password_s/2, is_user_exists/2, remove_user/2, + remove_user/3, store_type/0, + plain_password_required/0]). -include("ejabberd.hrl"). @@ -53,55 +47,59 @@ %%% API %%%---------------------------------------------------------------------- start(Host) -> - extauth:start( - Host, ejabberd_config:get_local_option({extauth_program, Host})), + Cmd = ejabberd_config:get_local_option( + {extauth_program, Host}, + fun(V) -> + binary_to_list(iolist_to_binary(V)) + end, + "extauth"), + extauth:start(Host, Cmd), case check_cache_last_options(Host) of - cache -> - ok = ejabberd_auth_internal:start(Host); - no_cache -> - ok + cache -> ok = ejabberd_auth_internal:start(Host); + no_cache -> ok end. check_cache_last_options(Server) -> - %% if extauth_cache is enabled, then a mod_last module must also be enabled case get_cache_option(Server) of - false -> no_cache; - {true, _CacheTime} -> - case get_mod_last_configured(Server) of - no_mod_last -> - ?ERROR_MSG("In host ~p extauth is used, extauth_cache is enabled but " - "mod_last is not enabled.", [Server]), - no_cache; - _ -> cache - end + false -> no_cache; + {true, _CacheTime} -> + case get_mod_last_configured(Server) of + no_mod_last -> + ?ERROR_MSG("In host ~p extauth is used, extauth_cache " + "is enabled but mod_last is not enabled.", + [Server]), + no_cache; + _ -> cache + end end. -plain_password_required() -> - true. +plain_password_required() -> true. -store_type() -> - external. +store_type() -> external. check_password(User, Server, Password) -> case get_cache_option(Server) of - false -> check_password_extauth(User, Server, Password); - {true, CacheTime} -> check_password_cache(User, Server, Password, CacheTime) + false -> check_password_extauth(User, Server, Password); + {true, CacheTime} -> + check_password_cache(User, Server, Password, CacheTime) end. -check_password(User, Server, Password, _Digest, _DigestGen) -> +check_password(User, Server, Password, _Digest, + _DigestGen) -> check_password(User, Server, Password). set_password(User, Server, Password) -> case extauth:set_password(User, Server, Password) of - true -> set_password_internal(User, Server, Password), - ok; - _ -> {error, unknown_problem} + true -> + set_password_internal(User, Server, Password), ok; + _ -> {error, unknown_problem} end. try_register(User, Server, Password) -> case get_cache_option(Server) of - false -> try_register_extauth(User, Server, Password); - {true, _CacheTime} -> try_register_external_cache(User, Server, Password) + false -> try_register_extauth(User, Server, Password); + {true, _CacheTime} -> + try_register_external_cache(User, Server, Password) end. dirty_get_registered_users() -> @@ -110,56 +108,60 @@ dirty_get_registered_users() -> get_vh_registered_users(Server) -> ejabberd_auth_internal:get_vh_registered_users(Server). -get_vh_registered_users(Server, Data) -> - ejabberd_auth_internal:get_vh_registered_users(Server, Data). +get_vh_registered_users(Server, Data) -> + ejabberd_auth_internal:get_vh_registered_users(Server, + Data). get_vh_registered_users_number(Server) -> ejabberd_auth_internal:get_vh_registered_users_number(Server). get_vh_registered_users_number(Server, Data) -> - ejabberd_auth_internal:get_vh_registered_users_number(Server, Data). + ejabberd_auth_internal:get_vh_registered_users_number(Server, + Data). %% The password can only be returned if cache is enabled, cached info exists and is fresh enough. get_password(User, Server) -> case get_cache_option(Server) of - false -> false; - {true, CacheTime} -> get_password_cache(User, Server, CacheTime) + false -> false; + {true, CacheTime} -> + get_password_cache(User, Server, CacheTime) end. get_password_s(User, Server) -> case get_password(User, Server) of - false -> []; - Other -> Other + false -> <<"">>; + Other -> Other end. %% @spec (User, Server) -> true | false | {error, Error} is_user_exists(User, Server) -> try extauth:is_user_exists(User, Server) of - Res -> Res + Res -> Res catch - _:Error -> {error, Error} + _:Error -> {error, Error} end. remove_user(User, Server) -> case extauth:remove_user(User, Server) of - false -> false; - true -> - case get_cache_option(Server) of - false -> false; - {true, _CacheTime} -> - ejabberd_auth_internal:remove_user(User, Server) - end + false -> false; + true -> + case get_cache_option(Server) of + false -> false; + {true, _CacheTime} -> + ejabberd_auth_internal:remove_user(User, Server) + end end. remove_user(User, Server, Password) -> case extauth:remove_user(User, Server, Password) of - false -> false; - true -> - case get_cache_option(Server) of - false -> false; - {true, _CacheTime} -> - ejabberd_auth_internal:remove_user(User, Server, Password) - end + false -> false; + true -> + case get_cache_option(Server) of + false -> false; + {true, _CacheTime} -> + ejabberd_auth_internal:remove_user(User, Server, + Password) + end end. %%% @@ -168,45 +170,50 @@ remove_user(User, Server, Password) -> %% @spec (Host::string()) -> false | {true, CacheTime::integer()} get_cache_option(Host) -> - case ejabberd_config:get_local_option({extauth_cache, Host}) of - CacheTime when is_integer(CacheTime) -> {true, CacheTime}; - _ -> false + case ejabberd_config:get_local_option( + {extauth_cache, Host}, + fun(I) when is_integer(I), I > 0 -> I end) of + undefined -> false; + CacheTime -> {true, CacheTime} end. %% @spec (User, Server, Password) -> true | false check_password_extauth(User, Server, Password) -> - extauth:check_password(User, Server, Password) andalso Password /= "". + extauth:check_password(User, Server, Password) andalso + Password /= <<"">>. %% @spec (User, Server, Password) -> true | false try_register_extauth(User, Server, Password) -> extauth:try_register(User, Server, Password). -check_password_cache(User, Server, Password, CacheTime) -> +check_password_cache(User, Server, Password, + CacheTime) -> case get_last_access(User, Server) of - online -> - check_password_internal(User, Server, Password); - never -> - check_password_external_cache(User, Server, Password); - mod_last_required -> - ?ERROR_MSG("extauth is used, extauth_cache is enabled but mod_last is not enabled in that host", []), - check_password_external_cache(User, Server, Password); - TimeStamp -> - %% If last access exists, compare last access with cache refresh time - case is_fresh_enough(TimeStamp, CacheTime) of - %% If no need to refresh, check password against Mnesia - true -> - case check_password_internal(User, Server, Password) of - %% If password valid in Mnesia, accept it - true -> - true; - %% Else (password nonvalid in Mnesia), check in extauth and cache result - false -> - check_password_external_cache(User, Server, Password) - end; - %% Else (need to refresh), check in extauth and cache result - false -> - check_password_external_cache(User, Server, Password) - end + online -> + check_password_internal(User, Server, Password); + never -> + check_password_external_cache(User, Server, Password); + mod_last_required -> + ?ERROR_MSG("extauth is used, extauth_cache is enabled " + "but mod_last is not enabled in that " + "host", + []), + check_password_external_cache(User, Server, Password); + TimeStamp -> + case is_fresh_enough(TimeStamp, CacheTime) of + %% If no need to refresh, check password against Mnesia + true -> + case check_password_internal(User, Server, Password) of + %% If password valid in Mnesia, accept it + true -> true; + %% Else (password nonvalid in Mnesia), check in extauth and cache result + false -> + check_password_external_cache(User, Server, Password) + end; + %% Else (need to refresh), check in extauth and cache result + false -> + check_password_external_cache(User, Server, Password) + end end. get_password_internal(User, Server) -> @@ -215,60 +222,54 @@ get_password_internal(User, Server) -> %% @spec (User, Server, CacheTime) -> false | Password::string() get_password_cache(User, Server, CacheTime) -> case get_last_access(User, Server) of - online -> - get_password_internal(User, Server); - never -> - false; - mod_last_required -> - ?ERROR_MSG("extauth is used, extauth_cache is enabled but mod_last is not enabled in that host", []), - false; - TimeStamp -> - case is_fresh_enough(TimeStamp, CacheTime) of - true -> - get_password_internal(User, Server); - false -> - false - end + online -> get_password_internal(User, Server); + never -> false; + mod_last_required -> + ?ERROR_MSG("extauth is used, extauth_cache is enabled " + "but mod_last is not enabled in that " + "host", + []), + false; + TimeStamp -> + case is_fresh_enough(TimeStamp, CacheTime) of + true -> get_password_internal(User, Server); + false -> false + end end. - %% Check the password using extauth; if success then cache it check_password_external_cache(User, Server, Password) -> case check_password_extauth(User, Server, Password) of - true -> - set_password_internal(User, Server, Password), true; - false -> - false + true -> + set_password_internal(User, Server, Password), true; + false -> false end. %% Try to register using extauth; if success then cache it try_register_external_cache(User, Server, Password) -> case try_register_extauth(User, Server, Password) of - {atomic, ok} = R -> - set_password_internal(User, Server, Password), - R; - _ -> {error, not_allowed} + {atomic, ok} = R -> + set_password_internal(User, Server, Password), R; + _ -> {error, not_allowed} end. %% @spec (User, Server, Password) -> true | false check_password_internal(User, Server, Password) -> - ejabberd_auth_internal:check_password(User, Server, Password). + ejabberd_auth_internal:check_password(User, Server, + Password). %% @spec (User, Server, Password) -> ok | {error, invalid_jid} set_password_internal(User, Server, Password) -> - ejabberd_auth_internal:set_password(User, Server, Password). - %% @spec (TimeLast, CacheTime) -> true | false %% TimeLast = online | never | integer() %% CacheTime = integer() | false -is_fresh_enough(online, _CacheTime) -> - true; -is_fresh_enough(never, _CacheTime) -> - false; + ejabberd_auth_internal:set_password(User, Server, + Password). + is_fresh_enough(TimeStampLast, CacheTime) -> {MegaSecs, Secs, _MicroSecs} = now(), Now = MegaSecs * 1000000 + Secs, - (TimeStampLast + CacheTime > Now). + TimeStampLast + CacheTime > Now. %% @spec (User, Server) -> online | never | mod_last_required | TimeStamp::integer() %% Code copied from mod_configure.erl @@ -276,38 +277,35 @@ is_fresh_enough(TimeStampLast, CacheTime) -> %% TODO: Update time format to XEP-0202: Entity Time get_last_access(User, Server) -> case ejabberd_sm:get_user_resources(User, Server) of - [] -> - _US = {User, Server}, - case get_last_info(User, Server) of - mod_last_required -> - mod_last_required; - not_found -> - never; - {ok, Timestamp, _Status} -> - Timestamp - end; - _ -> - online + [] -> + _US = {User, Server}, + case get_last_info(User, Server) of + mod_last_required -> mod_last_required; + not_found -> never; + {ok, Timestamp, _Status} -> Timestamp + end; + _ -> online end. %% @spec (User, Server) -> {ok, Timestamp, Status} | not_found | mod_last_required + get_last_info(User, Server) -> case get_mod_last_enabled(Server) of - mod_last -> mod_last:get_last_info(User, Server); - no_mod_last -> mod_last_required + mod_last -> mod_last:get_last_info(User, Server); + no_mod_last -> mod_last_required end. %% @spec (Server) -> mod_last | no_mod_last get_mod_last_enabled(Server) -> case gen_mod:is_loaded(Server, mod_last) of - true -> mod_last; - false -> no_mod_last + true -> mod_last; + false -> no_mod_last end. get_mod_last_configured(Server) -> case is_configured(Server, mod_last) of - true -> mod_last; - false -> no_mod_last + true -> mod_last; + false -> no_mod_last end. is_configured(Host, Module) -> - lists:keymember(Module, 1, ejabberd_config:get_local_option({modules, Host})). + gen_mod:is_loaded(Host, Module). diff --git a/src/ejabberd_auth_internal.erl b/src/ejabberd_auth_internal.erl index 4b5bcd327..b3587e211 100644 --- a/src/ejabberd_auth_internal.erl +++ b/src/ejabberd_auth_internal.erl @@ -25,32 +25,29 @@ %%%---------------------------------------------------------------------- -module(ejabberd_auth_internal). + -author('alexey@process-one.net'). +-behaviour(ejabberd_auth). + %% External exports --export([start/1, - set_password/3, - check_password/3, - check_password/5, - try_register/3, - dirty_get_registered_users/0, - get_vh_registered_users/1, +-export([start/1, set_password/3, check_password/3, + check_password/5, try_register/3, + dirty_get_registered_users/0, get_vh_registered_users/1, get_vh_registered_users/2, get_vh_registered_users_number/1, - get_vh_registered_users_number/2, - get_password/2, - get_password_s/2, - is_user_exists/2, - remove_user/2, - remove_user/3, - store_type/0, - plain_password_required/0 - ]). + get_vh_registered_users_number/2, get_password/2, + get_password_s/2, is_user_exists/2, remove_user/2, + remove_user/3, store_type/0, export/1, + plain_password_required/0]). -include("ejabberd.hrl"). --record(passwd, {us, password}). --record(reg_users_counter, {vhost, count}). +-record(passwd, {us = {<<"">>, <<"">>} :: {binary(), binary()} | '$1', + password = <<"">> :: binary() | scram() | '_'}). + +-record(reg_users_counter, {vhost = <<"">> :: binary(), + count = 0 :: integer() | '$1'}). -define(SALT_LENGTH, 16). @@ -58,8 +55,9 @@ %%% API %%%---------------------------------------------------------------------- start(Host) -> - mnesia:create_table(passwd, [{disc_copies, [node()]}, - {attributes, record_info(fields, passwd)}]), + mnesia:create_table(passwd, + [{disc_copies, [node()]}, + {attributes, record_info(fields, passwd)}]), mnesia:create_table(reg_users_counter, [{ram_copies, [node()]}, {attributes, record_info(fields, reg_users_counter)}]), @@ -72,22 +70,22 @@ update_reg_users_counter_table(Server) -> Set = get_vh_registered_users(Server), Size = length(Set), LServer = jlib:nameprep(Server), - F = fun() -> - mnesia:write(#reg_users_counter{vhost = LServer, - count = Size}) + F = fun () -> + mnesia:write(#reg_users_counter{vhost = LServer, + count = Size}) end, mnesia:sync_dirty(F). plain_password_required() -> case is_scrammed() of - false -> false; - true -> true + false -> false; + true -> true end. store_type() -> case is_scrammed() of - false -> plain; %% allows: PLAIN DIGEST-MD5 SCRAM - true -> scram %% allows: PLAIN SCRAM + false -> plain; %% allows: PLAIN DIGEST-MD5 SCRAM + true -> scram %% allows: PLAIN SCRAM end. check_password(User, Server, Password) -> @@ -95,46 +93,40 @@ check_password(User, Server, Password) -> LServer = jlib:nameprep(Server), US = {LUser, LServer}, case catch mnesia:dirty_read({passwd, US}) of - [#passwd{password = Password}] when is_list(Password) -> - Password /= ""; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - is_password_scram_valid(Password, Scram); - _ -> - false + [#passwd{password = Password}] + when is_binary(Password) -> + Password /= <<"">>; + [#passwd{password = Scram}] + when is_record(Scram, scram) -> + is_password_scram_valid(Password, Scram); + _ -> false end. -check_password(User, Server, Password, Digest, DigestGen) -> +check_password(User, Server, Password, Digest, + DigestGen) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), US = {LUser, LServer}, case catch mnesia:dirty_read({passwd, US}) of - [#passwd{password = Passwd}] when is_list(Passwd) -> - DigRes = if - Digest /= "" -> - Digest == DigestGen(Passwd); - true -> - false - end, - if DigRes -> - true; - true -> - (Passwd == Password) and (Password /= "") - end; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - Passwd = base64:decode(Scram#scram.storedkey), - DigRes = if - Digest /= "" -> - Digest == DigestGen(Passwd); - true -> - false - end, - if DigRes -> - true; - true -> - (Passwd == Password) and (Password /= "") - end; - _ -> - false + [#passwd{password = Passwd}] when is_binary(Passwd) -> + DigRes = if Digest /= <<"">> -> + Digest == DigestGen(Passwd); + true -> false + end, + if DigRes -> true; + true -> (Passwd == Password) and (Password /= <<"">>) + end; + [#passwd{password = Scram}] + when is_record(Scram, scram) -> + Passwd = jlib:decode_base64(Scram#scram.storedkey), + DigRes = if Digest /= <<"">> -> + Digest == DigestGen(Passwd); + true -> false + end, + if DigRes -> true; + true -> (Passwd == Password) and (Password /= <<"">>) + end; + _ -> false end. %% @spec (User::string(), Server::string(), Password::string()) -> @@ -143,49 +135,48 @@ set_password(User, Server, Password) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), US = {LUser, LServer}, - if - (LUser == error) or (LServer == error) -> - {error, invalid_jid}; - true -> - F = fun() -> - Password2 = case is_scrammed() and is_list(Password) of - true -> password_to_scram(Password); - false -> Password - end, - mnesia:write(#passwd{us = US, - password = Password2}) - end, - {atomic, ok} = mnesia:transaction(F), - ok + if (LUser == error) or (LServer == error) -> + {error, invalid_jid}; + true -> + F = fun () -> + Password2 = case is_scrammed() and is_binary(Password) + of + true -> password_to_scram(Password); + false -> Password + end, + mnesia:write(#passwd{us = US, password = Password2}) + end, + {atomic, ok} = mnesia:transaction(F), + ok end. %% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, invalid_jid} | {aborted, Reason} -try_register(User, Server, Password) -> +try_register(User, Server, PasswordList) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), + Password = iolist_to_binary(PasswordList), US = {LUser, LServer}, - if - (LUser == error) or (LServer == error) -> - {error, invalid_jid}; - true -> - F = fun() -> - case mnesia:read({passwd, US}) of - [] -> - Password2 = case is_scrammed() and is_list(Password) of - true -> password_to_scram(Password); - false -> Password - end, - mnesia:write(#passwd{us = US, - password = Password2}), - mnesia:dirty_update_counter( - reg_users_counter, - LServer, 1), - ok; - [_E] -> - exists - end - end, - mnesia:transaction(F) + if (LUser == error) or (LServer == error) -> + {error, invalid_jid}; + true -> + F = fun () -> + case mnesia:read({passwd, US}) of + [] -> + Password2 = case is_scrammed() and + is_binary(Password) + of + true -> password_to_scram(Password); + false -> Password + end, + mnesia:write(#passwd{us = US, + password = Password2}), + mnesia:dirty_update_counter(reg_users_counter, + LServer, 1), + ok; + [_E] -> exists + end + end, + mnesia:transaction(F) end. %% Get all registered users in Mnesia @@ -194,75 +185,81 @@ dirty_get_registered_users() -> get_vh_registered_users(Server) -> LServer = jlib:nameprep(Server), - mnesia:dirty_select( - passwd, - [{#passwd{us = '$1', _ = '_'}, - [{'==', {element, 2, '$1'}, LServer}], - ['$1']}]). + mnesia:dirty_select(passwd, + [{#passwd{us = '$1', _ = '_'}, + [{'==', {element, 2, '$1'}, LServer}], ['$1']}]). -get_vh_registered_users(Server, [{from, Start}, {to, End}]) - when is_integer(Start) and is_integer(End) -> - get_vh_registered_users(Server, [{limit, End-Start+1}, {offset, Start}]); - -get_vh_registered_users(Server, [{limit, Limit}, {offset, Offset}]) - when is_integer(Limit) and is_integer(Offset) -> +get_vh_registered_users(Server, + [{from, Start}, {to, End}]) + when is_integer(Start) and is_integer(End) -> + get_vh_registered_users(Server, + [{limit, End - Start + 1}, {offset, Start}]); +get_vh_registered_users(Server, + [{limit, Limit}, {offset, Offset}]) + when is_integer(Limit) and is_integer(Offset) -> case get_vh_registered_users(Server) of - [] -> - []; - Users -> - Set = lists:keysort(1, Users), - L = length(Set), - Start = if Offset < 1 -> 1; - Offset > L -> L; - true -> Offset - end, - lists:sublist(Set, Start, Limit) + [] -> []; + Users -> + Set = lists:keysort(1, Users), + L = length(Set), + Start = if Offset < 1 -> 1; + Offset > L -> L; + true -> Offset + end, + lists:sublist(Set, Start, Limit) end; - -get_vh_registered_users(Server, [{prefix, Prefix}]) - when is_list(Prefix) -> - Set = [{U,S} || {U, S} <- get_vh_registered_users(Server), lists:prefix(Prefix, U)], +get_vh_registered_users(Server, [{prefix, Prefix}]) + when is_binary(Prefix) -> + Set = [{U, S} + || {U, S} <- get_vh_registered_users(Server), + str:prefix(Prefix, U)], lists:keysort(1, Set); - -get_vh_registered_users(Server, [{prefix, Prefix}, {from, Start}, {to, End}]) - when is_list(Prefix) and is_integer(Start) and is_integer(End) -> - get_vh_registered_users(Server, [{prefix, Prefix}, {limit, End-Start+1}, {offset, Start}]); - -get_vh_registered_users(Server, [{prefix, Prefix}, {limit, Limit}, {offset, Offset}]) - when is_list(Prefix) and is_integer(Limit) and is_integer(Offset) -> - case [{U,S} || {U, S} <- get_vh_registered_users(Server), lists:prefix(Prefix, U)] of - [] -> - []; - Users -> - Set = lists:keysort(1, Users), - L = length(Set), - Start = if Offset < 1 -> 1; - Offset > L -> L; - true -> Offset - end, - lists:sublist(Set, Start, Limit) +get_vh_registered_users(Server, + [{prefix, Prefix}, {from, Start}, {to, End}]) + when is_binary(Prefix) and is_integer(Start) and + is_integer(End) -> + get_vh_registered_users(Server, + [{prefix, Prefix}, {limit, End - Start + 1}, + {offset, Start}]); +get_vh_registered_users(Server, + [{prefix, Prefix}, {limit, Limit}, {offset, Offset}]) + when is_binary(Prefix) and is_integer(Limit) and + is_integer(Offset) -> + case [{U, S} + || {U, S} <- get_vh_registered_users(Server), + str:prefix(Prefix, U)] + of + [] -> []; + Users -> + Set = lists:keysort(1, Users), + L = length(Set), + Start = if Offset < 1 -> 1; + Offset > L -> L; + true -> Offset + end, + lists:sublist(Set, Start, Limit) end; - get_vh_registered_users(Server, _) -> get_vh_registered_users(Server). get_vh_registered_users_number(Server) -> LServer = jlib:nameprep(Server), - Query = mnesia:dirty_select( - reg_users_counter, - [{#reg_users_counter{vhost = LServer, count = '$1'}, - [], - ['$1']}]), + Query = mnesia:dirty_select(reg_users_counter, + [{#reg_users_counter{vhost = LServer, + count = '$1'}, + [], ['$1']}]), case Query of - [Count] -> - Count; - _ -> 0 + [Count] -> Count; + _ -> 0 end. -get_vh_registered_users_number(Server, [{prefix, Prefix}]) when is_list(Prefix) -> - Set = [{U, S} || {U, S} <- get_vh_registered_users(Server), lists:prefix(Prefix, U)], +get_vh_registered_users_number(Server, + [{prefix, Prefix}]) + when is_binary(Prefix) -> + Set = [{U, S} + || {U, S} <- get_vh_registered_users(Server), + str:prefix(Prefix, U)], length(Set); - get_vh_registered_users_number(Server, _) -> get_vh_registered_users_number(Server). @@ -271,15 +268,16 @@ get_password(User, Server) -> LServer = jlib:nameprep(Server), US = {LUser, LServer}, case catch mnesia:dirty_read(passwd, US) of - [#passwd{password = Password}] when is_list(Password) -> - Password; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - {base64:decode(Scram#scram.storedkey), - base64:decode(Scram#scram.serverkey), - base64:decode(Scram#scram.salt), - Scram#scram.iterationcount}; - _ -> - false + [#passwd{password = Password}] + when is_binary(Password) -> + Password; + [#passwd{password = Scram}] + when is_record(Scram, scram) -> + {jlib:decode_base64(Scram#scram.storedkey), + jlib:decode_base64(Scram#scram.serverkey), + jlib:decode_base64(Scram#scram.salt), + Scram#scram.iterationcount}; + _ -> false end. get_password_s(User, Server) -> @@ -287,12 +285,13 @@ get_password_s(User, Server) -> LServer = jlib:nameprep(Server), US = {LUser, LServer}, case catch mnesia:dirty_read(passwd, US) of - [#passwd{password = Password}] when is_list(Password) -> - Password; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - []; - _ -> - [] + [#passwd{password = Password}] + when is_binary(Password) -> + Password; + [#passwd{password = Scram}] + when is_record(Scram, scram) -> + <<"">>; + _ -> <<"">> end. %% @spec (User, Server) -> true | false | {error, Error} @@ -301,12 +300,9 @@ is_user_exists(User, Server) -> LServer = jlib:nameprep(Server), US = {LUser, LServer}, case catch mnesia:dirty_read({passwd, US}) of - [] -> - false; - [_] -> - true; - Other -> - {error, Other} + [] -> false; + [_] -> true; + Other -> {error, Other} end. %% @spec (User, Server) -> ok @@ -316,13 +312,13 @@ remove_user(User, Server) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), US = {LUser, LServer}, - F = fun() -> + F = fun () -> mnesia:delete({passwd, US}), - mnesia:dirty_update_counter(reg_users_counter, - LServer, -1) - end, + mnesia:dirty_update_counter(reg_users_counter, LServer, + -1) + end, mnesia:transaction(F), - ok. + ok. %% @spec (User, Server, Password) -> ok | not_exists | not_allowed | bad_request %% @doc Remove user if the provided password is correct. @@ -330,79 +326,65 @@ remove_user(User, Server, Password) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), US = {LUser, LServer}, - F = fun() -> + F = fun () -> case mnesia:read({passwd, US}) of - [#passwd{password = Password}] when is_list(Password) -> - mnesia:delete({passwd, US}), - mnesia:dirty_update_counter(reg_users_counter, - LServer, -1), - ok; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - case is_password_scram_valid(Password, Scram) of - true -> - mnesia:delete({passwd, US}), - mnesia:dirty_update_counter(reg_users_counter, - LServer, -1), - ok; - false -> - not_allowed - end; - _ -> - not_exists + [#passwd{password = Password}] + when is_binary(Password) -> + mnesia:delete({passwd, US}), + mnesia:dirty_update_counter(reg_users_counter, LServer, + -1), + ok; + [#passwd{password = Scram}] + when is_record(Scram, scram) -> + case is_password_scram_valid(Password, Scram) of + true -> + mnesia:delete({passwd, US}), + mnesia:dirty_update_counter(reg_users_counter, + LServer, -1), + ok; + false -> not_allowed + end; + _ -> not_exists end - end, + end, case mnesia:transaction(F) of - {atomic, ok} -> - ok; - {atomic, Res} -> - Res; - _ -> - bad_request + {atomic, ok} -> ok; + {atomic, Res} -> Res; + _ -> bad_request end. update_table() -> Fields = record_info(fields, passwd), case mnesia:table_info(passwd, attributes) of - Fields -> - maybe_scram_passwords(), - ok; - [user, password] -> - ?INFO_MSG("Converting passwd table from " - "{user, password} format", []), - Host = ?MYNAME, - {atomic, ok} = mnesia:create_table( - ejabberd_auth_internal_tmp_table, - [{disc_only_copies, [node()]}, - {type, bag}, - {local_content, true}, - {record_name, passwd}, - {attributes, record_info(fields, passwd)}]), - mnesia:transform_table(passwd, ignore, Fields), - F1 = fun() -> - mnesia:write_lock_table(ejabberd_auth_internal_tmp_table), - mnesia:foldl( - fun(#passwd{us = U} = R, _) -> - mnesia:dirty_write( - ejabberd_auth_internal_tmp_table, - R#passwd{us = {U, Host}}) - end, ok, passwd) - end, - mnesia:transaction(F1), - mnesia:clear_table(passwd), - F2 = fun() -> - mnesia:write_lock_table(passwd), - mnesia:foldl( - fun(R, _) -> - mnesia:dirty_write(R) - end, ok, ejabberd_auth_internal_tmp_table) - end, - mnesia:transaction(F2), - mnesia:delete_table(ejabberd_auth_internal_tmp_table); - _ -> - ?INFO_MSG("Recreating passwd table", []), - mnesia:transform_table(passwd, ignore, Fields) + Fields -> + convert_to_binary(Fields), + maybe_scram_passwords(), + ok; + _ -> + ?INFO_MSG("Recreating passwd table", []), + mnesia:transform_table(passwd, ignore, Fields) end. +convert_to_binary(Fields) -> + ejabberd_config:convert_table_to_binary( + passwd, Fields, set, + fun(#passwd{us = {U, _}}) -> U end, + fun(#passwd{us = {U, S}, password = Pass} = R) -> + NewUS = {iolist_to_binary(U), iolist_to_binary(S)}, + NewPass = case Pass of + #scram{storedkey = StoredKey, + serverkey = ServerKey, + salt = Salt} -> + Pass#scram{ + storedkey = iolist_to_binary(StoredKey), + serverkey = iolist_to_binary(ServerKey), + salt = iolist_to_binary(Salt)}; + _ -> + iolist_to_binary(Pass) + end, + R#passwd{us = NewUS, password = NewPass} + end). + %%% %%% SCRAM %%% @@ -411,38 +393,43 @@ update_table() -> %% or if at least the first password is scrammed. is_scrammed() -> OptionScram = is_option_scram(), - FirstElement = mnesia:dirty_read(passwd, mnesia:dirty_first(passwd)), + FirstElement = mnesia:dirty_read(passwd, + mnesia:dirty_first(passwd)), case {OptionScram, FirstElement} of - {true, _} -> - true; - {false, [#passwd{password = Scram}]} when is_record(Scram, scram) -> - true; - _ -> - false + {true, _} -> true; + {false, [#passwd{password = Scram}]} + when is_record(Scram, scram) -> + true; + _ -> false end. is_option_scram() -> - scram == ejabberd_config:get_local_option({auth_password_format, ?MYNAME}). + scram == + ejabberd_config:get_local_option({auth_password_format, ?MYNAME}, + fun(V) -> V end). maybe_alert_password_scrammed_without_option() -> case is_scrammed() andalso not is_option_scram() of - true -> - ?ERROR_MSG("Some passwords were stored in the database as SCRAM, " - "but 'auth_password_format' is not configured 'scram'. " - "The option will now be considered to be 'scram'.", []); - false -> - ok + true -> + ?ERROR_MSG("Some passwords were stored in the database " + "as SCRAM, but 'auth_password_format' " + "is not configured 'scram'. The option " + "will now be considered to be 'scram'.", + []); + false -> ok end. maybe_scram_passwords() -> case is_scrammed() of - true -> scram_passwords(); - false -> ok + true -> scram_passwords(); + false -> ok end. scram_passwords() -> - ?INFO_MSG("Converting the stored passwords into SCRAM bits", []), - Fun = fun(#passwd{password = Password} = P) -> + ?INFO_MSG("Converting the stored passwords into " + "SCRAM bits", + []), + Fun = fun (#passwd{password = Password} = P) -> Scram = password_to_scram(Password), P#passwd{password = Scram} end, @@ -450,21 +437,39 @@ scram_passwords() -> mnesia:transform_table(passwd, Fun, Fields). password_to_scram(Password) -> - password_to_scram(Password, ?SCRAM_DEFAULT_ITERATION_COUNT). + password_to_scram(Password, + ?SCRAM_DEFAULT_ITERATION_COUNT). password_to_scram(Password, IterationCount) -> Salt = crypto:rand_bytes(?SALT_LENGTH), - SaltedPassword = scram:salted_password(Password, Salt, IterationCount), - StoredKey = scram:stored_key(scram:client_key(SaltedPassword)), + SaltedPassword = scram:salted_password(Password, Salt, + IterationCount), + StoredKey = + scram:stored_key(scram:client_key(SaltedPassword)), ServerKey = scram:server_key(SaltedPassword), - #scram{storedkey = base64:encode(StoredKey), - serverkey = base64:encode(ServerKey), - salt = base64:encode(Salt), + #scram{storedkey = jlib:encode_base64(StoredKey), + serverkey = jlib:encode_base64(ServerKey), + salt = jlib:encode_base64(Salt), iterationcount = IterationCount}. is_password_scram_valid(Password, Scram) -> IterationCount = Scram#scram.iterationcount, - Salt = base64:decode(Scram#scram.salt), - SaltedPassword = scram:salted_password(Password, Salt, IterationCount), - StoredKey = scram:stored_key(scram:client_key(SaltedPassword)), - (base64:decode(Scram#scram.storedkey) == StoredKey). + Salt = jlib:decode_base64(Scram#scram.salt), + SaltedPassword = scram:salted_password(Password, Salt, + IterationCount), + StoredKey = + scram:stored_key(scram:client_key(SaltedPassword)), + jlib:decode_base64(Scram#scram.storedkey) == StoredKey. + +export(_Server) -> + [{passwd, + fun(Host, #passwd{us = {LUser, LServer}, password = Password}) + when LServer == Host -> + Username = ejabberd_odbc:escape(LUser), + Pass = ejabberd_odbc:escape(Password), + [[<<"delete from users where username='">>, Username, <<"';">>], + [<<"insert into users(username, password) " + "values ('">>, Username, <<"', '">>, Pass, <<"');">>]]; + (_Host, _R) -> + [] + end}]. diff --git a/src/ejabberd_auth_ldap.erl b/src/ejabberd_auth_ldap.erl index 5e5ca2422..998f21215 100644 --- a/src/ejabberd_auth_ldap.erl +++ b/src/ejabberd_auth_ldap.erl @@ -25,73 +25,59 @@ %%%---------------------------------------------------------------------- -module(ejabberd_auth_ldap). + -author('alexey@process-one.net'). -behaviour(gen_server). +-behaviour(ejabberd_auth). %% gen_server callbacks --export([init/1, - handle_info/2, - handle_call/3, - handle_cast/2, - terminate/2, - code_change/3 - ]). +-export([init/1, handle_info/2, handle_call/3, + handle_cast/2, terminate/2, code_change/3]). %% External exports --export([start/1, - stop/1, - start_link/1, - set_password/3, - check_password/3, - check_password/5, - try_register/3, - dirty_get_registered_users/0, - get_vh_registered_users/1, - get_vh_registered_users_number/1, - get_password/2, - get_password_s/2, - is_user_exists/2, - remove_user/2, - remove_user/3, - store_type/0, - plain_password_required/0 - ]). +-export([start/1, stop/1, start_link/1, set_password/3, + check_password/3, check_password/5, try_register/3, + dirty_get_registered_users/0, get_vh_registered_users/1, + get_vh_registered_users/2, + get_vh_registered_users_number/1, + get_vh_registered_users_number/2, get_password/2, + get_password_s/2, is_user_exists/2, remove_user/2, + remove_user/3, store_type/0, + plain_password_required/0]). -include("ejabberd.hrl"). --include("eldap/eldap.hrl"). - --record(state, {host, - eldap_id, - bind_eldap_id, - servers, - backups, - port, - tls_options, - dn, - password, - base, - uids, - ufilter, - sfilter, - lfilter, %% Local filter (performed by ejabberd, not LDAP) - deref_aliases, - dn_filter, - dn_filter_attrs - }). - %% Unused callbacks. -handle_cast(_Request, State) -> - {noreply, State}. -code_change(_OldVsn, State, _Extra) -> - {ok, State}. -handle_info(_Info, State) -> - {noreply, State}. %% ----- +-include("eldap/eldap.hrl"). --define(LDAP_SEARCH_TIMEOUT, 5). % Timeout for LDAP search queries in seconds +-record(state, + {host = <<"">> :: binary(), + eldap_id = <<"">> :: binary(), + bind_eldap_id = <<"">> :: binary(), + servers = [] :: [binary()], + backups = [] :: [binary()], + port = ?LDAP_PORT :: inet:port_number(), + tls_options = [] :: list(), + dn = <<"">> :: binary(), + password = <<"">> :: binary(), + base = <<"">> :: binary(), + uids = [] :: [{binary()} | {binary(), binary()}], + ufilter = <<"">> :: binary(), + sfilter = <<"">> :: binary(), + lfilter :: {any(), any()}, + deref_aliases = never :: never | searching | finding | always, + dn_filter :: binary(), + dn_filter_attrs = [] :: [binary()]}). +handle_cast(_Request, State) -> {noreply, State}. + +code_change(_OldVsn, State, _Extra) -> {ok, State}. + +handle_info(_Info, State) -> {noreply, State}. + +-define(LDAP_SEARCH_TIMEOUT, 5). %%%---------------------------------------------------------------------- %%% API @@ -99,10 +85,8 @@ handle_info(_Info, State) -> start(Host) -> Proc = gen_mod:get_module_proc(Host, ?MODULE), - ChildSpec = { - Proc, {?MODULE, start_link, [Host]}, - transient, 1000, worker, [?MODULE] - }, + ChildSpec = {Proc, {?MODULE, start_link, [Host]}, + transient, 1000, worker, [?MODULE]}, supervisor:start_child(ejabberd_sup, ChildSpec). stop(Host) -> @@ -115,56 +99,45 @@ start_link(Host) -> Proc = gen_mod:get_module_proc(Host, ?MODULE), gen_server:start_link({local, Proc}, ?MODULE, Host, []). -terminate(_Reason, _State) -> - ok. +terminate(_Reason, _State) -> ok. init(Host) -> State = parse_options(Host), eldap_pool:start_link(State#state.eldap_id, - State#state.servers, - State#state.backups, - State#state.port, - State#state.dn, - State#state.password, - State#state.tls_options), + State#state.servers, State#state.backups, + State#state.port, State#state.dn, + State#state.password, State#state.tls_options), eldap_pool:start_link(State#state.bind_eldap_id, - State#state.servers, - State#state.backups, - State#state.port, - State#state.dn, - State#state.password, - State#state.tls_options), + State#state.servers, State#state.backups, + State#state.port, State#state.dn, + State#state.password, State#state.tls_options), {ok, State}. -plain_password_required() -> - true. +plain_password_required() -> true. -store_type() -> - external. +store_type() -> external. check_password(User, Server, Password) -> - %% In LDAP spec: empty password means anonymous authentication. - %% As ejabberd is providing other anonymous authentication mechanisms - %% we simply prevent the use of LDAP anonymous authentication. - if Password == "" -> - false; - true -> - case catch check_password_ldap(User, Server, Password) of - {'EXIT', _} -> false; - Result -> Result - end + if Password == <<"">> -> false; + true -> + case catch check_password_ldap(User, Server, Password) + of + {'EXIT', _} -> false; + Result -> Result + end end. -check_password(User, Server, Password, _Digest, _DigestGen) -> +check_password(User, Server, Password, _Digest, + _DigestGen) -> check_password(User, Server, Password). set_password(User, Server, Password) -> {ok, State} = eldap_utils:get_state(Server, ?MODULE), case find_user_dn(User, State) of - false -> - {error, user_not_found}; - DN -> - eldap_pool:modify_passwd(State#state.eldap_id, DN, Password) + false -> {error, user_not_found}; + DN -> + eldap_pool:modify_passwd(State#state.eldap_id, DN, + Password) end. %% @spec (User, Server, Password) -> {error, not_allowed} @@ -173,55 +146,56 @@ try_register(_User, _Server, _Password) -> dirty_get_registered_users() -> Servers = ejabberd_config:get_vh_by_auth_method(ldap), - lists:flatmap( - fun(Server) -> - get_vh_registered_users(Server) - end, Servers). + lists:flatmap(fun (Server) -> + get_vh_registered_users(Server) + end, + Servers). get_vh_registered_users(Server) -> case catch get_vh_registered_users_ldap(Server) of - {'EXIT', _} -> []; - Result -> Result - end. + {'EXIT', _} -> []; + Result -> Result + end. + +get_vh_registered_users(Server, _) -> + get_vh_registered_users(Server). get_vh_registered_users_number(Server) -> length(get_vh_registered_users(Server)). -get_password(_User, _Server) -> - false. +get_vh_registered_users_number(Server, _) -> + get_vh_registered_users_number(Server). -get_password_s(_User, _Server) -> - "". +get_password(_User, _Server) -> false. + +get_password_s(_User, _Server) -> <<"">>. %% @spec (User, Server) -> true | false | {error, Error} is_user_exists(User, Server) -> case catch is_user_exists_ldap(User, Server) of - {'EXIT', Error} -> - {error, Error}; - Result -> - Result + {'EXIT', Error} -> {error, Error}; + Result -> Result end. -remove_user(_User, _Server) -> - {error, not_allowed}. +remove_user(_User, _Server) -> {error, not_allowed}. -remove_user(_User, _Server, _Password) -> - not_allowed. +remove_user(_User, _Server, _Password) -> not_allowed. %%%---------------------------------------------------------------------- %%% Internal functions %%%---------------------------------------------------------------------- check_password_ldap(User, Server, Password) -> - {ok, State} = eldap_utils:get_state(Server, ?MODULE), - case find_user_dn(User, State) of - false -> - false; - DN -> - case eldap_pool:bind(State#state.bind_eldap_id, DN, Password) of - ok -> true; - _ -> false - end - end. + {ok, State} = eldap_utils:get_state(Server, ?MODULE), + case find_user_dn(User, State) of + false -> false; + DN -> + case eldap_pool:bind(State#state.bind_eldap_id, DN, + Password) + of + ok -> true; + _ -> false + end + end. get_vh_registered_users_ldap(Server) -> {ok, State} = eldap_utils:get_state(Server, ?MODULE), @@ -230,114 +204,123 @@ get_vh_registered_users_ldap(Server) -> Server = State#state.host, ResAttrs = result_attrs(State), case eldap_filter:parse(State#state.sfilter) of - {ok, EldapFilter} -> - case eldap_pool:search(Eldap_ID, - [{base, State#state.base}, - {filter, EldapFilter}, - {timeout, ?LDAP_SEARCH_TIMEOUT}, - {deref_aliases, State#state.deref_aliases}, - {attributes, ResAttrs}]) of - #eldap_search_result{entries = Entries} -> - lists:flatmap( - fun(#eldap_entry{attributes = Attrs, - object_name = DN}) -> + {ok, EldapFilter} -> + case eldap_pool:search(Eldap_ID, + [{base, State#state.base}, + {filter, EldapFilter}, + {timeout, ?LDAP_SEARCH_TIMEOUT}, + {deref_aliases, State#state.deref_aliases}, + {attributes, ResAttrs}]) + of + #eldap_search_result{entries = Entries} -> + lists:flatmap(fun (#eldap_entry{attributes = Attrs, + object_name = DN}) -> case is_valid_dn(DN, Attrs, State) of - false -> []; - _ -> - case eldap_utils:find_ldap_attrs(UIDs, Attrs) of - "" -> []; - {User, UIDFormat} -> - case eldap_utils:get_user_part(User, UIDFormat) of - {ok, U} -> - case jlib:nodeprep(U) of - error -> []; - LU -> [{LU, jlib:nameprep(Server)}] - end; - _ -> [] - end - end + false -> []; + _ -> + case + eldap_utils:find_ldap_attrs(UIDs, + Attrs) + of + <<"">> -> []; + {User, UIDFormat} -> + case + eldap_utils:get_user_part(User, + UIDFormat) + of + {ok, U} -> + case jlib:nodeprep(U) of + error -> []; + LU -> + [{LU, + jlib:nameprep(Server)}] + end; + _ -> [] + end + end end - end, Entries); - _ -> - [] - end; - _ -> - [] - end. + end, + Entries); + _ -> [] + end; + _ -> [] + end. is_user_exists_ldap(User, Server) -> {ok, State} = eldap_utils:get_state(Server, ?MODULE), case find_user_dn(User, State) of - false -> false; - _DN -> true - end. + false -> false; + _DN -> true + end. handle_call(get_state, _From, State) -> - {reply, {ok, State}, State}; - + {reply, {ok, State}, State}; handle_call(stop, _From, State) -> {stop, normal, ok, State}; - handle_call(_Request, _From, State) -> {reply, bad_request, State}. find_user_dn(User, State) -> ResAttrs = result_attrs(State), - case eldap_filter:parse(State#state.ufilter, [{"%u", User}]) of - {ok, Filter} -> - case eldap_pool:search(State#state.eldap_id, - [{base, State#state.base}, - {filter, Filter}, - {deref_aliases, State#state.deref_aliases}, - {attributes, ResAttrs}]) of - #eldap_search_result{entries = [#eldap_entry{attributes = Attrs, - object_name = DN} | _]} -> - dn_filter(DN, Attrs, State); - _ -> - false - end; - _ -> - false + case eldap_filter:parse(State#state.ufilter, + [{<<"%u">>, User}]) + of + {ok, Filter} -> + case eldap_pool:search(State#state.eldap_id, + [{base, State#state.base}, {filter, Filter}, + {deref_aliases, State#state.deref_aliases}, + {attributes, ResAttrs}]) + of + #eldap_search_result{entries = + [#eldap_entry{attributes = Attrs, + object_name = DN} + | _]} -> + dn_filter(DN, Attrs, State); + _ -> false + end; + _ -> false end. %% apply the dn filter and the local filter: dn_filter(DN, Attrs, State) -> - %% Check if user is denied access by attribute value (local check) case check_local_filter(Attrs, State) of - false -> false; - true -> is_valid_dn(DN, Attrs, State) + false -> false; + true -> is_valid_dn(DN, Attrs, State) end. %% Check that the DN is valid, based on the dn filter -is_valid_dn(DN, _, #state{dn_filter = undefined}) -> - DN; - +is_valid_dn(DN, _, #state{dn_filter = undefined}) -> DN; is_valid_dn(DN, Attrs, State) -> DNAttrs = State#state.dn_filter_attrs, UIDs = State#state.uids, - Values = [{"%s", eldap_utils:get_ldap_attr(Attr, Attrs), 1} || Attr <- DNAttrs], - SubstValues = case eldap_utils:find_ldap_attrs(UIDs, Attrs) of - "" -> Values; - {S, UAF} -> - case eldap_utils:get_user_part(S, UAF) of - {ok, U} -> [{"%u", U} | Values]; - _ -> Values - end - end ++ [{"%d", State#state.host}, {"%D", DN}], - case eldap_filter:parse(State#state.dn_filter, SubstValues) of - {ok, EldapFilter} -> - case eldap_pool:search(State#state.eldap_id, - [{base, State#state.base}, - {filter, EldapFilter}, - {deref_aliases, State#state.deref_aliases}, - {attributes, ["dn"]}]) of - #eldap_search_result{entries = [_|_]} -> - DN; - _ -> - false - end; - _ -> - false + Values = [{<<"%s">>, + eldap_utils:get_ldap_attr(Attr, Attrs), 1} + || Attr <- DNAttrs], + SubstValues = case eldap_utils:find_ldap_attrs(UIDs, + Attrs) + of + <<"">> -> Values; + {S, UAF} -> + case eldap_utils:get_user_part(S, UAF) of + {ok, U} -> [{<<"%u">>, U} | Values]; + _ -> Values + end + end + ++ [{<<"%d">>, State#state.host}, {<<"%D">>, DN}], + case eldap_filter:parse(State#state.dn_filter, + SubstValues) + of + {ok, EldapFilter} -> + case eldap_pool:search(State#state.eldap_id, + [{base, State#state.base}, + {filter, EldapFilter}, + {deref_aliases, State#state.deref_aliases}, + {attributes, [<<"dn">>]}]) + of + #eldap_search_result{entries = [_ | _]} -> DN; + _ -> false + end; + _ -> false end. %% The local filter is used to check an attribute in ejabberd @@ -346,109 +329,92 @@ is_valid_dn(DN, Attrs, State) -> %% {equal, {"accountStatus",["active"]}} %% {notequal, {"accountStatus",["disabled"]}} %% {ldap_local_filter, {notequal, {"accountStatus",["disabled"]}}} -check_local_filter(_Attrs, #state{lfilter = undefined}) -> +check_local_filter(_Attrs, + #state{lfilter = undefined}) -> true; -check_local_filter(Attrs, #state{lfilter = LocalFilter}) -> +check_local_filter(Attrs, + #state{lfilter = LocalFilter}) -> {Operation, FilterMatch} = LocalFilter, local_filter(Operation, Attrs, FilterMatch). - + local_filter(equal, Attrs, FilterMatch) -> {Attr, Value} = FilterMatch, case lists:keysearch(Attr, 1, Attrs) of - false -> false; - {value,{Attr,Value}} -> true; - _ -> false + false -> false; + {value, {Attr, Value}} -> true; + _ -> false end; local_filter(notequal, Attrs, FilterMatch) -> not local_filter(equal, Attrs, FilterMatch). -result_attrs(#state{uids = UIDs, dn_filter_attrs = DNFilterAttrs}) -> - lists:foldl( - fun({UID}, Acc) -> - [UID | Acc]; - ({UID, _}, Acc) -> - [UID | Acc] - end, DNFilterAttrs, UIDs). +result_attrs(#state{uids = UIDs, + dn_filter_attrs = DNFilterAttrs}) -> + lists:foldl(fun ({UID}, Acc) -> [UID | Acc]; + ({UID, _}, Acc) -> [UID | Acc] + end, + DNFilterAttrs, UIDs). %%%---------------------------------------------------------------------- %%% Auxiliary functions %%%---------------------------------------------------------------------- parse_options(Host) -> - Eldap_ID = atom_to_list(gen_mod:get_module_proc(Host, ?MODULE)), - Bind_Eldap_ID = atom_to_list(gen_mod:get_module_proc(Host, bind_ejabberd_auth_ldap)), - LDAPServers = ejabberd_config:get_local_option({ldap_servers, Host}), - LDAPBackups = case ejabberd_config:get_local_option({ldap_backups, Host}) of - undefined -> []; - Backups -> Backups - end, - LDAPEncrypt = ejabberd_config:get_local_option({ldap_encrypt, Host}), - LDAPTLSVerify = ejabberd_config:get_local_option({ldap_tls_verify, Host}), - LDAPTLSCAFile = ejabberd_config:get_local_option({ldap_tls_cacertfile, Host}), - LDAPTLSDepth = ejabberd_config:get_local_option({ldap_tls_depth, Host}), - LDAPPort = case ejabberd_config:get_local_option({ldap_port, Host}) of - undefined -> case LDAPEncrypt of - tls -> ?LDAPS_PORT; - starttls -> ?LDAP_PORT; - _ -> ?LDAP_PORT - end; - P -> P - end, - RootDN = case ejabberd_config:get_local_option({ldap_rootdn, Host}) of - undefined -> ""; - RDN -> RDN - end, - Password = case ejabberd_config:get_local_option({ldap_password, Host}) of - undefined -> ""; - Pass -> Pass - end, - UIDs = case ejabberd_config:get_local_option({ldap_uids, Host}) of - undefined -> [{"uid", "%u"}]; - UI -> eldap_utils:uids_domain_subst(Host, UI) - end, - SubFilter = lists:flatten(eldap_utils:generate_subfilter(UIDs)), - UserFilter = case ejabberd_config:get_local_option({ldap_filter, Host}) of - undefined -> SubFilter; - "" -> SubFilter; - F -> - eldap_utils:check_filter(F), - "(&" ++ SubFilter ++ F ++ ")" - end, - SearchFilter = eldap_filter:do_sub(UserFilter, [{"%u", "*"}]), - LDAPBase = ejabberd_config:get_local_option({ldap_base, Host}), + Cfg = eldap_utils:get_config(Host, []), + Eldap_ID = jlib:atom_to_binary(gen_mod:get_module_proc(Host, ?MODULE)), + Bind_Eldap_ID = jlib:atom_to_binary( + gen_mod:get_module_proc(Host, bind_ejabberd_auth_ldap)), + UIDsTemp = eldap_utils:get_opt( + {ldap_uids, Host}, [], + fun(Us) -> + lists:map( + fun({U, P}) -> + {iolist_to_binary(U), + iolist_to_binary(P)}; + ({U}) -> + {iolist_to_binary(U)} + end, Us) + end, [{<<"uid">>, <<"%u">>}]), + UIDs = eldap_utils:uids_domain_subst(Host, UIDsTemp), + SubFilter = eldap_utils:generate_subfilter(UIDs), + UserFilter = case eldap_utils:get_opt( + {ldap_filter, Host}, [], + fun check_filter/1, <<"">>) of + <<"">> -> + SubFilter; + F -> + <<"(&", SubFilter/binary, F/binary, ")">> + end, + SearchFilter = eldap_filter:do_sub(UserFilter, + [{<<"%u">>, <<"*">>}]), {DNFilter, DNFilterAttrs} = - case ejabberd_config:get_local_option({ldap_dn_filter, Host}) of - undefined -> - {undefined, []}; - {DNF, undefined} -> - {DNF, []}; - {DNF, DNFA} -> - {DNF, DNFA} - end, - eldap_utils:check_filter(DNFilter), - LocalFilter = ejabberd_config:get_local_option({ldap_local_filter, Host}), - DerefAliases = case ejabberd_config:get_local_option( - {ldap_deref_aliases, Host}) of - undefined -> never; - Val -> Val - end, - #state{host = Host, - eldap_id = Eldap_ID, - bind_eldap_id = Bind_Eldap_ID, - servers = LDAPServers, - backups = LDAPBackups, - port = LDAPPort, - tls_options = [{encrypt, LDAPEncrypt}, - {tls_verify, LDAPTLSVerify}, - {tls_cacertfile, LDAPTLSCAFile}, - {tls_depth, LDAPTLSDepth}], - dn = RootDN, - password = Password, - base = LDAPBase, - uids = UIDs, - ufilter = UserFilter, - sfilter = SearchFilter, - lfilter = LocalFilter, - deref_aliases = DerefAliases, - dn_filter = DNFilter, - dn_filter_attrs = DNFilterAttrs - }. + eldap_utils:get_opt({ldap_dn_filter, Host}, [], + fun({DNF, DNFA}) -> + NewDNFA = case DNFA of + undefined -> + []; + _ -> + [iolist_to_binary(A) + || A <- DNFA] + end, + NewDNF = check_filter(DNF), + {NewDNF, NewDNFA} + end, {undefined, []}), + LocalFilter = eldap_utils:get_opt( + {ldap_local_filter, Host}, [], fun(V) -> V end), + #state{host = Host, eldap_id = Eldap_ID, + bind_eldap_id = Bind_Eldap_ID, + servers = Cfg#eldap_config.servers, + backups = Cfg#eldap_config.backups, + port = Cfg#eldap_config.port, + tls_options = Cfg#eldap_config.tls_options, + dn = Cfg#eldap_config.dn, + password = Cfg#eldap_config.password, + base = Cfg#eldap_config.base, + deref_aliases = Cfg#eldap_config.deref_aliases, + uids = UIDs, ufilter = UserFilter, + sfilter = SearchFilter, lfilter = LocalFilter, + dn_filter = DNFilter, dn_filter_attrs = DNFilterAttrs}. + +check_filter(F) -> + NewF = iolist_to_binary(F), + {ok, _} = eldap_filter:parse(NewF), + NewF. diff --git a/src/ejabberd_auth_odbc.erl b/src/ejabberd_auth_odbc.erl index 3f648d666..7a2e90e02 100644 --- a/src/ejabberd_auth_odbc.erl +++ b/src/ejabberd_auth_odbc.erl @@ -25,223 +25,197 @@ %%%---------------------------------------------------------------------- -module(ejabberd_auth_odbc). + -author('alexey@process-one.net'). +-behaviour(ejabberd_auth). + %% External exports --export([start/1, - set_password/3, - check_password/3, - check_password/5, - try_register/3, - dirty_get_registered_users/0, - get_vh_registered_users/1, +-export([start/1, set_password/3, check_password/3, + check_password/5, try_register/3, + dirty_get_registered_users/0, get_vh_registered_users/1, get_vh_registered_users/2, get_vh_registered_users_number/1, - get_vh_registered_users_number/2, - get_password/2, - get_password_s/2, - is_user_exists/2, - remove_user/2, - remove_user/3, - store_type/0, - plain_password_required/0 - ]). + get_vh_registered_users_number/2, get_password/2, + get_password_s/2, is_user_exists/2, remove_user/2, + remove_user/3, store_type/0, + plain_password_required/0]). -include("ejabberd.hrl"). %%%---------------------------------------------------------------------- %%% API %%%---------------------------------------------------------------------- -start(_Host) -> - ok. +start(_Host) -> ok. -plain_password_required() -> - false. +plain_password_required() -> false. -store_type() -> - plain. +store_type() -> plain. %% @spec (User, Server, Password) -> true | false | {error, Error} check_password(User, Server, Password) -> case jlib:nodeprep(User) of - error -> - false; - LUser -> - Username = ejabberd_odbc:escape(LUser), - LServer = jlib:nameprep(Server), - try odbc_queries:get_password(LServer, Username) of - {selected, ["password"], [{Password}]} -> - Password /= ""; %% Password is correct, and not empty - {selected, ["password"], [{_Password2}]} -> - false; %% Password is not correct - {selected, ["password"], []} -> - false; %% Account does not exist - {error, _Error} -> - false %% Typical error is that table doesn't exist - catch - _:_ -> - false %% Typical error is database not accessible - end + error -> false; + LUser -> + Username = ejabberd_odbc:escape(LUser), + LServer = jlib:nameprep(Server), + try odbc_queries:get_password(LServer, Username) of + {selected, [<<"password">>], [[Password]]} -> + Password /= <<"">>; + {selected, [<<"password">>], [[_Password2]]} -> + false; %% Password is not correct + {selected, [<<"password">>], []} -> + false; %% Account does not exist + {error, _Error} -> + false %% Typical error is that table doesn't exist + catch + _:_ -> + false %% Typical error is database not accessible + end end. %% @spec (User, Server, Password, Digest, DigestGen) -> true | false | {error, Error} -check_password(User, Server, Password, Digest, DigestGen) -> +check_password(User, Server, Password, Digest, + DigestGen) -> case jlib:nodeprep(User) of - error -> - false; - LUser -> - Username = ejabberd_odbc:escape(LUser), - LServer = jlib:nameprep(Server), - try odbc_queries:get_password(LServer, Username) of - %% Account exists, check if password is valid - {selected, ["password"], [{Passwd}]} -> - DigRes = if - Digest /= "" -> - Digest == DigestGen(Passwd); - true -> - false - end, - if DigRes -> - true; - true -> - (Passwd == Password) and (Password /= "") - end; - {selected, ["password"], []} -> - false; %% Account does not exist - {error, _Error} -> - false %% Typical error is that table doesn't exist - catch - _:_ -> - false %% Typical error is database not accessible - end + error -> false; + LUser -> + Username = ejabberd_odbc:escape(LUser), + LServer = jlib:nameprep(Server), + try odbc_queries:get_password(LServer, Username) of + %% Account exists, check if password is valid + {selected, [<<"password">>], [[Passwd]]} -> + DigRes = if Digest /= <<"">> -> + Digest == DigestGen(Passwd); + true -> false + end, + if DigRes -> true; + true -> (Passwd == Password) and (Password /= <<"">>) + end; + {selected, [<<"password">>], []} -> + false; %% Account does not exist + {error, _Error} -> + false %% Typical error is that table doesn't exist + catch + _:_ -> + false %% Typical error is database not accessible + end end. %% @spec (User::string(), Server::string(), Password::string()) -> %% ok | {error, invalid_jid} set_password(User, Server, Password) -> case jlib:nodeprep(User) of - error -> - {error, invalid_jid}; - LUser -> - Username = ejabberd_odbc:escape(LUser), - Pass = ejabberd_odbc:escape(Password), - LServer = jlib:nameprep(Server), - case catch odbc_queries:set_password_t(LServer, Username, Pass) of - {atomic, ok} -> ok; - Other -> {error, Other} - end + error -> {error, invalid_jid}; + LUser -> + Username = ejabberd_odbc:escape(LUser), + Pass = ejabberd_odbc:escape(Password), + LServer = jlib:nameprep(Server), + case catch odbc_queries:set_password_t(LServer, + Username, Pass) + of + {atomic, ok} -> ok; + Other -> {error, Other} + end end. - %% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, invalid_jid} try_register(User, Server, Password) -> case jlib:nodeprep(User) of - error -> - {error, invalid_jid}; - LUser -> - Username = ejabberd_odbc:escape(LUser), - Pass = ejabberd_odbc:escape(Password), - LServer = jlib:nameprep(Server), - case catch odbc_queries:add_user(LServer, Username, Pass) of - {updated, 1} -> - {atomic, ok}; - _ -> - {atomic, exists} - end + error -> {error, invalid_jid}; + LUser -> + Username = ejabberd_odbc:escape(LUser), + Pass = ejabberd_odbc:escape(Password), + LServer = jlib:nameprep(Server), + case catch odbc_queries:add_user(LServer, Username, + Pass) + of + {updated, 1} -> {atomic, ok}; + _ -> {atomic, exists} + end end. dirty_get_registered_users() -> Servers = ejabberd_config:get_vh_by_auth_method(odbc), - lists:flatmap( - fun(Server) -> - get_vh_registered_users(Server) - end, Servers). + lists:flatmap(fun (Server) -> + get_vh_registered_users(Server) + end, + Servers). get_vh_registered_users(Server) -> LServer = jlib:nameprep(Server), case catch odbc_queries:list_users(LServer) of - {selected, ["username"], Res} -> - [{U, LServer} || {U} <- Res]; - _ -> - [] + {selected, [<<"username">>], Res} -> + [{U, LServer} || [U] <- Res]; + _ -> [] end. get_vh_registered_users(Server, Opts) -> LServer = jlib:nameprep(Server), case catch odbc_queries:list_users(LServer, Opts) of - {selected, ["username"], Res} -> - [{U, LServer} || {U} <- Res]; - _ -> - [] + {selected, [<<"username">>], Res} -> + [{U, LServer} || [U] <- Res]; + _ -> [] end. get_vh_registered_users_number(Server) -> LServer = jlib:nameprep(Server), case catch odbc_queries:users_number(LServer) of - {selected, [_], [{Res}]} -> - list_to_integer(Res); - _ -> - 0 + {selected, [_], [[Res]]} -> + jlib:binary_to_integer(Res); + _ -> 0 end. get_vh_registered_users_number(Server, Opts) -> LServer = jlib:nameprep(Server), case catch odbc_queries:users_number(LServer, Opts) of - {selected, [_], [{Res}]} -> - list_to_integer(Res); - _Other -> - 0 + {selected, [_], [[Res]]} -> + jlib:binary_to_integer(Res); + _Other -> 0 end. get_password(User, Server) -> case jlib:nodeprep(User) of - error -> - false; - LUser -> - Username = ejabberd_odbc:escape(LUser), - LServer = jlib:nameprep(Server), - case catch odbc_queries:get_password(LServer, Username) of - {selected, ["password"], [{Password}]} -> - Password; - _ -> - false - end + error -> false; + LUser -> + Username = ejabberd_odbc:escape(LUser), + LServer = jlib:nameprep(Server), + case catch odbc_queries:get_password(LServer, Username) + of + {selected, [<<"password">>], [[Password]]} -> Password; + _ -> false + end end. get_password_s(User, Server) -> case jlib:nodeprep(User) of - error -> - ""; - LUser -> - Username = ejabberd_odbc:escape(LUser), - LServer = jlib:nameprep(Server), - case catch odbc_queries:get_password(LServer, Username) of - {selected, ["password"], [{Password}]} -> - Password; - _ -> - "" - end + error -> <<"">>; + LUser -> + Username = ejabberd_odbc:escape(LUser), + LServer = jlib:nameprep(Server), + case catch odbc_queries:get_password(LServer, Username) + of + {selected, [<<"password">>], [[Password]]} -> Password; + _ -> <<"">> + end end. %% @spec (User, Server) -> true | false | {error, Error} is_user_exists(User, Server) -> case jlib:nodeprep(User) of - error -> - false; - LUser -> - Username = ejabberd_odbc:escape(LUser), - LServer = jlib:nameprep(Server), - try odbc_queries:get_password(LServer, Username) of - {selected, ["password"], [{_Password}]} -> - true; %% Account exists - {selected, ["password"], []} -> - false; %% Account does not exist - {error, Error} -> - {error, Error} %% Typical error is that table doesn't exist - catch - _:B -> - {error, B} %% Typical error is database not accessible - end + error -> false; + LUser -> + Username = ejabberd_odbc:escape(LUser), + LServer = jlib:nameprep(Server), + try odbc_queries:get_password(LServer, Username) of + {selected, [<<"password">>], [[_Password]]} -> + true; %% Account exists + {selected, [<<"password">>], []} -> + false; %% Account does not exist + {error, Error} -> {error, Error} + catch + _:B -> {error, B} + end end. %% @spec (User, Server) -> ok | error @@ -249,37 +223,34 @@ is_user_exists(User, Server) -> %% Note: it may return ok even if there was some problem removing the user. remove_user(User, Server) -> case jlib:nodeprep(User) of - error -> - error; - LUser -> - Username = ejabberd_odbc:escape(LUser), - LServer = jlib:nameprep(Server), - catch odbc_queries:del_user(LServer, Username), - ok + error -> error; + LUser -> + Username = ejabberd_odbc:escape(LUser), + LServer = jlib:nameprep(Server), + catch odbc_queries:del_user(LServer, Username), + ok end. %% @spec (User, Server, Password) -> ok | error | not_exists | not_allowed %% @doc Remove user if the provided password is correct. remove_user(User, Server, Password) -> case jlib:nodeprep(User) of - error -> - error; - LUser -> - Username = ejabberd_odbc:escape(LUser), - Pass = ejabberd_odbc:escape(Password), - LServer = jlib:nameprep(Server), - F = fun() -> - Result = odbc_queries:del_user_return_password( - LServer, Username, Pass), - case Result of - {selected, ["password"], [{Password}]} -> - ok; - {selected, ["password"], []} -> - not_exists; - _ -> - not_allowed - end - end, - {atomic, Result} = odbc_queries:sql_transaction(LServer, F), - Result + error -> error; + LUser -> + Username = ejabberd_odbc:escape(LUser), + Pass = ejabberd_odbc:escape(Password), + LServer = jlib:nameprep(Server), + F = fun () -> + Result = odbc_queries:del_user_return_password(LServer, + Username, + Pass), + case Result of + {selected, [<<"password">>], [[Password]]} -> ok; + {selected, [<<"password">>], []} -> not_exists; + _ -> not_allowed + end + end, + {atomic, Result} = odbc_queries:sql_transaction(LServer, + F), + Result end. diff --git a/src/ejabberd_auth_pam.erl b/src/ejabberd_auth_pam.erl index 29752ba8d..a1400fe8e 100644 --- a/src/ejabberd_auth_pam.erl +++ b/src/ejabberd_auth_pam.erl @@ -24,102 +24,102 @@ %%% %%%------------------------------------------------------------------- -module(ejabberd_auth_pam). + -author('xram@jabber.ru'). -%% External exports --export([start/1, - set_password/3, - check_password/3, - check_password/5, - try_register/3, - dirty_get_registered_users/0, - get_vh_registered_users/1, - get_password/2, - get_password_s/2, - is_user_exists/2, - remove_user/2, - remove_user/3, - store_type/0, - plain_password_required/0 - ]). +-behaviour(ejabberd_auth). +%% External exports %%==================================================================== %% API %%==================================================================== +-export([start/1, set_password/3, check_password/3, + check_password/5, try_register/3, + dirty_get_registered_users/0, get_vh_registered_users/1, + get_vh_registered_users/2, get_vh_registered_users_number/1, + get_vh_registered_users_number/2, + get_password/2, get_password_s/2, is_user_exists/2, + remove_user/2, remove_user/3, store_type/0, + plain_password_required/0]). + start(_Host) -> case epam:start() of - {ok, _} -> ok; - {error,{already_started, _}} -> ok; - Err -> Err + {ok, _} -> ok; + {error, {already_started, _}} -> ok; + Err -> Err end. set_password(_User, _Server, _Password) -> {error, not_allowed}. -check_password(User, Server, Password, _Digest, _DigestGen) -> +check_password(User, Server, Password, _Digest, + _DigestGen) -> check_password(User, Server, Password). check_password(User, Host, Password) -> Service = get_pam_service(Host), UserInfo = case get_pam_userinfotype(Host) of - username -> User; - jid -> User++"@"++Host - end, - case catch epam:authenticate(Service, UserInfo, Password) of - true -> true; - _ -> false + username -> User; + jid -> <> + end, + case catch epam:authenticate(Service, UserInfo, + Password) + of + true -> true; + _ -> false end. try_register(_User, _Server, _Password) -> {error, not_allowed}. -dirty_get_registered_users() -> - []. +dirty_get_registered_users() -> []. -get_vh_registered_users(_Host) -> - []. +get_vh_registered_users(_Host) -> []. -get_password(_User, _Server) -> - false. +get_vh_registered_users(_Host, _) -> []. -get_password_s(_User, _Server) -> - "". +get_vh_registered_users_number(_Host) -> 0. + +get_vh_registered_users_number(_Host, _) -> 0. + +get_password(_User, _Server) -> false. + +get_password_s(_User, _Server) -> <<"">>. %% @spec (User, Server) -> true | false | {error, Error} %% TODO: Improve this function to return an error instead of 'false' when connection to PAM failed is_user_exists(User, Host) -> Service = get_pam_service(Host), UserInfo = case get_pam_userinfotype(Host) of - username -> User; - jid -> User++"@"++Host - end, + username -> User; + jid -> <> + end, case catch epam:acct_mgmt(Service, UserInfo) of - true -> true; - _ -> false + true -> true; + _ -> false end. -remove_user(_User, _Server) -> - {error, not_allowed}. +remove_user(_User, _Server) -> {error, not_allowed}. -remove_user(_User, _Server, _Password) -> - not_allowed. +remove_user(_User, _Server, _Password) -> not_allowed. -plain_password_required() -> - true. +plain_password_required() -> true. -store_type() -> - external. +store_type() -> external. %%==================================================================== %% Internal functions %%==================================================================== get_pam_service(Host) -> - case ejabberd_config:get_local_option({pam_service, Host}) of - undefined -> "ejabberd"; - Service -> Service - end. + ejabberd_config:get_local_option( + {pam_service, Host}, + fun iolist_to_binary/1, + <<"ejabberd">>). + get_pam_userinfotype(Host) -> - case ejabberd_config:get_local_option({pam_userinfotype, Host}) of - undefined -> username; - Type -> Type - end. + ejabberd_config:get_local_option( + {pam_userinfotype, Host}, + fun(username) -> username; + (jid) -> jid + end, + username). diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index ed26400f0..f1cde0e0f 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -25,7 +25,9 @@ %%%---------------------------------------------------------------------- -module(ejabberd_c2s). + -author('alexey@process-one.net'). + -update_info({update, 0}). -define(GEN_FSM, p1_fsm). @@ -61,11 +63,13 @@ code_change/4, handle_info/3, terminate/3, - print_state/1 + print_state/1 ]). -include("ejabberd.hrl"). + -include("jlib.hrl"). + -include("mod_privacy.hrl"). -define(SETS, gb_sets). @@ -88,7 +92,7 @@ tls_options = [], authenticated = false, jid, - user = "", server = ?MYNAME, resource = "", + user = "", server = ?MYNAME, resource = <<"">>, sid, pres_t = ?SETS:new(), pres_f = ?SETS:new(), @@ -107,9 +111,13 @@ %-define(DBGFSM, true). -ifdef(DBGFSM). + -define(FSMOPTS, [{debug, [trace]}]). + -else. + -define(FSMOPTS, []). + -endif. %% Module start with or without supervisor: @@ -124,22 +132,26 @@ %% This is the timeout to apply between event when starting a new %% session: -define(C2S_OPEN_TIMEOUT, 60000). + -define(C2S_HIBERNATE_TIMEOUT, 90000). -define(STREAM_HEADER, - "" - "" - ). + <<"">>). --define(STREAM_TRAILER, ""). +-define(STREAM_TRAILER, <<"">>). -define(INVALID_NS_ERR, ?SERR_INVALID_NAMESPACE). + -define(INVALID_XML_ERR, ?SERR_XML_NOT_WELL_FORMED). + -define(HOST_UNKNOWN_ERR, ?SERR_HOST_UNKNOWN). + -define(POLICY_VIOLATION_ERR(Lang, Text), ?SERRT_POLICY_VIOLATION(Lang, Text)). + -define(INVALID_FROM, ?SERR_INVALID_FROM). @@ -153,24 +165,23 @@ start_link(SockData, Opts) -> ?GEN_FSM:start_link(ejabberd_c2s, [SockData, Opts], fsm_limit_opts(Opts) ++ ?FSMOPTS). -socket_type() -> - xml_stream. +socket_type() -> xml_stream. %% Return Username, Resource and presence information get_presence(FsmRef) -> - ?GEN_FSM:sync_send_all_state_event(FsmRef, {get_presence}, 1000). + (?GEN_FSM):sync_send_all_state_event(FsmRef, + {get_presence}, 1000). get_aux_field(Key, #state{aux_fields = Opts}) -> case lists:keysearch(Key, 1, Opts) of - {value, {_, Val}} -> - {ok, Val}; - _ -> - error + {value, {_, Val}} -> {ok, Val}; + _ -> error end. -set_aux_field(Key, Val, #state{aux_fields = Opts} = State) -> +set_aux_field(Key, Val, + #state{aux_fields = Opts} = State) -> Opts1 = lists:keydelete(Key, 1, Opts), - State#state{aux_fields = [{Key, Val}|Opts1]}. + State#state{aux_fields = [{Key, Val} | Opts1]}. del_aux_field(Key, #state{aux_fields = Opts} = State) -> Opts1 = lists:keydelete(Key, 1, Opts), @@ -179,11 +190,13 @@ del_aux_field(Key, #state{aux_fields = Opts} = State) -> get_subscription(From = #jid{}, StateData) -> get_subscription(jlib:jid_tolower(From), StateData); get_subscription(LFrom, StateData) -> - LBFrom = setelement(3, LFrom, ""), - F = ?SETS:is_element(LFrom, StateData#state.pres_f) orelse - ?SETS:is_element(LBFrom, StateData#state.pres_f), - T = ?SETS:is_element(LFrom, StateData#state.pres_t) orelse - ?SETS:is_element(LBFrom, StateData#state.pres_t), + LBFrom = setelement(3, LFrom, <<"">>), + F = (?SETS):is_element(LFrom, StateData#state.pres_f) + orelse + (?SETS):is_element(LBFrom, StateData#state.pres_f), + T = (?SETS):is_element(LFrom, StateData#state.pres_t) + orelse + (?SETS):is_element(LBFrom, StateData#state.pres_t), if F and T -> both; F -> from; T -> to; @@ -193,8 +206,7 @@ get_subscription(LFrom, StateData) -> broadcast(FsmRef, Type, From, Packet) -> FsmRef ! {broadcast, Type, From, Packet}. -stop(FsmRef) -> - ?GEN_FSM:send_event(FsmRef, closed). +stop(FsmRef) -> (?GEN_FSM):send_event(FsmRef, closed). %%%---------------------------------------------------------------------- %%% Callback functions from gen_fsm @@ -209,63 +221,58 @@ stop(FsmRef) -> %%---------------------------------------------------------------------- init([{SockMod, Socket}, Opts]) -> Access = case lists:keysearch(access, 1, Opts) of - {value, {_, A}} -> A; - _ -> all + {value, {_, A}} -> A; + _ -> all end, Shaper = case lists:keysearch(shaper, 1, Opts) of - {value, {_, S}} -> S; - _ -> none + {value, {_, S}} -> S; + _ -> none end, - XMLSocket = - case lists:keysearch(xml_socket, 1, Opts) of - {value, {_, XS}} -> XS; - _ -> false - end, + XMLSocket = case lists:keysearch(xml_socket, 1, Opts) of + {value, {_, XS}} -> XS; + _ -> false + end, Zlib = lists:member(zlib, Opts), StartTLS = lists:member(starttls, Opts), - StartTLSRequired = lists:member(starttls_required, Opts), + StartTLSRequired = lists:member(starttls_required, + Opts), TLSEnabled = lists:member(tls, Opts), - TLS = StartTLS orelse StartTLSRequired orelse TLSEnabled, - TLSOpts1 = - lists:filter(fun({certfile, _}) -> true; - (_) -> false - end, Opts), + TLS = StartTLS orelse + StartTLSRequired orelse TLSEnabled, + TLSOpts1 = lists:filter(fun ({certfile, _}) -> true; + (_) -> false + end, + Opts), TLSOpts = [verify_none | TLSOpts1], IP = peerip(SockMod, Socket), %% Check if IP is blacklisted: case is_ip_blacklisted(IP) of - true -> - ?INFO_MSG("Connection attempt from blacklisted IP: ~s (~w)", - [jlib:ip_to_list(IP), IP]), - {stop, normal}; - false -> - Socket1 = - if - TLSEnabled -> - SockMod:starttls(Socket, TLSOpts); - true -> - Socket - end, - SocketMonitor = SockMod:monitor(Socket1), - {ok, wait_for_stream, #state{socket = Socket1, - sockmod = SockMod, - socket_monitor = SocketMonitor, - xml_socket = XMLSocket, - zlib = Zlib, - tls = TLS, - tls_required = StartTLSRequired, - tls_enabled = TLSEnabled, - tls_options = TLSOpts, - streamid = new_id(), - access = Access, - shaper = Shaper, - ip = IP}, - ?C2S_OPEN_TIMEOUT} + true -> + ?INFO_MSG("Connection attempt from blacklisted " + "IP: ~s (~w)", + [jlib:ip_to_list(IP), IP]), + {stop, normal}; + false -> + Socket1 = if TLSEnabled andalso + SockMod /= ejabberd_frontend_socket -> + SockMod:starttls(Socket, TLSOpts); + true -> Socket + end, + SocketMonitor = SockMod:monitor(Socket1), + StateData = #state{socket = Socket1, sockmod = SockMod, + socket_monitor = SocketMonitor, + xml_socket = XMLSocket, zlib = Zlib, tls = TLS, + tls_required = StartTLSRequired, + tls_enabled = TLSEnabled, tls_options = TLSOpts, + streamid = new_id(), access = Access, + shaper = Shaper, ip = IP}, + {ok, wait_for_stream, StateData, ?C2S_OPEN_TIMEOUT} end. %% Return list of all available resources of contacts, get_subscribed(FsmRef) -> - ?GEN_FSM:sync_send_all_state_event(FsmRef, get_subscribed, 1000). + (?GEN_FSM):sync_send_all_state_event(FsmRef, + get_subscribed, 1000). %%---------------------------------------------------------------------- %% Func: StateName/2 @@ -276,17 +283,15 @@ get_subscribed(FsmRef) -> wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> DefaultLang = case ?MYLANG of - undefined -> - "en"; - DL -> - DL - end, - case xml:get_attr_s("xmlns:stream", Attrs) of + undefined -> <<"en">>; + DL -> DL + end, + case xml:get_attr_s(<<"xmlns:stream">>, Attrs) of ?NS_STREAM -> - Server = jlib:nameprep(xml:get_attr_s("to", Attrs)), + Server = jlib:nameprep(xml:get_attr_s(<<"to">>, Attrs)), case lists:member(Server, ?MYHOSTS) of true -> - Lang = case xml:get_attr_s("xml:lang", Attrs) of + Lang = case xml:get_attr_s(<<"xml:lang">>, Attrs) of Lang1 when length(Lang1) =< 35 -> %% As stated in BCP47, 4.4.1: %% Protocols or specifications that @@ -297,17 +302,17 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> _ -> %% Do not store long language tag to %% avoid possible DoS/flood attacks - "" + <<"">> end, - change_shaper(StateData, jlib:make_jid("", Server, "")), - case xml:get_attr_s("version", Attrs) of - "1.0" -> - send_header(StateData, Server, "1.0", DefaultLang), + change_shaper(StateData, jlib:make_jid(<<"">>, Server, <<"">>)), + case xml:get_attr_s(<<"version">>, Attrs) of + <<"1.0">> -> + send_header(StateData, Server, <<"1.0">>, DefaultLang), case StateData#state.authenticated of false -> SASLState = cyrsasl:server_new( - "jabber", Server, "", [], + <<"jabber">>, Server, <<"">>, [], fun(U) -> ejabberd_auth:get_password_with_authmodule( U, Server) @@ -320,11 +325,12 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> ejabberd_auth:check_password_with_authmodule( U, Server, P, D, DG) end), - Mechs = lists:map( - fun(S) -> - {xmlelement, "mechanism", [], - [{xmlcdata, S}]} - end, cyrsasl:listmech(Server)), + Mechs = lists:map(fun (S) -> + #xmlel{name = <<"mechanism">>, + attrs = [], + children = [{xmlcdata, S}]} + end, + cyrsasl:listmech(Server)), SockMod = (StateData#state.sockmod):get_sockmod( StateData#state.socket), @@ -334,10 +340,11 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> ((SockMod == gen_tcp) orelse (SockMod == tls)) of true -> - [{xmlelement, "compression", - [{"xmlns", ?NS_FEATURE_COMPRESS}], - [{xmlelement, "method", - [], [{xmlcdata, "zlib"}]}]}]; + [#xmlel{name = <<"compression">>, + attrs = [{<<"xmlns">>, ?NS_FEATURE_COMPRESS}], + children = [#xmlel{name = <<"method">>, + attrs = [], + children = [{xmlcdata, <<"zlib">>}]}]}]; _ -> [] end, @@ -351,27 +358,30 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> true -> case TLSRequired of true -> - [{xmlelement, "starttls", - [{"xmlns", ?NS_TLS}], - [{xmlelement, "required", - [], []}]}]; + [#xmlel{name = <<"starttls">>, + attrs = [{<<"xmlns">>, ?NS_TLS}], + children = [#xmlel{name = <<"required">>, + attrs = [], + children = []}]}]; _ -> - [{xmlelement, "starttls", - [{"xmlns", ?NS_TLS}], []}] + [#xmlel{name = <<"starttls">>, + attrs = [{<<"xmlns">>, ?NS_TLS}], + children = []}] end; false -> [] end, send_element(StateData, - {xmlelement, "stream:features", [], - TLSFeature ++ CompressFeature ++ - [{xmlelement, "mechanisms", - [{"xmlns", ?NS_SASL}], - Mechs}] ++ - ejabberd_hooks:run_fold( - c2s_stream_features, - Server, - [], [Server])}), + #xmlel{name = <<"stream:features">>, + attrs = [], + children = + TLSFeature ++ CompressFeature ++ + [#xmlel{name = <<"mechanisms">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = Mechs}] + ++ + ejabberd_hooks:run_fold(c2s_stream_features, + Server, [], [Server])}), fsm_next_state(wait_for_feature_request, StateData#state{ server = Server, @@ -379,556 +389,578 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> lang = Lang}); _ -> case StateData#state.resource of - "" -> - RosterVersioningFeature = - ejabberd_hooks:run_fold( - roster_get_versioning_feature, - Server, [], [Server]), - StreamFeatures = - [{xmlelement, "bind", - [{"xmlns", ?NS_BIND}], []}, - {xmlelement, "session", - [{"xmlns", ?NS_SESSION}], []}] - ++ RosterVersioningFeature - ++ ejabberd_hooks:run_fold( - c2s_stream_features, - Server, - [], [Server]), - send_element( - StateData, - {xmlelement, "stream:features", [], - StreamFeatures}), - fsm_next_state(wait_for_bind, - StateData#state{ - server = Server, - lang = Lang}); - _ -> - send_element( - StateData, - {xmlelement, "stream:features", [], []}), - fsm_next_state(wait_for_session, - StateData#state{ - server = Server, - lang = Lang}) + <<"">> -> + RosterVersioningFeature = + ejabberd_hooks:run_fold(roster_get_versioning_feature, + Server, [], + [Server]), + StreamFeatures = [#xmlel{name = <<"bind">>, + attrs = [{<<"xmlns">>, ?NS_BIND}], + children = []}, + #xmlel{name = <<"session">>, + attrs = [{<<"xmlns">>, ?NS_SESSION}], + children = []}] + ++ + RosterVersioningFeature ++ + ejabberd_hooks:run_fold(c2s_stream_features, + Server, [], [Server]), + send_element(StateData, + #xmlel{name = <<"stream:features">>, + attrs = [], + children = StreamFeatures}), + fsm_next_state(wait_for_bind, + StateData#state{server = Server, lang = Lang}); + _ -> + send_element(StateData, + #xmlel{name = <<"stream:features">>, + attrs = [], + children = []}), + fsm_next_state(wait_for_session, + StateData#state{server = Server, lang = Lang}) end end; - _ -> - send_header(StateData, Server, "", DefaultLang), - if - (not StateData#state.tls_enabled) and - StateData#state.tls_required -> - send_element( - StateData, - ?POLICY_VIOLATION_ERR( - Lang, - "Use of STARTTLS required")), - send_trailer(StateData), - {stop, normal, StateData}; - true -> - fsm_next_state(wait_for_auth, - StateData#state{ - server = Server, - lang = Lang}) - end - end; _ -> - send_header(StateData, ?MYNAME, "", DefaultLang), - send_element(StateData, ?HOST_UNKNOWN_ERR), - send_trailer(StateData), - {stop, normal, StateData} + send_header(StateData, Server, <<"">>, DefaultLang), + if not StateData#state.tls_enabled and + StateData#state.tls_required -> + send_element(StateData, + ?POLICY_VIOLATION_ERR(Lang, + <<"Use of STARTTLS required">>)), + send_trailer(StateData), + {stop, normal, StateData}; + true -> + fsm_next_state(wait_for_auth, + StateData#state{server = Server, + lang = Lang}) + end end; _ -> - send_header(StateData, ?MYNAME, "", DefaultLang), - send_element(StateData, ?INVALID_NS_ERR), + send_header(StateData, ?MYNAME, <<"">>, DefaultLang), + send_element(StateData, ?HOST_UNKNOWN_ERR), send_trailer(StateData), {stop, normal, StateData} + end; + _ -> + send_header(StateData, ?MYNAME, <<"">>, DefaultLang), + send_element(StateData, ?INVALID_NS_ERR), + send_trailer(StateData), + {stop, normal, StateData} end; - wait_for_stream(timeout, StateData) -> {stop, normal, StateData}; - wait_for_stream({xmlstreamelement, _}, StateData) -> send_element(StateData, ?INVALID_XML_ERR), send_trailer(StateData), {stop, normal, StateData}; - wait_for_stream({xmlstreamend, _}, StateData) -> send_element(StateData, ?INVALID_XML_ERR), send_trailer(StateData), {stop, normal, StateData}; - wait_for_stream({xmlstreamerror, _}, StateData) -> - send_header(StateData, ?MYNAME, "1.0", ""), + send_header(StateData, ?MYNAME, <<"1.0">>, <<"">>), send_element(StateData, ?INVALID_XML_ERR), send_trailer(StateData), {stop, normal, StateData}; - wait_for_stream(closed, StateData) -> {stop, normal, StateData}. - wait_for_auth({xmlstreamelement, El}, StateData) -> case is_auth_packet(El) of - {auth, _ID, get, {U, _, _, _}} -> - {xmlelement, Name, Attrs, _Els} = jlib:make_result_iq_reply(El), - case U of - "" -> - UCdata = []; - _ -> - UCdata = [{xmlcdata, U}] - end, - Res = case ejabberd_auth:plain_password_required( - StateData#state.server) of - false -> - {xmlelement, Name, Attrs, - [{xmlelement, "query", [{"xmlns", ?NS_AUTH}], - [{xmlelement, "username", [], UCdata}, - {xmlelement, "password", [], []}, - {xmlelement, "digest", [], []}, - {xmlelement, "resource", [], []} - ]}]}; - true -> - {xmlelement, Name, Attrs, - [{xmlelement, "query", [{"xmlns", ?NS_AUTH}], - [{xmlelement, "username", [], UCdata}, - {xmlelement, "password", [], []}, - {xmlelement, "resource", [], []} - ]}]} - end, - send_element(StateData, Res), - fsm_next_state(wait_for_auth, StateData); - {auth, _ID, set, {_U, _P, _D, ""}} -> - Err = jlib:make_error_reply( - El, - ?ERR_AUTH_NO_RESOURCE_PROVIDED(StateData#state.lang)), - send_element(StateData, Err), - fsm_next_state(wait_for_auth, StateData); - {auth, _ID, set, {U, P, D, R}} -> - JID = jlib:make_jid(U, StateData#state.server, R), - case (JID /= error) andalso - (acl:match_rule(StateData#state.server, - StateData#state.access, JID) == allow) of - true -> - DGen = fun(PW) -> - sha:sha(StateData#state.streamid ++ PW) end, - case ejabberd_auth:check_password_with_authmodule( - U, StateData#state.server, P, D, DGen) of - {true, AuthModule} -> - ?INFO_MSG( - "(~w) Accepted legacy authentication for ~s by ~p", - [StateData#state.socket, - jlib:jid_to_string(JID), AuthModule]), - SID = {now(), self()}, - Conn = get_conn_type(StateData), - Info = [{ip, StateData#state.ip}, {conn, Conn}, + {auth, _ID, get, {U, _, _, _}} -> + #xmlel{name = Name, attrs = Attrs} = + jlib:make_result_iq_reply(El), + case U of + <<"">> -> UCdata = []; + _ -> UCdata = [{xmlcdata, U}] + end, + Res = case + ejabberd_auth:plain_password_required(StateData#state.server) + of + false -> + #xmlel{name = Name, attrs = Attrs, + children = + [#xmlel{name = <<"query">>, + attrs = [{<<"xmlns">>, ?NS_AUTH}], + children = + [#xmlel{name = <<"username">>, + attrs = [], + children = UCdata}, + #xmlel{name = <<"password">>, + attrs = [], children = []}, + #xmlel{name = <<"digest">>, + attrs = [], children = []}, + #xmlel{name = <<"resource">>, + attrs = [], + children = []}]}]}; + true -> + #xmlel{name = Name, attrs = Attrs, + children = + [#xmlel{name = <<"query">>, + attrs = [{<<"xmlns">>, ?NS_AUTH}], + children = + [#xmlel{name = <<"username">>, + attrs = [], + children = UCdata}, + #xmlel{name = <<"password">>, + attrs = [], children = []}, + #xmlel{name = <<"resource">>, + attrs = [], + children = []}]}]} + end, + send_element(StateData, Res), + fsm_next_state(wait_for_auth, StateData); + {auth, _ID, set, {_U, _P, _D, <<"">>}} -> + Err = jlib:make_error_reply(El, + ?ERR_AUTH_NO_RESOURCE_PROVIDED((StateData#state.lang))), + send_element(StateData, Err), + fsm_next_state(wait_for_auth, StateData); + {auth, _ID, set, {U, P, D, R}} -> + JID = jlib:make_jid(U, StateData#state.server, R), + case JID /= error andalso + acl:match_rule(StateData#state.server, + StateData#state.access, JID) + == allow + of + true -> + DGen = fun (PW) -> + sha:sha(<<(StateData#state.streamid)/binary, PW/binary>>) + end, + case ejabberd_auth:check_password_with_authmodule(U, + StateData#state.server, + P, D, DGen) + of + {true, AuthModule} -> + ?INFO_MSG("(~w) Accepted legacy authentication for ~s by ~p", + [StateData#state.socket, + jlib:jid_to_string(JID), AuthModule]), + SID = {now(), self()}, + Conn = (StateData#state.sockmod):get_conn_type( + StateData#state.socket), + Info = [{ip, StateData#state.ip}, {conn, Conn}, {auth_module, AuthModule}], - Res1 = jlib:make_result_iq_reply(El), - Res = setelement(4, Res1, []), - send_element(StateData, Res), - ejabberd_sm:open_session( - SID, U, StateData#state.server, R, Info), - change_shaper(StateData, JID), - {Fs, Ts} = ejabberd_hooks:run_fold( - roster_get_subscription_lists, - StateData#state.server, - {[], []}, - [U, StateData#state.server]), - LJID = jlib:jid_tolower( - jlib:jid_remove_resource(JID)), - Fs1 = [LJID | Fs], - Ts1 = [LJID | Ts], - PrivList = - ejabberd_hooks:run_fold( - privacy_get_user_list, StateData#state.server, - #userlist{}, - [U, StateData#state.server]), - NewStateData = - StateData#state{ - user = U, - resource = R, - jid = JID, - sid = SID, - conn = Conn, - auth_module = AuthModule, - pres_f = ?SETS:from_list(Fs1), - pres_t = ?SETS:from_list(Ts1), - privacy_list = PrivList}, - fsm_next_state_pack(session_established, - NewStateData); - _ -> - IP = peerip(StateData#state.sockmod, StateData#state.socket), - ?INFO_MSG( - "(~w) Failed legacy authentication for ~s from IP ~s (~w)", - [StateData#state.socket, - jlib:jid_to_string(JID), jlib:ip_to_list(IP), IP]), - Err = jlib:make_error_reply( - El, ?ERR_NOT_AUTHORIZED), - send_element(StateData, Err), - fsm_next_state(wait_for_auth, StateData) - end; - _ -> - if - JID == error -> - ?INFO_MSG( - "(~w) Forbidden legacy authentication for " - "username '~s' with resource '~s'", - [StateData#state.socket, U, R]), - Err = jlib:make_error_reply(El, ?ERR_JID_MALFORMED), - send_element(StateData, Err), - fsm_next_state(wait_for_auth, StateData); - true -> - ?INFO_MSG( - "(~w) Forbidden legacy authentication for ~s", - [StateData#state.socket, - jlib:jid_to_string(JID)]), - Err = jlib:make_error_reply(El, ?ERR_NOT_ALLOWED), - send_element(StateData, Err), - fsm_next_state(wait_for_auth, StateData) - end - end; - _ -> - process_unauthenticated_stanza(StateData, El), - fsm_next_state(wait_for_auth, StateData) + Res1 = jlib:make_result_iq_reply(El), + Res = Res1#xmlel{children = []}, + send_element(StateData, Res), + ejabberd_sm:open_session(SID, U, StateData#state.server, R, Info), + change_shaper(StateData, JID), + {Fs, Ts} = + ejabberd_hooks:run_fold(roster_get_subscription_lists, + StateData#state.server, + {[], []}, + [U, + StateData#state.server]), + LJID = + jlib:jid_tolower(jlib:jid_remove_resource(JID)), + Fs1 = [LJID | Fs], + Ts1 = [LJID | Ts], + PrivList = ejabberd_hooks:run_fold(privacy_get_user_list, + StateData#state.server, + #userlist{}, + [U, StateData#state.server]), + NewStateData = StateData#state{user = U, + resource = R, + jid = JID, sid = SID, + conn = Conn, + auth_module = AuthModule, + pres_f = (?SETS):from_list(Fs1), + pres_t = (?SETS):from_list(Ts1), + privacy_list = PrivList}, + fsm_next_state(session_established, NewStateData); + _ -> + IP = peerip(StateData#state.sockmod, + StateData#state.socket), + ?INFO_MSG("(~w) Failed legacy authentication for " + "~s from IP ~s", + [StateData#state.socket, + jlib:jid_to_string(JID), jlib:ip_to_list(IP)]), + Err = jlib:make_error_reply(El, ?ERR_NOT_AUTHORIZED), + send_element(StateData, Err), + fsm_next_state(wait_for_auth, StateData) + end; + _ -> + if JID == error -> + ?INFO_MSG("(~w) Forbidden legacy authentication " + "for username '~s' with resource '~s'", + [StateData#state.socket, U, R]), + Err = jlib:make_error_reply(El, ?ERR_JID_MALFORMED), + send_element(StateData, Err), + fsm_next_state(wait_for_auth, StateData); + true -> + ?INFO_MSG("(~w) Forbidden legacy authentication " + "for ~s", + [StateData#state.socket, + jlib:jid_to_string(JID)]), + Err = jlib:make_error_reply(El, ?ERR_NOT_ALLOWED), + send_element(StateData, Err), + fsm_next_state(wait_for_auth, StateData) + end + end; + _ -> + process_unauthenticated_stanza(StateData, El), + fsm_next_state(wait_for_auth, StateData) end; - wait_for_auth(timeout, StateData) -> {stop, normal, StateData}; - wait_for_auth({xmlstreamend, _Name}, StateData) -> - send_trailer(StateData), - {stop, normal, StateData}; - + send_trailer(StateData), {stop, normal, StateData}; wait_for_auth({xmlstreamerror, _}, StateData) -> send_element(StateData, ?INVALID_XML_ERR), send_trailer(StateData), {stop, normal, StateData}; - wait_for_auth(closed, StateData) -> {stop, normal, StateData}. - -wait_for_feature_request({xmlstreamelement, El}, StateData) -> - {xmlelement, Name, Attrs, Els} = El, +wait_for_feature_request({xmlstreamelement, El}, + StateData) -> + #xmlel{name = Name, attrs = Attrs, children = Els} = El, Zlib = StateData#state.zlib, TLS = StateData#state.tls, TLSEnabled = StateData#state.tls_enabled, TLSRequired = StateData#state.tls_required, - SockMod = (StateData#state.sockmod):get_sockmod(StateData#state.socket), - case {xml:get_attr_s("xmlns", Attrs), Name} of - {?NS_SASL, "auth"} when not ((SockMod == gen_tcp) and TLSRequired) -> - Mech = xml:get_attr_s("mechanism", Attrs), - ClientIn = jlib:decode_base64(xml:get_cdata(Els)), - case cyrsasl:server_start(StateData#state.sasl_state, - Mech, - ClientIn) of - {ok, Props} -> - (StateData#state.sockmod):reset_stream( - StateData#state.socket), - send_element(StateData, - {xmlelement, "success", - [{"xmlns", ?NS_SASL}], []}), - U = xml:get_attr_s(username, Props), - AuthModule = xml:get_attr_s(auth_module, Props), - ?INFO_MSG("(~w) Accepted authentication for ~s by ~p", - [StateData#state.socket, U, AuthModule]), - fsm_next_state(wait_for_stream, - StateData#state{ - streamid = new_id(), - authenticated = true, - auth_module = AuthModule, - user = U }); - {continue, ServerOut, NewSASLState} -> - send_element(StateData, - {xmlelement, "challenge", - [{"xmlns", ?NS_SASL}], - [{xmlcdata, - jlib:encode_base64(ServerOut)}]}), - fsm_next_state(wait_for_sasl_response, - StateData#state{ - sasl_state = NewSASLState}); - {error, Error, Username} -> - IP = peerip(StateData#state.sockmod, StateData#state.socket), - ?INFO_MSG( - "(~w) Failed authentication for ~s@~s from IP ~s (~w)", + SockMod = + (StateData#state.sockmod):get_sockmod(StateData#state.socket), + case {xml:get_attr_s(<<"xmlns">>, Attrs), Name} of + {?NS_SASL, <<"auth">>} + when not ((SockMod == gen_tcp) and TLSRequired) -> + Mech = xml:get_attr_s(<<"mechanism">>, Attrs), + ClientIn = jlib:decode_base64(xml:get_cdata(Els)), + case cyrsasl:server_start(StateData#state.sasl_state, + Mech, ClientIn) + of + {ok, Props} -> + (StateData#state.sockmod):reset_stream(StateData#state.socket), + %U = xml:get_attr_s(username, Props), + U = proplists:get_value(username, Props, <<>>), + %AuthModule = xml:get_attr_s(auth_module, Props), + AuthModule = proplists:get_value(auth_module, Props, undefined), + ?INFO_MSG("(~w) Accepted authentication for ~s " + "by ~p", + [StateData#state.socket, U, AuthModule]), + send_element(StateData, + #xmlel{name = <<"success">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = []}), + fsm_next_state(wait_for_stream, + StateData#state{streamid = new_id(), + authenticated = true, + auth_module = AuthModule, + user = U}); + {continue, ServerOut, NewSASLState} -> + send_element(StateData, + #xmlel{name = <<"challenge">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = + [{xmlcdata, + jlib:encode_base64(ServerOut)}]}), + fsm_next_state(wait_for_sasl_response, + StateData#state{sasl_state = NewSASLState}); + {error, Error, Username} -> + IP = peerip(StateData#state.sockmod, StateData#state.socket), + ?INFO_MSG("(~w) Failed authentication for ~s@~s from IP ~s", [StateData#state.socket, - Username, StateData#state.server, jlib:ip_to_list(IP), IP]), - send_element(StateData, - {xmlelement, "failure", - [{"xmlns", ?NS_SASL}], - [{xmlelement, Error, [], []}]}), - {next_state, wait_for_feature_request, StateData, - ?C2S_OPEN_TIMEOUT}; - {error, Error} -> - send_element(StateData, - {xmlelement, "failure", - [{"xmlns", ?NS_SASL}], - [{xmlelement, Error, [], []}]}), - fsm_next_state(wait_for_feature_request, StateData) - end; - {?NS_TLS, "starttls"} when TLS == true, - TLSEnabled == false, - SockMod == gen_tcp -> - TLSOpts = case ejabberd_config:get_local_option( - {domain_certfile, StateData#state.server}) of - undefined -> - StateData#state.tls_options; - CertFile -> - [{certfile, CertFile} | - lists:keydelete( - certfile, 1, StateData#state.tls_options)] - end, - Socket = StateData#state.socket, - TLSSocket = (StateData#state.sockmod):starttls( - Socket, TLSOpts, - xml:element_to_binary( - {xmlelement, "proceed", [{"xmlns", ?NS_TLS}], []})), - fsm_next_state(wait_for_stream, - StateData#state{socket = TLSSocket, - streamid = new_id(), - tls_enabled = true - }); - {?NS_COMPRESS, "compress"} when Zlib == true, - ((SockMod == gen_tcp) or - (SockMod == tls)) -> - case xml:get_subtag(El, "method") of - false -> - send_element(StateData, - {xmlelement, "failure", - [{"xmlns", ?NS_COMPRESS}], - [{xmlelement, "setup-failed", [], []}]}), - fsm_next_state(wait_for_feature_request, StateData); - Method -> - case xml:get_tag_cdata(Method) of - "zlib" -> - Socket = StateData#state.socket, - ZlibSocket = (StateData#state.sockmod):compress( - Socket, - xml:element_to_binary( - {xmlelement, "compressed", - [{"xmlns", ?NS_COMPRESS}], []})), - fsm_next_state(wait_for_stream, - StateData#state{socket = ZlibSocket, - streamid = new_id() - }); - _ -> - send_element(StateData, - {xmlelement, "failure", - [{"xmlns", ?NS_COMPRESS}], - [{xmlelement, "unsupported-method", - [], []}]}), - fsm_next_state(wait_for_feature_request, - StateData) - end - end; - _ -> - if - (SockMod == gen_tcp) and TLSRequired -> - Lang = StateData#state.lang, - send_element(StateData, ?POLICY_VIOLATION_ERR( - Lang, - "Use of STARTTLS required")), - send_trailer(StateData), - {stop, normal, StateData}; - true -> - process_unauthenticated_stanza(StateData, El), - fsm_next_state(wait_for_feature_request, StateData) - end + Username, StateData#state.server, jlib:ip_to_list(IP)]), + send_element(StateData, + #xmlel{name = <<"failure">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = + [#xmlel{name = Error, attrs = [], + children = []}]}), + {next_state, wait_for_feature_request, StateData, + ?C2S_OPEN_TIMEOUT}; + {error, Error} -> + send_element(StateData, + #xmlel{name = <<"failure">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = + [#xmlel{name = Error, attrs = [], + children = []}]}), + fsm_next_state(wait_for_feature_request, StateData) + end; + {?NS_TLS, <<"starttls">>} + when TLS == true, TLSEnabled == false, + SockMod == gen_tcp -> + TLSOpts = case + ejabberd_config:get_local_option( + {domain_certfile, StateData#state.server}, + fun iolist_to_binary/1) + of + undefined -> StateData#state.tls_options; + CertFile -> + [{certfile, CertFile} | lists:keydelete(certfile, 1, + StateData#state.tls_options)] + end, + Socket = StateData#state.socket, + TLSSocket = (StateData#state.sockmod):starttls(Socket, + TLSOpts, + xml:element_to_binary(#xmlel{name + = + <<"proceed">>, + attrs + = + [{<<"xmlns">>, + ?NS_TLS}], + children + = + []})), + fsm_next_state(wait_for_stream, + StateData#state{socket = TLSSocket, + streamid = new_id(), + tls_enabled = true}); + {?NS_COMPRESS, <<"compress">>} + when Zlib == true, + (SockMod == gen_tcp) or (SockMod == tls) -> + case xml:get_subtag(El, <<"method">>) of + false -> + send_element(StateData, + #xmlel{name = <<"failure">>, + attrs = [{<<"xmlns">>, ?NS_COMPRESS}], + children = + [#xmlel{name = <<"setup-failed">>, + attrs = [], children = []}]}), + fsm_next_state(wait_for_feature_request, StateData); + Method -> + case xml:get_tag_cdata(Method) of + <<"zlib">> -> + Socket = StateData#state.socket, + ZlibSocket = (StateData#state.sockmod):compress(Socket, + xml:element_to_binary(#xmlel{name + = + <<"compressed">>, + attrs + = + [{<<"xmlns">>, + ?NS_COMPRESS}], + children + = + []})), + fsm_next_state(wait_for_stream, + StateData#state{socket = ZlibSocket, + streamid = new_id()}); + _ -> + send_element(StateData, + #xmlel{name = <<"failure">>, + attrs = [{<<"xmlns">>, ?NS_COMPRESS}], + children = + [#xmlel{name = + <<"unsupported-method">>, + attrs = [], + children = []}]}), + fsm_next_state(wait_for_feature_request, StateData) + end + end; + _ -> + if (SockMod == gen_tcp) and TLSRequired -> + Lang = StateData#state.lang, + send_element(StateData, + ?POLICY_VIOLATION_ERR(Lang, + <<"Use of STARTTLS required">>)), + send_trailer(StateData), + {stop, normal, StateData}; + true -> + process_unauthenticated_stanza(StateData, El), + fsm_next_state(wait_for_feature_request, StateData) + end end; - wait_for_feature_request(timeout, StateData) -> {stop, normal, StateData}; - -wait_for_feature_request({xmlstreamend, _Name}, StateData) -> - send_trailer(StateData), - {stop, normal, StateData}; - -wait_for_feature_request({xmlstreamerror, _}, StateData) -> +wait_for_feature_request({xmlstreamend, _Name}, + StateData) -> + send_trailer(StateData), {stop, normal, StateData}; +wait_for_feature_request({xmlstreamerror, _}, + StateData) -> send_element(StateData, ?INVALID_XML_ERR), send_trailer(StateData), {stop, normal, StateData}; - wait_for_feature_request(closed, StateData) -> {stop, normal, StateData}. - -wait_for_sasl_response({xmlstreamelement, El}, StateData) -> - {xmlelement, Name, Attrs, Els} = El, - case {xml:get_attr_s("xmlns", Attrs), Name} of - {?NS_SASL, "response"} -> - ClientIn = jlib:decode_base64(xml:get_cdata(Els)), - case cyrsasl:server_step(StateData#state.sasl_state, - ClientIn) of - {ok, Props} -> - (StateData#state.sockmod):reset_stream( - StateData#state.socket), - send_element(StateData, - {xmlelement, "success", - [{"xmlns", ?NS_SASL}], []}), - U = xml:get_attr_s(username, Props), - AuthModule = xml:get_attr_s(auth_module, Props), - ?INFO_MSG("(~w) Accepted authentication for ~s by ~p", - [StateData#state.socket, U, AuthModule]), - fsm_next_state(wait_for_stream, - StateData#state{ - streamid = new_id(), - authenticated = true, - auth_module = AuthModule, - user = U}); - {ok, Props, ServerOut} -> - (StateData#state.sockmod):reset_stream( - StateData#state.socket), - send_element(StateData, - {xmlelement, "success", - [{"xmlns", ?NS_SASL}], - [{xmlcdata, - jlib:encode_base64(ServerOut)}]}), - U = xml:get_attr_s(username, Props), - AuthModule = xml:get_attr_s(auth_module, Props), - ?INFO_MSG("(~w) Accepted authentication for ~s by ~p", - [StateData#state.socket, U, AuthModule]), - fsm_next_state(wait_for_stream, - StateData#state{ - streamid = new_id(), - authenticated = true, - auth_module = AuthModule, - user = U}); - {continue, ServerOut, NewSASLState} -> - send_element(StateData, - {xmlelement, "challenge", - [{"xmlns", ?NS_SASL}], - [{xmlcdata, - jlib:encode_base64(ServerOut)}]}), - fsm_next_state(wait_for_sasl_response, - StateData#state{sasl_state = NewSASLState}); - {error, Error, Username} -> - IP = peerip(StateData#state.sockmod, StateData#state.socket), - ?INFO_MSG( - "(~w) Failed authentication for ~s@~s from IP ~s (~w)", +wait_for_sasl_response({xmlstreamelement, El}, + StateData) -> + #xmlel{name = Name, attrs = Attrs, children = Els} = El, + case {xml:get_attr_s(<<"xmlns">>, Attrs), Name} of + {?NS_SASL, <<"response">>} -> + ClientIn = jlib:decode_base64(xml:get_cdata(Els)), + case cyrsasl:server_step(StateData#state.sasl_state, + ClientIn) + of + {ok, Props} -> + catch + (StateData#state.sockmod):reset_stream(StateData#state.socket), +% U = xml:get_attr_s(username, Props), + U = proplists:get_value(username, Props, <<>>), +% AuthModule = xml:get_attr_s(auth_module, Props), + AuthModule = proplists:get_value(auth_module, Props, <<>>), + ?INFO_MSG("(~w) Accepted authentication for ~s " + "by ~p", + [StateData#state.socket, U, AuthModule]), + send_element(StateData, + #xmlel{name = <<"success">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = []}), + fsm_next_state(wait_for_stream, + StateData#state{streamid = new_id(), + authenticated = true, + auth_module = AuthModule, + user = U}); + {ok, Props, ServerOut} -> + (StateData#state.sockmod):reset_stream(StateData#state.socket), +% U = xml:get_attr_s(username, Props), + U = proplists:get_value(username, Props, <<>>), +% AuthModule = xml:get_attr_s(auth_module, Props), + AuthModule = proplists:get_value(auth_module, Props, undefined), + ?INFO_MSG("(~w) Accepted authentication for ~s " + "by ~p", + [StateData#state.socket, U, AuthModule]), + send_element(StateData, + #xmlel{name = <<"success">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = + [{xmlcdata, + jlib:encode_base64(ServerOut)}]}), + fsm_next_state(wait_for_stream, + StateData#state{streamid = new_id(), + authenticated = true, + auth_module = AuthModule, + user = U}); + {continue, ServerOut, NewSASLState} -> + send_element(StateData, + #xmlel{name = <<"challenge">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = + [{xmlcdata, + jlib:encode_base64(ServerOut)}]}), + fsm_next_state(wait_for_sasl_response, + StateData#state{sasl_state = NewSASLState}); + {error, Error, Username} -> + IP = peerip(StateData#state.sockmod, StateData#state.socket), + ?INFO_MSG("(~w) Failed authentication for ~s@~s from IP ~s", [StateData#state.socket, - Username, StateData#state.server, jlib:ip_to_list(IP), IP]), - send_element(StateData, - {xmlelement, "failure", - [{"xmlns", ?NS_SASL}], - [{xmlelement, Error, [], []}]}), - fsm_next_state(wait_for_feature_request, StateData); - {error, Error} -> - send_element(StateData, - {xmlelement, "failure", - [{"xmlns", ?NS_SASL}], - [{xmlelement, Error, [], []}]}), - fsm_next_state(wait_for_feature_request, StateData) - end; - _ -> - process_unauthenticated_stanza(StateData, El), - fsm_next_state(wait_for_feature_request, StateData) + Username, StateData#state.server, jlib:ip_to_list(IP)]), + send_element(StateData, + #xmlel{name = <<"failure">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = + [#xmlel{name = Error, attrs = [], + children = []}]}), + fsm_next_state(wait_for_feature_request, StateData); + {error, Error} -> + send_element(StateData, + #xmlel{name = <<"failure">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = + [#xmlel{name = Error, attrs = [], + children = []}]}), + fsm_next_state(wait_for_feature_request, StateData) + end; + _ -> + process_unauthenticated_stanza(StateData, El), + fsm_next_state(wait_for_feature_request, StateData) end; - wait_for_sasl_response(timeout, StateData) -> {stop, normal, StateData}; - -wait_for_sasl_response({xmlstreamend, _Name}, StateData) -> - send_trailer(StateData), - {stop, normal, StateData}; - -wait_for_sasl_response({xmlstreamerror, _}, StateData) -> +wait_for_sasl_response({xmlstreamend, _Name}, + StateData) -> + send_trailer(StateData), {stop, normal, StateData}; +wait_for_sasl_response({xmlstreamerror, _}, + StateData) -> send_element(StateData, ?INVALID_XML_ERR), send_trailer(StateData), {stop, normal, StateData}; - wait_for_sasl_response(closed, StateData) -> {stop, normal, StateData}. - resource_conflict_action(U, S, R) -> - OptionRaw = case ejabberd_sm:is_existing_resource(U, S, R) of - true -> - ejabberd_config:get_local_option({resource_conflict,S}); - false -> - acceptnew + OptionRaw = case ejabberd_sm:is_existing_resource(U, S, + R) + of + true -> + ejabberd_config:get_local_option( + {resource_conflict, S}, + fun(setresource) -> setresource; + (closeold) -> closeold; + (closenew) -> closenew; + (acceptnew) -> acceptnew + end); + false -> + acceptnew end, Option = case OptionRaw of - setresource -> setresource; - closeold -> acceptnew; %% ejabberd_sm will close old session - closenew -> closenew; - acceptnew -> acceptnew; - _ -> acceptnew %% default ejabberd behavior + setresource -> setresource; + closeold -> + acceptnew; %% ejabberd_sm will close old session + closenew -> closenew; + acceptnew -> acceptnew; + _ -> acceptnew %% default ejabberd behavior end, case Option of - acceptnew -> - {accept_resource, R}; - closenew -> - closenew; - setresource -> - Rnew = lists:concat([randoms:get_string() | tuple_to_list(now())]), - {accept_resource, Rnew} + acceptnew -> {accept_resource, R}; + closenew -> closenew; + setresource -> + Rnew = iolist_to_binary([randoms:get_string() + | tuple_to_list(now())]), + {accept_resource, Rnew} end. wait_for_bind({xmlstreamelement, El}, StateData) -> case jlib:iq_query_info(El) of - #iq{type = set, xmlns = ?NS_BIND, sub_el = SubEl} = IQ -> - U = StateData#state.user, - R1 = xml:get_path_s(SubEl, [{elem, "resource"}, cdata]), - R = case jlib:resourceprep(R1) of - error -> error; - "" -> - lists:concat( - [randoms:get_string() | tuple_to_list(now())]); - Resource -> Resource - end, - case R of - error -> - Err = jlib:make_error_reply(El, ?ERR_BAD_REQUEST), - send_element(StateData, Err), - fsm_next_state(wait_for_bind, StateData); - _ -> - %%Server = StateData#state.server, - %%RosterVersioningFeature = - %% ejabberd_hooks:run_fold( - %% roster_get_versioning_feature, Server, [], [Server]), - %%StreamFeatures = [{xmlelement, "session", - %% [{"xmlns", ?NS_SESSION}], []} | - %% RosterVersioningFeature], - %%send_element(StateData, {xmlelement, "stream:features", - %% [], StreamFeatures}), - case resource_conflict_action(U, StateData#state.server, R) of - closenew -> - Err = jlib:make_error_reply(El, ?STANZA_ERROR("409", "modify", "conflict")), - send_element(StateData, Err), - fsm_next_state(wait_for_bind, StateData); - {accept_resource, R2} -> - JID = jlib:make_jid(U, StateData#state.server, R2), - Res = IQ#iq{type = result, - sub_el = [{xmlelement, "bind", - [{"xmlns", ?NS_BIND}], - [{xmlelement, "jid", [], - [{xmlcdata, - jlib:jid_to_string(JID)}]}]}]}, - send_element(StateData, jlib:iq_to_xml(Res)), - fsm_next_state(wait_for_session, - StateData#state{resource = R2, jid = JID}) - end - end; - _ -> - fsm_next_state(wait_for_bind, StateData) + #iq{type = set, xmlns = ?NS_BIND, sub_el = SubEl} = + IQ -> + U = StateData#state.user, + R1 = xml:get_path_s(SubEl, + [{elem, <<"resource">>}, cdata]), + R = case jlib:resourceprep(R1) of + error -> error; + <<"">> -> + iolist_to_binary([randoms:get_string() + | tuple_to_list(now())]); + Resource -> Resource + end, + case R of + error -> + Err = jlib:make_error_reply(El, ?ERR_BAD_REQUEST), + send_element(StateData, Err), + fsm_next_state(wait_for_bind, StateData); + _ -> + case resource_conflict_action(U, StateData#state.server, + R) + of + closenew -> + Err = jlib:make_error_reply(El, + ?STANZA_ERROR(<<"409">>, + <<"modify">>, + <<"conflict">>)), + send_element(StateData, Err), + fsm_next_state(wait_for_bind, StateData); + {accept_resource, R2} -> + JID = jlib:make_jid(U, StateData#state.server, R2), + Res = IQ#iq{type = result, + sub_el = + [#xmlel{name = <<"bind">>, + attrs = [{<<"xmlns">>, ?NS_BIND}], + children = + [#xmlel{name = <<"jid">>, + attrs = [], + children = + [{xmlcdata, + jlib:jid_to_string(JID)}]}]}]}, + send_element(StateData, jlib:iq_to_xml(Res)), + fsm_next_state(wait_for_session, + StateData#state{resource = R2, jid = JID}) + end + end; + _ -> fsm_next_state(wait_for_bind, StateData) end; - wait_for_bind(timeout, StateData) -> {stop, normal, StateData}; - wait_for_bind({xmlstreamend, _Name}, StateData) -> - send_trailer(StateData), - {stop, normal, StateData}; - + send_trailer(StateData), {stop, normal, StateData}; wait_for_bind({xmlstreamerror, _}, StateData) -> send_element(StateData, ?INVALID_XML_ERR), send_trailer(StateData), {stop, normal, StateData}; - wait_for_bind(closed, StateData) -> {stop, normal, StateData}. - - wait_for_session({xmlstreamelement, El}, StateData) -> case jlib:iq_query_info(El) of #iq{type = set, xmlns = ?NS_SESSION} -> @@ -988,23 +1020,18 @@ wait_for_session({xmlstreamelement, El}, StateData) -> wait_for_session(timeout, StateData) -> {stop, normal, StateData}; - wait_for_session({xmlstreamend, _Name}, StateData) -> - send_trailer(StateData), - {stop, normal, StateData}; - + send_trailer(StateData), {stop, normal, StateData}; wait_for_session({xmlstreamerror, _}, StateData) -> send_element(StateData, ?INVALID_XML_ERR), send_trailer(StateData), {stop, normal, StateData}; - wait_for_session(closed, StateData) -> {stop, normal, StateData}. - -session_established({xmlstreamelement, El}, StateData) -> +session_established({xmlstreamelement, El}, + StateData) -> FromJID = StateData#state.jid, - % Check 'from' attribute in stanza RFC 3920 Section 9.1.2 case check_from(El, FromJID) of 'invalid-from' -> send_element(StateData, ?INVALID_FROM), @@ -1013,124 +1040,109 @@ session_established({xmlstreamelement, El}, StateData) -> _NewEl -> session_established2(El, StateData) end; - %% We hibernate the process to reduce memory consumption after a %% configurable activity timeout session_established(timeout, StateData) -> - %% TODO: Options must be stored in state: Options = [], proc_lib:hibernate(?GEN_FSM, enter_loop, [?MODULE, Options, session_established, StateData]), fsm_next_state(session_established, StateData); - session_established({xmlstreamend, _Name}, StateData) -> + send_trailer(StateData), {stop, normal, StateData}; +session_established({xmlstreamerror, + <<"XML stanza is too big">> = E}, + StateData) -> + send_element(StateData, + ?POLICY_VIOLATION_ERR((StateData#state.lang), E)), send_trailer(StateData), {stop, normal, StateData}; - -session_established({xmlstreamerror, "XML stanza is too big" = E}, StateData) -> - send_element(StateData, ?POLICY_VIOLATION_ERR(StateData#state.lang, E)), - send_trailer(StateData), - {stop, normal, StateData}; - session_established({xmlstreamerror, _}, StateData) -> send_element(StateData, ?INVALID_XML_ERR), send_trailer(StateData), {stop, normal, StateData}; - session_established(closed, StateData) -> {stop, normal, StateData}. %% Process packets sent by user (coming from user on c2s XMPP %% connection) session_established2(El, StateData) -> - {xmlelement, Name, Attrs, _Els} = El, + #xmlel{name = Name, attrs = Attrs} = El, User = StateData#state.user, Server = StateData#state.server, FromJID = StateData#state.jid, - To = xml:get_attr_s("to", Attrs), + To = xml:get_attr_s(<<"to">>, Attrs), ToJID = case To of - "" -> - jlib:make_jid(User, Server, ""); - _ -> - jlib:string_to_jid(To) + <<"">> -> jlib:make_jid(User, Server, <<"">>); + _ -> jlib:string_to_jid(To) end, - NewEl1 = jlib:remove_attr("xmlns", El), - NewEl = case xml:get_attr_s("xml:lang", Attrs) of - "" -> - case StateData#state.lang of - "" -> NewEl1; - Lang -> - xml:replace_tag_attr("xml:lang", Lang, NewEl1) - end; - _ -> - NewEl1 + NewEl1 = jlib:remove_attr(<<"xmlns">>, El), + NewEl = case xml:get_attr_s(<<"xml:lang">>, Attrs) of + <<"">> -> + case StateData#state.lang of + <<"">> -> NewEl1; + Lang -> + xml:replace_tag_attr(<<"xml:lang">>, Lang, NewEl1) + end; + _ -> NewEl1 end, - NewState = - case ToJID of - error -> - case xml:get_attr_s("type", Attrs) of - "error" -> StateData; - "result" -> StateData; - _ -> - Err = jlib:make_error_reply(NewEl, ?ERR_JID_MALFORMED), - send_element(StateData, Err), - StateData - end; - _ -> - case Name of - "presence" -> - PresenceEl = ejabberd_hooks:run_fold( - c2s_update_presence, - Server, - NewEl, - [User, Server]), - ejabberd_hooks:run( - user_send_packet, - Server, - [FromJID, ToJID, PresenceEl]), - case ToJID of - #jid{user = User, - server = Server, - resource = ""} -> - ?DEBUG("presence_update(~p,~n\t~p,~n\t~p)", - [FromJID, PresenceEl, StateData]), - presence_update(FromJID, PresenceEl, - StateData); - _ -> - presence_track(FromJID, ToJID, PresenceEl, - StateData) - end; - "iq" -> - case jlib:iq_query_info(NewEl) of - #iq{xmlns = Xmlns} = IQ - when Xmlns == ?NS_PRIVACY; - Xmlns == ?NS_BLOCKING -> - process_privacy_iq( - FromJID, ToJID, IQ, StateData); - _ -> - ejabberd_hooks:run( - user_send_packet, - Server, - [FromJID, ToJID, NewEl]), - check_privacy_route(FromJID, StateData, FromJID, ToJID, NewEl), - StateData - end; - "message" -> - ejabberd_hooks:run(user_send_packet, - Server, - [FromJID, ToJID, NewEl]), - check_privacy_route(FromJID, StateData, FromJID, - ToJID, NewEl), - StateData; - _ -> - StateData - end - end, - ejabberd_hooks:run(c2s_loop_debug, [{xmlstreamelement, El}]), + NewState = case ToJID of + error -> + case xml:get_attr_s(<<"type">>, Attrs) of + <<"error">> -> StateData; + <<"result">> -> StateData; + _ -> + Err = jlib:make_error_reply(NewEl, + ?ERR_JID_MALFORMED), + send_element(StateData, Err), + StateData + end; + _ -> + case Name of + <<"presence">> -> + PresenceEl = + ejabberd_hooks:run_fold(c2s_update_presence, + Server, NewEl, + [User, Server]), + ejabberd_hooks:run(user_send_packet, Server, + [FromJID, ToJID, PresenceEl]), + case ToJID of + #jid{user = User, server = Server, + resource = <<"">>} -> + ?DEBUG("presence_update(~p,~n\t~p,~n\t~p)", + [FromJID, PresenceEl, StateData]), + presence_update(FromJID, PresenceEl, + StateData); + _ -> + presence_track(FromJID, ToJID, PresenceEl, + StateData) + end; + <<"iq">> -> + case jlib:iq_query_info(NewEl) of + #iq{xmlns = Xmlns} = IQ + when Xmlns == (?NS_PRIVACY); + Xmlns == (?NS_BLOCKING) -> + process_privacy_iq(FromJID, ToJID, IQ, + StateData); + _ -> + ejabberd_hooks:run(user_send_packet, Server, + [FromJID, ToJID, NewEl]), + check_privacy_route(FromJID, StateData, + FromJID, ToJID, NewEl), + StateData + end; + <<"message">> -> + ejabberd_hooks:run(user_send_packet, Server, + [FromJID, ToJID, NewEl]), + check_privacy_route(FromJID, StateData, FromJID, + ToJID, NewEl), + StateData; + _ -> StateData + end + end, + ejabberd_hooks:run(c2s_loop_debug, + [{xmlstreamelement, El}]), fsm_next_state(session_established, NewState). - - %%---------------------------------------------------------------------- %% Func: StateName/3 %% Returns: {next_state, NextStateName, NextStateData} | @@ -1162,24 +1174,22 @@ handle_event(_Event, StateName, StateData) -> %% {stop, Reason, NewStateData} | %% {stop, Reason, Reply, NewStateData} %%---------------------------------------------------------------------- -handle_sync_event({get_presence}, _From, StateName, StateData) -> +handle_sync_event({get_presence}, _From, StateName, + StateData) -> User = StateData#state.user, PresLast = StateData#state.pres_last, - Show = get_showtag(PresLast), Status = get_statustag(PresLast), Resource = StateData#state.resource, - Reply = {User, Resource, Show, Status}, fsm_reply(Reply, StateName, StateData); - -handle_sync_event(get_subscribed, _From, StateName, StateData) -> - Subscribed = ?SETS:to_list(StateData#state.pres_f), +handle_sync_event(get_subscribed, _From, StateName, + StateData) -> + Subscribed = (?SETS):to_list(StateData#state.pres_f), {reply, Subscribed, StateName, StateData}; - -handle_sync_event(_Event, _From, StateName, StateData) -> - Reply = ok, - fsm_reply(Reply, StateName, StateData). +handle_sync_event(_Event, _From, StateName, + StateData) -> + Reply = ok, fsm_reply(Reply, StateName, StateData). code_change(_OldVsn, StateName, StateData, _Extra) -> {ok, StateName, StateData}. @@ -1197,209 +1207,301 @@ handle_info({send_text, Text}, StateName, StateData) -> handle_info(replaced, _StateName, StateData) -> Lang = StateData#state.lang, send_element(StateData, - ?SERRT_CONFLICT(Lang, "Replaced by new connection")), + ?SERRT_CONFLICT(Lang, + <<"Replaced by new connection">>)), send_trailer(StateData), - {stop, normal, StateData#state{authenticated = replaced}}; + {stop, normal, + StateData#state{authenticated = replaced}}; +handle_info({route, _From, _To, {broadcast, Data}}, + StateName, StateData) -> + ?DEBUG("broadcast~n~p~n", [Data]), + case Data of + {item, IJID, ISubscription} -> + fsm_next_state(StateName, + roster_change(IJID, ISubscription, StateData)); + {exit, Reason} -> + Lang = StateData#state.lang, + send_element(StateData, ?SERRT_CONFLICT(Lang, Reason)), + catch send_trailer(StateData), + {stop, normal, StateData}; + {privacy_list, PrivList, PrivListName} -> + case ejabberd_hooks:run_fold(privacy_updated_list, + StateData#state.server, + false, + [StateData#state.privacy_list, + PrivList]) of + false -> + fsm_next_state(StateName, StateData); + NewPL -> + PrivPushIQ = #iq{type = set, + xmlns = ?NS_PRIVACY, + id = <<"push", + (randoms:get_string())/binary>>, + sub_el = + [#xmlel{name = <<"query">>, + attrs = [{<<"xmlns">>, + ?NS_PRIVACY}], + children = + [#xmlel{name = <<"list">>, + attrs = [{<<"name">>, + PrivListName}], + children = []}]}]}, + PrivPushEl = jlib:replace_from_to( + jlib:jid_remove_resource(StateData#state.jid), + StateData#state.jid, + jlib:iq_to_xml(PrivPushIQ)), + send_element(StateData, PrivPushEl), + fsm_next_state(StateName, + StateData#state{privacy_list = NewPL}) + end; + {blocking, What} -> + route_blocking(What, StateData), + fsm_next_state(StateName, StateData); + _ -> + fsm_next_state(StateName, StateData) + end; %% Process Packets that are to be send to the user -handle_info({route, From, To, Packet}, StateName, StateData) -> - {xmlelement, Name, Attrs, Els} = Packet, - {Pass, NewAttrs, NewState} = - case Name of - "presence" -> - State = ejabberd_hooks:run_fold( - c2s_presence_in, StateData#state.server, - StateData, - [{From, To, Packet}]), - case xml:get_attr_s("type", Attrs) of - "probe" -> - LFrom = jlib:jid_tolower(From), - LBFrom = jlib:jid_remove_resource(LFrom), - NewStateData = - case ?SETS:is_element( - LFrom, State#state.pres_a) orelse - ?SETS:is_element( - LBFrom, State#state.pres_a) of - true -> - State; - false -> - case ?SETS:is_element( - LFrom, State#state.pres_f) of - true -> - A = ?SETS:add_element( - LFrom, - State#state.pres_a), - State#state{pres_a = A}; - false -> - case ?SETS:is_element( - LBFrom, State#state.pres_f) of - true -> - A = ?SETS:add_element( - LBFrom, - State#state.pres_a), - State#state{pres_a = A}; - false -> - State - end - end - end, - process_presence_probe(From, To, NewStateData), - {false, Attrs, NewStateData}; - "error" -> - NewA = remove_element(jlib:jid_tolower(From), - State#state.pres_a), - {true, Attrs, State#state{pres_a = NewA}}; - "invisible" -> - Attrs1 = lists:keydelete("type", 1, Attrs), - {true, [{"type", "unavailable"} | Attrs1], State}; - "subscribe" -> - SRes = is_privacy_allow(State, From, To, Packet, in), - {SRes, Attrs, State}; - "subscribed" -> - SRes = is_privacy_allow(State, From, To, Packet, in), - {SRes, Attrs, State}; - "unsubscribe" -> - SRes = is_privacy_allow(State, From, To, Packet, in), - {SRes, Attrs, State}; - "unsubscribed" -> - SRes = is_privacy_allow(State, From, To, Packet, in), - {SRes, Attrs, State}; - _ -> - case privacy_check_packet(State, From, To, Packet, in) of - allow -> - LFrom = jlib:jid_tolower(From), - LBFrom = jlib:jid_remove_resource(LFrom), - case ?SETS:is_element( - LFrom, State#state.pres_a) orelse - ?SETS:is_element( - LBFrom, State#state.pres_a) of - true -> - {true, Attrs, State}; - false -> - case ?SETS:is_element( - LFrom, State#state.pres_f) of - true -> - A = ?SETS:add_element( - LFrom, - State#state.pres_a), - {true, Attrs, - State#state{pres_a = A}}; - false -> - case ?SETS:is_element( - LBFrom, State#state.pres_f) of - true -> - A = ?SETS:add_element( - LBFrom, - State#state.pres_a), - {true, Attrs, - State#state{pres_a = A}}; - false -> - {true, Attrs, State} - end - end - end; - deny -> - {false, Attrs, State} - end - end; - "broadcast" -> - ?DEBUG("broadcast~n~p~n", [Els]), - case Els of - [{item, IJID, ISubscription}] -> - {false, Attrs, - roster_change(IJID, ISubscription, - StateData)}; - [{exit, Reason}] -> - {exit, Attrs, Reason}; - [{privacy_list, PrivList, PrivListName}] -> - case ejabberd_hooks:run_fold( - privacy_updated_list, StateData#state.server, - false, - [StateData#state.privacy_list, - PrivList]) of - false -> - {false, Attrs, StateData}; - NewPL -> - PrivPushIQ = - #iq{type = set, xmlns = ?NS_PRIVACY, - id = "push" ++ randoms:get_string(), - sub_el = [{xmlelement, "query", - [{"xmlns", ?NS_PRIVACY}], - [{xmlelement, "list", - [{"name", PrivListName}], - []}]}]}, - PrivPushEl = - jlib:replace_from_to( - jlib:jid_remove_resource( - StateData#state.jid), - StateData#state.jid, - jlib:iq_to_xml(PrivPushIQ)), - send_element(StateData, PrivPushEl), - {false, Attrs, StateData#state{privacy_list = NewPL}} - end; - [{blocking, What}] -> - route_blocking(What, StateData), - {false, Attrs, StateData}; - _ -> - {false, Attrs, StateData} - end; - "iq" -> - IQ = jlib:iq_query_info(Packet), - case IQ of - #iq{xmlns = ?NS_LAST} -> - LFrom = jlib:jid_tolower(From), - LBFrom = jlib:jid_remove_resource(LFrom), - HasFromSub = (?SETS:is_element(LFrom, StateData#state.pres_f) orelse ?SETS:is_element(LBFrom, StateData#state.pres_f)) - andalso is_privacy_allow(StateData, To, From, {xmlelement, "presence", [], []}, out), - case HasFromSub of - true -> - case privacy_check_packet(StateData, From, To, Packet, in) of - allow -> - {true, Attrs, StateData}; - deny -> - {false, Attrs, StateData} - end; - _ -> - Err = jlib:make_error_reply(Packet, ?ERR_FORBIDDEN), - ejabberd_router:route(To, From, Err), - {false, Attrs, StateData} - end; - IQ when (is_record(IQ, iq)) or (IQ == reply) -> - case privacy_check_packet(StateData, From, To, Packet, in) of - allow -> - {true, Attrs, StateData}; - deny when is_record(IQ, iq) -> - Err = jlib:make_error_reply( - Packet, ?ERR_SERVICE_UNAVAILABLE), - ejabberd_router:route(To, From, Err), - {false, Attrs, StateData}; - deny when IQ == reply -> - {false, Attrs, StateData} - end; - IQ when (IQ == invalid) or (IQ == not_iq) -> - {false, Attrs, StateData} - end; - "message" -> - case privacy_check_packet(StateData, From, To, Packet, in) of - allow -> - {true, Attrs, StateData}; - deny -> - {false, Attrs, StateData} - end; - _ -> - {true, Attrs, StateData} - end, - if - Pass == exit -> +handle_info({route, From, To, + #xmlel{name = Name, attrs = Attrs, children = Els} = Packet}, + StateName, StateData) -> + {Pass, NewAttrs, NewState} = case Name of + <<"presence">> -> + State = + ejabberd_hooks:run_fold(c2s_presence_in, + StateData#state.server, + StateData, + [{From, To, + Packet}]), + case xml:get_attr_s(<<"type">>, Attrs) of + <<"probe">> -> + LFrom = jlib:jid_tolower(From), + LBFrom = + jlib:jid_remove_resource(LFrom), + NewStateData = case + (?SETS):is_element(LFrom, + State#state.pres_a) + orelse + (?SETS):is_element(LBFrom, + State#state.pres_a) + of + true -> State; + false -> + case + (?SETS):is_element(LFrom, + State#state.pres_f) + of + true -> + A = + (?SETS):add_element(LFrom, + State#state.pres_a), + State#state{pres_a + = + A}; + false -> + case + (?SETS):is_element(LBFrom, + State#state.pres_f) + of + true -> + A = + (?SETS):add_element(LBFrom, + State#state.pres_a), + State#state{pres_a + = + A}; + false -> + State + end + end + end, + process_presence_probe(From, To, + NewStateData), + {false, Attrs, NewStateData}; + <<"error">> -> + NewA = + remove_element(jlib:jid_tolower(From), + State#state.pres_a), + {true, Attrs, + State#state{pres_a = NewA}}; + <<"subscribe">> -> + SRes = is_privacy_allow(State, + From, To, + Packet, + in), + {SRes, Attrs, State}; + <<"subscribed">> -> + SRes = is_privacy_allow(State, + From, To, + Packet, + in), + {SRes, Attrs, State}; + <<"unsubscribe">> -> + SRes = is_privacy_allow(State, + From, To, + Packet, + in), + {SRes, Attrs, State}; + <<"unsubscribed">> -> + SRes = is_privacy_allow(State, + From, To, + Packet, + in), + {SRes, Attrs, State}; + _ -> + case privacy_check_packet(State, + From, To, + Packet, + in) + of + allow -> + LFrom = + jlib:jid_tolower(From), + LBFrom = + jlib:jid_remove_resource(LFrom), + case + (?SETS):is_element(LFrom, + State#state.pres_a) + orelse + (?SETS):is_element(LBFrom, + State#state.pres_a) + of + true -> + {true, Attrs, State}; + false -> + case + (?SETS):is_element(LFrom, + State#state.pres_f) + of + true -> + A = + (?SETS):add_element(LFrom, + State#state.pres_a), + {true, Attrs, + State#state{pres_a + = + A}}; + false -> + case + (?SETS):is_element(LBFrom, + State#state.pres_f) + of + true -> + A = + (?SETS):add_element(LBFrom, + State#state.pres_a), + {true, + Attrs, + State#state{pres_a + = + A}}; + false -> + {true, + Attrs, + State} + end + end + end; + deny -> {false, Attrs, State} + end + end; + <<"iq">> -> + IQ = jlib:iq_query_info(Packet), + case IQ of + #iq{xmlns = ?NS_LAST} -> + LFrom = jlib:jid_tolower(From), + LBFrom = + jlib:jid_remove_resource(LFrom), + HasFromSub = + ((?SETS):is_element(LFrom, + StateData#state.pres_f) + orelse + (?SETS):is_element(LBFrom, + StateData#state.pres_f)) + andalso + is_privacy_allow(StateData, + To, From, + #xmlel{name + = + <<"presence">>, + attrs + = + [], + children + = + []}, + out), + case HasFromSub of + true -> + case + privacy_check_packet(StateData, + From, + To, + Packet, + in) + of + allow -> + {true, Attrs, + StateData}; + deny -> + {false, Attrs, + StateData} + end; + _ -> + Err = + jlib:make_error_reply(Packet, + ?ERR_FORBIDDEN), + ejabberd_router:route(To, + From, + Err), + {false, Attrs, StateData} + end; + IQ + when is_record(IQ, iq) or + (IQ == reply) -> + case + privacy_check_packet(StateData, + From, To, + Packet, in) + of + allow -> + {true, Attrs, StateData}; + deny when is_record(IQ, iq) -> + Err = + jlib:make_error_reply(Packet, + ?ERR_SERVICE_UNAVAILABLE), + ejabberd_router:route(To, + From, + Err), + {false, Attrs, StateData}; + deny when IQ == reply -> + {false, Attrs, StateData} + end; + IQ + when (IQ == invalid) or + (IQ == not_iq) -> + {false, Attrs, StateData} + end; + <<"message">> -> + case privacy_check_packet(StateData, + From, To, + Packet, in) + of + allow -> {true, Attrs, StateData}; + deny -> {false, Attrs, StateData} + end; + _ -> {true, Attrs, StateData} + end, + if Pass == exit -> %% When Pass==exit, NewState contains a string instead of a #state{} Lang = StateData#state.lang, send_element(StateData, ?SERRT_CONFLICT(Lang, NewState)), send_trailer(StateData), {stop, normal, StateData}; Pass -> - Attrs2 = jlib:replace_from_to_attrs(jlib:jid_to_string(From), - jlib:jid_to_string(To), - NewAttrs), - FixedPacket = {xmlelement, Name, Attrs2, Els}, + Attrs2 = + jlib:replace_from_to_attrs(jlib:jid_to_string(From), + jlib:jid_to_string(To), NewAttrs), + FixedPacket = #xmlel{name = Name, attrs = Attrs2, children = Els}, send_element(StateData, FixedPacket), ejabberd_hooks:run(user_receive_packet, StateData#state.server, @@ -1410,40 +1512,38 @@ handle_info({route, From, To, Packet}, StateName, StateData) -> ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), fsm_next_state(StateName, NewState) end; -handle_info({'DOWN', Monitor, _Type, _Object, _Info}, _StateName, StateData) - when Monitor == StateData#state.socket_monitor -> +handle_info({'DOWN', Monitor, _Type, _Object, _Info}, + _StateName, StateData) + when Monitor == StateData#state.socket_monitor -> {stop, normal, StateData}; handle_info(system_shutdown, StateName, StateData) -> case StateName of - wait_for_stream -> - send_header(StateData, ?MYNAME, "1.0", "en"), - send_element(StateData, ?SERR_SYSTEM_SHUTDOWN), - send_trailer(StateData), - ok; - _ -> - send_element(StateData, ?SERR_SYSTEM_SHUTDOWN), - send_trailer(StateData), - ok + wait_for_stream -> + send_header(StateData, ?MYNAME, <<"1.0">>, <<"en">>), + send_element(StateData, ?SERR_SYSTEM_SHUTDOWN), + send_trailer(StateData), + ok; + _ -> + send_element(StateData, ?SERR_SYSTEM_SHUTDOWN), + send_trailer(StateData), + ok end, {stop, normal, StateData}; handle_info({force_update_presence, LUser}, StateName, - #state{user = LUser, server = LServer} = StateData) -> - NewStateData = - case StateData#state.pres_last of - {xmlelement, "presence", _Attrs, _Els} -> - PresenceEl = ejabberd_hooks:run_fold( - c2s_update_presence, - LServer, - StateData#state.pres_last, - [LUser, LServer]), - StateData2 = StateData#state{pres_last = PresenceEl}, - presence_update(StateData2#state.jid, - PresenceEl, - StateData2), - StateData2; - _ -> - StateData - end, + #state{user = LUser, server = LServer} = StateData) -> + NewStateData = case StateData#state.pres_last of + #xmlel{name = <<"presence">>} -> + PresenceEl = + ejabberd_hooks:run_fold(c2s_update_presence, + LServer, + StateData#state.pres_last, + [LUser, LServer]), + StateData2 = StateData#state{pres_last = PresenceEl}, + presence_update(StateData2#state.jid, PresenceEl, + StateData2), + StateData2; + _ -> StateData + end, {next_state, StateName, NewStateData}; handle_info({broadcast, Type, From, Packet}, StateName, StateData) -> Recipients = ejabberd_hooks:run_fold( @@ -1480,61 +1580,59 @@ print_state(State = #state{pres_t = T, pres_f = F, pres_a = A, pres_i = I}) -> %%---------------------------------------------------------------------- terminate(_Reason, StateName, StateData) -> case StateName of - session_established -> - case StateData#state.authenticated of - replaced -> - ?INFO_MSG("(~w) Replaced session for ~s", - [StateData#state.socket, - jlib:jid_to_string(StateData#state.jid)]), - From = StateData#state.jid, - Packet = {xmlelement, "presence", - [{"type", "unavailable"}], - [{xmlelement, "status", [], - [{xmlcdata, "Replaced by new connection"}]}]}, - ejabberd_sm:close_session_unset_presence( - StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource, - "Replaced by new connection"), - presence_broadcast( - StateData, From, StateData#state.pres_a, Packet), - presence_broadcast( - StateData, From, StateData#state.pres_i, Packet); - _ -> - ?INFO_MSG("(~w) Close session for ~s", - [StateData#state.socket, - jlib:jid_to_string(StateData#state.jid)]), - - EmptySet = ?SETS:new(), - case StateData of - #state{pres_last = undefined, - pres_a = EmptySet, - pres_i = EmptySet, - pres_invis = false} -> - ejabberd_sm:close_session(StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource); - _ -> - From = StateData#state.jid, - Packet = {xmlelement, "presence", - [{"type", "unavailable"}], []}, - ejabberd_sm:close_session_unset_presence( - StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource, - ""), - presence_broadcast( - StateData, From, StateData#state.pres_a, Packet), - presence_broadcast( - StateData, From, StateData#state.pres_i, Packet) - end - end, - bounce_messages(); - _ -> - ok + session_established -> + case StateData#state.authenticated of + replaced -> + ?INFO_MSG("(~w) Replaced session for ~s", + [StateData#state.socket, + jlib:jid_to_string(StateData#state.jid)]), + From = StateData#state.jid, + Packet = #xmlel{name = <<"presence">>, + attrs = [{<<"type">>, <<"unavailable">>}], + children = + [#xmlel{name = <<"status">>, attrs = [], + children = + [{xmlcdata, + <<"Replaced by new connection">>}]}]}, + ejabberd_sm:close_session_unset_presence(StateData#state.sid, + StateData#state.user, + StateData#state.server, + StateData#state.resource, + <<"Replaced by new connection">>), + presence_broadcast(StateData, From, + StateData#state.pres_a, Packet), + presence_broadcast(StateData, From, + StateData#state.pres_i, Packet); + _ -> + ?INFO_MSG("(~w) Close session for ~s", + [StateData#state.socket, + jlib:jid_to_string(StateData#state.jid)]), + EmptySet = (?SETS):new(), + case StateData of + #state{pres_last = undefined, pres_a = EmptySet, pres_i = EmptySet, pres_invis = false} -> + ejabberd_sm:close_session(StateData#state.sid, + StateData#state.user, + StateData#state.server, + StateData#state.resource); + _ -> + From = StateData#state.jid, + Packet = #xmlel{name = <<"presence">>, + attrs = [{<<"type">>, <<"unavailable">>}], + children = []}, + ejabberd_sm:close_session_unset_presence(StateData#state.sid, + StateData#state.user, + StateData#state.server, + StateData#state.resource, + <<"">>), + presence_broadcast(StateData, From, + StateData#state.pres_a, Packet), + presence_broadcast(StateData, From, + StateData#state.pres_i, Packet) + end + end, + bounce_messages(); + _ -> + ok end, (StateData#state.sockmod):close(StateData#state.socket), ok. @@ -1546,10 +1644,11 @@ terminate(_Reason, StateName, StateData) -> change_shaper(StateData, JID) -> Shaper = acl:match_rule(StateData#state.server, StateData#state.shaper, JID), - (StateData#state.sockmod):change_shaper(StateData#state.socket, Shaper). + (StateData#state.sockmod):change_shaper(StateData#state.socket, + Shaper). send_text(StateData, Text) when StateData#state.xml_socket -> - ?DEBUG("Send Text on stream = ~p", [lists:flatten(Text)]), + ?DEBUG("Send Text on stream = ~p", [Text]), (StateData#state.sockmod):send_xml(StateData#state.socket, {xmlstreamraw, Text}); send_text(StateData, Text) -> @@ -1563,82 +1662,67 @@ send_element(StateData, El) -> send_text(StateData, xml:element_to_binary(El)). send_header(StateData, Server, Version, Lang) - when StateData#state.xml_socket -> - VersionAttr = - case Version of - "" -> []; - _ -> [{"version", Version}] - end, - LangAttr = - case Lang of - "" -> []; - _ -> [{"xml:lang", Lang}] - end, - Header = - {xmlstreamstart, - "stream:stream", - VersionAttr ++ - LangAttr ++ - [{"xmlns", "jabber:client"}, - {"xmlns:stream", "http://etherx.jabber.org/streams"}, - {"id", StateData#state.streamid}, - {"from", Server}]}, - (StateData#state.sockmod):send_xml( - StateData#state.socket, Header); + when StateData#state.xml_socket -> + VersionAttr = case Version of + <<"">> -> []; + _ -> [{<<"version">>, Version}] + end, + LangAttr = case Lang of + <<"">> -> []; + _ -> [{<<"xml:lang">>, Lang}] + end, + Header = {xmlstreamstart, <<"stream:stream">>, + VersionAttr ++ + LangAttr ++ + [{<<"xmlns">>, <<"jabber:client">>}, + {<<"xmlns:stream">>, + <<"http://etherx.jabber.org/streams">>}, + {<<"id">>, StateData#state.streamid}, + {<<"from">>, Server}]}, + (StateData#state.sockmod):send_xml(StateData#state.socket, + Header); send_header(StateData, Server, Version, Lang) -> - VersionStr = - case Version of - "" -> ""; - _ -> [" version='", Version, "'"] - end, - LangStr = - case Lang of - "" -> ""; - _ -> [" xml:lang='", Lang, "'"] - end, + VersionStr = case Version of + <<"">> -> <<"">>; + _ -> [<<" version='">>, Version, <<"'">>] + end, + LangStr = case Lang of + <<"">> -> <<"">>; + _ -> [<<" xml:lang='">>, Lang, <<"'">>] + end, Header = io_lib:format(?STREAM_HEADER, - [StateData#state.streamid, - Server, - VersionStr, + [StateData#state.streamid, Server, VersionStr, LangStr]), - send_text(StateData, Header). + send_text(StateData, iolist_to_binary(Header)). -send_trailer(StateData) when StateData#state.xml_socket -> - (StateData#state.sockmod):send_xml( - StateData#state.socket, - {xmlstreamend, "stream:stream"}); +send_trailer(StateData) + when StateData#state.xml_socket -> + (StateData#state.sockmod):send_xml(StateData#state.socket, + {xmlstreamend, <<"stream:stream">>}); send_trailer(StateData) -> send_text(StateData, ?STREAM_TRAILER). - -new_id() -> - randoms:get_string(). - +new_id() -> randoms:get_string(). is_auth_packet(El) -> case jlib:iq_query_info(El) of - #iq{id = ID, type = Type, xmlns = ?NS_AUTH, sub_el = SubEl} -> - {xmlelement, _, _, Els} = SubEl, - {auth, ID, Type, - get_auth_tags(Els, "", "", "", "")}; - _ -> - false + #iq{id = ID, type = Type, xmlns = ?NS_AUTH, + sub_el = SubEl} -> + #xmlel{children = Els} = SubEl, + {auth, ID, Type, + get_auth_tags(Els, <<"">>, <<"">>, <<"">>, <<"">>)}; + _ -> false end. - -get_auth_tags([{xmlelement, Name, _Attrs, Els}| L], U, P, D, R) -> +get_auth_tags([#xmlel{name = Name, children = Els} | L], + U, P, D, R) -> CData = xml:get_cdata(Els), case Name of - "username" -> - get_auth_tags(L, CData, P, D, R); - "password" -> - get_auth_tags(L, U, CData, D, R); - "digest" -> - get_auth_tags(L, U, P, CData, R); - "resource" -> - get_auth_tags(L, U, P, D, CData); - _ -> - get_auth_tags(L, U, P, D, R) + <<"username">> -> get_auth_tags(L, CData, P, D, R); + <<"password">> -> get_auth_tags(L, U, CData, D, R); + <<"digest">> -> get_auth_tags(L, U, P, CData, R); + <<"resource">> -> get_auth_tags(L, U, P, D, CData); + _ -> get_auth_tags(L, U, P, D, R) end; get_auth_tags([_ | L], U, P, D, R) -> get_auth_tags(L, U, P, D, R); @@ -1664,7 +1748,7 @@ get_conn_type(StateData) -> process_presence_probe(From, To, StateData) -> LFrom = jlib:jid_tolower(From), - LBFrom = setelement(3, LFrom, ""), + LBFrom = setelement(3, LFrom, <<"">>), case StateData#state.pres_last of undefined -> ok; @@ -1688,7 +1772,7 @@ process_presence_probe(From, To, StateData) -> Packet = xml:append_subtags( StateData#state.pres_last, %% To is the one sending the presence (the target of the probe) - [jlib:timestamp_to_xml(Timestamp, utc, To, ""), + [jlib:timestamp_to_xml(Timestamp, utc, To, <<"">>), %% TODO: Delete the next line once XEP-0091 is Obsolete jlib:timestamp_to_xml(Timestamp)]), case privacy_check_packet(StateData, To, From, Packet, out) of @@ -1707,9 +1791,9 @@ process_presence_probe(From, To, StateData) -> end; Cond2 -> ejabberd_router:route(To, From, - {xmlelement, "presence", - [], - []}); + #xmlel{name = <<"presence">>, + attrs = [], + children = []}); true -> ok end @@ -1717,500 +1801,414 @@ process_presence_probe(From, To, StateData) -> %% User updates his presence (non-directed presence packet) presence_update(From, Packet, StateData) -> - {xmlelement, _Name, Attrs, _Els} = Packet, - case xml:get_attr_s("type", Attrs) of - "unavailable" -> - Status = case xml:get_subtag(Packet, "status") of - false -> - ""; - StatusTag -> - xml:get_tag_cdata(StatusTag) + #xmlel{attrs = Attrs} = Packet, + case xml:get_attr_s(<<"type">>, Attrs) of + <<"unavailable">> -> + Status = case xml:get_subtag(Packet, <<"status">>) of + false -> <<"">>; + StatusTag -> xml:get_tag_cdata(StatusTag) + end, + Info = [{ip, StateData#state.ip}, + {conn, StateData#state.conn}, + {auth_module, StateData#state.auth_module}], + ejabberd_sm:unset_presence(StateData#state.sid, + StateData#state.user, + StateData#state.server, + StateData#state.resource, Status, Info), + presence_broadcast(StateData, From, + StateData#state.pres_a, Packet), + StateData#state{pres_last = undefined, + pres_timestamp = undefined, pres_a = (?SETS):new()}; + <<"error">> -> StateData; + <<"probe">> -> StateData; + <<"subscribe">> -> StateData; + <<"subscribed">> -> StateData; + <<"unsubscribe">> -> StateData; + <<"unsubscribed">> -> StateData; + _ -> + OldPriority = case StateData#state.pres_last of + undefined -> 0; + OldPresence -> get_priority_from_presence(OldPresence) + end, + NewPriority = get_priority_from_presence(Packet), + Timestamp = calendar:now_to_universal_time(now()), + update_priority(NewPriority, Packet, StateData), + FromUnavail = (StateData#state.pres_last == undefined), + ?DEBUG("from unavail = ~p~n", [FromUnavail]), + NewStateData = StateData#state{pres_last = Packet, + pres_timestamp = Timestamp}, + NewState = if FromUnavail -> + ejabberd_hooks:run(user_available_hook, + NewStateData#state.server, + [NewStateData#state.jid]), + if NewPriority >= 0 -> + resend_offline_messages(NewStateData), + resend_subscription_requests(NewStateData); + true -> ok + end, + presence_broadcast_first(From, NewStateData, + Packet); + true -> + presence_broadcast_to_trusted(NewStateData, From, + NewStateData#state.pres_f, + NewStateData#state.pres_a, + Packet), + if OldPriority < 0, NewPriority >= 0 -> + resend_offline_messages(NewStateData); + true -> ok + end, + NewStateData end, - Info = [{ip, StateData#state.ip}, {conn, StateData#state.conn}, - {auth_module, StateData#state.auth_module}], - ejabberd_sm:unset_presence(StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource, - Status, - Info), - presence_broadcast(StateData, From, StateData#state.pres_a, Packet), - presence_broadcast(StateData, From, StateData#state.pres_i, Packet), - StateData#state{pres_last = undefined, - pres_timestamp = undefined, - pres_a = ?SETS:new(), - pres_i = ?SETS:new(), - pres_invis = false}; - "invisible" -> - NewPriority = get_priority_from_presence(Packet), - update_priority(NewPriority, Packet, StateData), - NewState = - if - not StateData#state.pres_invis -> - presence_broadcast(StateData, From, - StateData#state.pres_a, - Packet), - presence_broadcast(StateData, From, - StateData#state.pres_i, - Packet), - S1 = StateData#state{pres_last = undefined, - pres_timestamp = undefined, - pres_a = ?SETS:new(), - pres_i = ?SETS:new(), - pres_invis = true}, - presence_broadcast_first(From, S1, Packet); - true -> - StateData - end, - NewState; - "error" -> - StateData; - "probe" -> - StateData; - "subscribe" -> - StateData; - "subscribed" -> - StateData; - "unsubscribe" -> - StateData; - "unsubscribed" -> - StateData; - _ -> - OldPriority = case StateData#state.pres_last of - undefined -> - 0; - OldPresence -> - get_priority_from_presence(OldPresence) - end, - NewPriority = get_priority_from_presence(Packet), - Timestamp = calendar:now_to_universal_time(now()), - update_priority(NewPriority, Packet, StateData), - FromUnavail = (StateData#state.pres_last == undefined) or - StateData#state.pres_invis, - ?DEBUG("from unavail = ~p~n", [FromUnavail]), - NewStateData = StateData#state{pres_last = Packet, - pres_invis = false, - pres_timestamp = Timestamp}, - NewState = - if - FromUnavail -> - ejabberd_hooks:run(user_available_hook, - NewStateData#state.server, - [NewStateData#state.jid]), - if NewPriority >= 0 -> - resend_offline_messages(NewStateData), - resend_subscription_requests(NewStateData); - true -> - ok - end, - presence_broadcast_first(From, NewStateData, Packet); - true -> - presence_broadcast_to_trusted(NewStateData, - From, - NewStateData#state.pres_f, - NewStateData#state.pres_a, - Packet), - if OldPriority < 0, NewPriority >= 0 -> - resend_offline_messages(NewStateData); - true -> - ok - end, - NewStateData - end, - NewState + NewState end. %% User sends a directed presence packet presence_track(From, To, Packet, StateData) -> - {xmlelement, _Name, Attrs, _Els} = Packet, + #xmlel{attrs = Attrs} = Packet, LTo = jlib:jid_tolower(To), User = StateData#state.user, Server = StateData#state.server, - case xml:get_attr_s("type", Attrs) of - "unavailable" -> - check_privacy_route(From, StateData, From, To, Packet), - I = remove_element(LTo, StateData#state.pres_i), - A = remove_element(LTo, StateData#state.pres_a), - StateData#state{pres_i = I, - pres_a = A}; - "invisible" -> - check_privacy_route(From, StateData, From, To, Packet), - I = ?SETS:add_element(LTo, StateData#state.pres_i), - A = remove_element(LTo, StateData#state.pres_a), - StateData#state{pres_i = I, - pres_a = A}; - "subscribe" -> - ejabberd_hooks:run(roster_out_subscription, - Server, - [User, Server, To, subscribe]), - check_privacy_route(From, StateData, jlib:jid_remove_resource(From), - To, Packet), - StateData; - "subscribed" -> - ejabberd_hooks:run(roster_out_subscription, - Server, - [User, Server, To, subscribed]), - check_privacy_route(From, StateData, jlib:jid_remove_resource(From), - To, Packet), - StateData; - "unsubscribe" -> - ejabberd_hooks:run(roster_out_subscription, - Server, - [User, Server, To, unsubscribe]), - check_privacy_route(From, StateData, jlib:jid_remove_resource(From), - To, Packet), - StateData; - "unsubscribed" -> - ejabberd_hooks:run(roster_out_subscription, - Server, - [User, Server, To, unsubscribed]), - check_privacy_route(From, StateData, jlib:jid_remove_resource(From), - To, Packet), - StateData; - "error" -> - check_privacy_route(From, StateData, From, To, Packet), - StateData; - "probe" -> - check_privacy_route(From, StateData, From, To, Packet), - StateData; - _ -> - check_privacy_route(From, StateData, From, To, Packet), - I = remove_element(LTo, StateData#state.pres_i), - A = ?SETS:add_element(LTo, StateData#state.pres_a), - StateData#state{pres_i = I, - pres_a = A} + case xml:get_attr_s(<<"type">>, Attrs) of + <<"unavailable">> -> + check_privacy_route(From, StateData, From, To, Packet), + A = remove_element(LTo, StateData#state.pres_a), + StateData#state{pres_a = A}; + <<"subscribe">> -> + ejabberd_hooks:run(roster_out_subscription, Server, + [User, Server, To, subscribe]), + check_privacy_route(From, StateData, + jlib:jid_remove_resource(From), To, Packet), + StateData; + <<"subscribed">> -> + ejabberd_hooks:run(roster_out_subscription, Server, + [User, Server, To, subscribed]), + check_privacy_route(From, StateData, + jlib:jid_remove_resource(From), To, Packet), + StateData; + <<"unsubscribe">> -> + ejabberd_hooks:run(roster_out_subscription, Server, + [User, Server, To, unsubscribe]), + check_privacy_route(From, StateData, + jlib:jid_remove_resource(From), To, Packet), + StateData; + <<"unsubscribed">> -> + ejabberd_hooks:run(roster_out_subscription, Server, + [User, Server, To, unsubscribed]), + check_privacy_route(From, StateData, + jlib:jid_remove_resource(From), To, Packet), + StateData; + <<"error">> -> + check_privacy_route(From, StateData, From, To, Packet), + StateData; + <<"probe">> -> + check_privacy_route(From, StateData, From, To, Packet), + StateData; + _ -> + check_privacy_route(From, StateData, From, To, Packet), + A = (?SETS):add_element(LTo, StateData#state.pres_a), + StateData#state{pres_a = A} end. -check_privacy_route(From, StateData, FromRoute, To, Packet) -> - case privacy_check_packet(StateData, From, To, Packet, out) of - deny -> - Lang = StateData#state.lang, - ErrText = "Your active privacy list has denied the routing of this stanza.", - Err = jlib:make_error_reply(Packet, ?ERRT_NOT_ACCEPTABLE(Lang, ErrText)), - ejabberd_router:route(To, From, Err), - ok; - allow -> - ejabberd_router:route(FromRoute, To, Packet) +check_privacy_route(From, StateData, FromRoute, To, + Packet) -> + case privacy_check_packet(StateData, From, To, Packet, + out) + of + deny -> + Lang = StateData#state.lang, + ErrText = <<"Your active privacy list has denied " + "the routing of this stanza.">>, + Err = jlib:make_error_reply(Packet, + ?ERRT_NOT_ACCEPTABLE(Lang, ErrText)), + ejabberd_router:route(To, From, Err), + ok; + allow -> ejabberd_router:route(FromRoute, To, Packet) end. -privacy_check_packet(StateData, From, To, Packet, Dir) -> - ejabberd_hooks:run_fold( - privacy_check_packet, StateData#state.server, - allow, - [StateData#state.user, - StateData#state.server, - StateData#state.privacy_list, - {From, To, Packet}, - Dir]). - %% Check if privacy rules allow this delivery +privacy_check_packet(StateData, From, To, Packet, + Dir) -> + ejabberd_hooks:run_fold(privacy_check_packet, + StateData#state.server, allow, + [StateData#state.user, StateData#state.server, + StateData#state.privacy_list, {From, To, Packet}, + Dir]). + is_privacy_allow(StateData, From, To, Packet, Dir) -> - allow == privacy_check_packet(StateData, From, To, Packet, Dir). + allow == + privacy_check_packet(StateData, From, To, Packet, Dir). +%% Send presence when disconnecting presence_broadcast(StateData, From, JIDSet, Packet) -> - lists:foreach(fun(JID) -> - FJID = jlib:make_jid(JID), - case privacy_check_packet(StateData, From, FJID, Packet, out) of - deny -> - ok; - allow -> - ejabberd_router:route(From, FJID, Packet) - end - end, ?SETS:to_list(JIDSet)). - -presence_broadcast_to_trusted(StateData, From, T, A, Packet) -> - lists:foreach( - fun(JID) -> - case ?SETS:is_element(JID, T) of - true -> - FJID = jlib:make_jid(JID), - case privacy_check_packet(StateData, From, FJID, Packet, out) of - deny -> - ok; - allow -> - ejabberd_router:route(From, FJID, Packet) - end; - _ -> - ok - end - end, ?SETS:to_list(A)). + JIDs = ?SETS:to_list(JIDSet), + JIDs2 = format_and_check_privacy(From, StateData, Packet, JIDs, out), + Server = StateData#state.server, + send_multiple(From, Server, JIDs2, Packet). +%% Send presence when updating presence +presence_broadcast_to_trusted(StateData, From, Trusted, JIDSet, Packet) -> + JIDs = ?SETS:to_list(JIDSet), + JIDs_trusted = [JID || JID <- JIDs, ?SETS:is_element(JID, Trusted)], + JIDs2 = format_and_check_privacy(From, StateData, Packet, JIDs_trusted, out), + Server = StateData#state.server, + send_multiple(From, Server, JIDs2, Packet). +%% Send presence when connecting presence_broadcast_first(From, StateData, Packet) -> - ?SETS:fold(fun(JID, X) -> - ejabberd_router:route( - From, - jlib:make_jid(JID), - {xmlelement, "presence", - [{"type", "probe"}], - []}), - X - end, - [], - StateData#state.pres_t), - if - StateData#state.pres_invis -> - StateData; - true -> - As = ?SETS:fold( - fun(JID, A) -> - FJID = jlib:make_jid(JID), - case privacy_check_packet(StateData, From, FJID, Packet, out) of - deny -> - ok; - allow -> - ejabberd_router:route(From, FJID, Packet) - end, - ?SETS:add_element(JID, A) - end, - StateData#state.pres_a, - StateData#state.pres_f), - StateData#state{pres_a = As} - end. + JIDsProbe = + ?SETS:fold( + fun(JID, L) -> [JID | L] end, + [], + StateData#state.pres_t), + PacketProbe = #xmlel{name = <<"presence">>, attrs = [{<<"type">>,<<"probe">>}], children = []}, + JIDs2Probe = format_and_check_privacy(From, StateData, Packet, JIDsProbe, out), + Server = StateData#state.server, + send_multiple(From, Server, JIDs2Probe, PacketProbe), + {As, JIDs} = + ?SETS:fold( + fun(JID, {A, JID_list}) -> + {?SETS:add_element(JID, A), JID_list++[JID]} + end, + {StateData#state.pres_a, []}, + StateData#state.pres_f), + JIDs2 = format_and_check_privacy(From, StateData, Packet, JIDs, out), + Server = StateData#state.server, + send_multiple(From, Server, JIDs2, Packet), + StateData#state{pres_a = As}. + +format_and_check_privacy(From, StateData, Packet, JIDs, Dir) -> + FJIDs = [jlib:make_jid(JID) || JID <- JIDs], + lists:filter( + fun(FJID) -> + case ejabberd_hooks:run_fold( + privacy_check_packet, StateData#state.server, + allow, + [StateData#state.user, + StateData#state.server, + StateData#state.privacy_list, + {From, FJID, Packet}, + Dir]) of + deny -> false; + allow -> true + end + end, + FJIDs). + +send_multiple(From, Server, JIDs, Packet) -> + ejabberd_router_multicast:route_multicast(From, Server, JIDs, Packet). remove_element(E, Set) -> - case ?SETS:is_element(E, Set) of - true -> - ?SETS:del_element(E, Set); - _ -> - Set + case (?SETS):is_element(E, Set) of + true -> (?SETS):del_element(E, Set); + _ -> Set end. - roster_change(IJID, ISubscription, StateData) -> LIJID = jlib:jid_tolower(IJID), - IsFrom = (ISubscription == both) or (ISubscription == from), - IsTo = (ISubscription == both) or (ISubscription == to), - OldIsFrom = ?SETS:is_element(LIJID, StateData#state.pres_f), - FSet = if - IsFrom -> - ?SETS:add_element(LIJID, StateData#state.pres_f); - true -> - remove_element(LIJID, StateData#state.pres_f) + IsFrom = (ISubscription == both) or + (ISubscription == from), + IsTo = (ISubscription == both) or (ISubscription == to), + OldIsFrom = (?SETS):is_element(LIJID, + StateData#state.pres_f), + FSet = if IsFrom -> + (?SETS):add_element(LIJID, StateData#state.pres_f); + true -> remove_element(LIJID, StateData#state.pres_f) end, - TSet = if - IsTo -> - ?SETS:add_element(LIJID, StateData#state.pres_t); - true -> - remove_element(LIJID, StateData#state.pres_t) + TSet = if IsTo -> + (?SETS):add_element(LIJID, StateData#state.pres_t); + true -> remove_element(LIJID, StateData#state.pres_t) end, case StateData#state.pres_last of - undefined -> - StateData#state{pres_f = FSet, pres_t = TSet}; - P -> - ?DEBUG("roster changed for ~p~n", [StateData#state.user]), - From = StateData#state.jid, - To = jlib:make_jid(IJID), - Cond1 = (not StateData#state.pres_invis) and IsFrom - and (not OldIsFrom), - Cond2 = (not IsFrom) and OldIsFrom - and (?SETS:is_element(LIJID, StateData#state.pres_a) or - ?SETS:is_element(LIJID, StateData#state.pres_i)), - if - Cond1 -> - ?DEBUG("C1: ~p~n", [LIJID]), - case privacy_check_packet(StateData, From, To, P, out) of - deny -> - ok; - allow -> - ejabberd_router:route(From, To, P) - end, - A = ?SETS:add_element(LIJID, - StateData#state.pres_a), - StateData#state{pres_a = A, - pres_f = FSet, - pres_t = TSet}; - Cond2 -> - ?DEBUG("C2: ~p~n", [LIJID]), - PU = {xmlelement, "presence", - [{"type", "unavailable"}], []}, - case privacy_check_packet(StateData, From, To, PU, out) of - deny -> - ok; - allow -> - ejabberd_router:route(From, To, PU) - end, - I = remove_element(LIJID, - StateData#state.pres_i), - A = remove_element(LIJID, - StateData#state.pres_a), - StateData#state{pres_i = I, - pres_a = A, - pres_f = FSet, - pres_t = TSet}; - true -> - StateData#state{pres_f = FSet, pres_t = TSet} - end + undefined -> + StateData#state{pres_f = FSet, pres_t = TSet}; + P -> + ?DEBUG("roster changed for ~p~n", + [StateData#state.user]), + From = StateData#state.jid, + To = jlib:make_jid(IJID), + Cond1 = IsFrom andalso not OldIsFrom, + Cond2 = not IsFrom andalso OldIsFrom andalso + ((?SETS):is_element(LIJID, StateData#state.pres_a)), + if Cond1 -> + ?DEBUG("C1: ~p~n", [LIJID]), + case privacy_check_packet(StateData, From, To, P, out) + of + deny -> ok; + allow -> ejabberd_router:route(From, To, P) + end, + A = (?SETS):add_element(LIJID, StateData#state.pres_a), + StateData#state{pres_a = A, pres_f = FSet, + pres_t = TSet}; + Cond2 -> + ?DEBUG("C2: ~p~n", [LIJID]), + PU = #xmlel{name = <<"presence">>, + attrs = [{<<"type">>, <<"unavailable">>}], + children = []}, + case privacy_check_packet(StateData, From, To, PU, out) + of + deny -> ok; + allow -> ejabberd_router:route(From, To, PU) + end, + A = remove_element(LIJID, StateData#state.pres_a), + StateData#state{pres_a = A, pres_f = FSet, + pres_t = TSet}; + true -> StateData#state{pres_f = FSet, pres_t = TSet} + end end. - update_priority(Priority, Packet, StateData) -> Info = [{ip, StateData#state.ip}, {conn, StateData#state.conn}, {auth_module, StateData#state.auth_module}], ejabberd_sm:set_presence(StateData#state.sid, - StateData#state.user, - StateData#state.server, - StateData#state.resource, - Priority, - Packet, - Info). + StateData#state.user, StateData#state.server, + StateData#state.resource, Priority, Packet, Info). get_priority_from_presence(PresencePacket) -> - case xml:get_subtag(PresencePacket, "priority") of - false -> - 0; - SubEl -> - case catch list_to_integer(xml:get_tag_cdata(SubEl)) of - P when is_integer(P) -> - P; - _ -> - 0 - end + case xml:get_subtag(PresencePacket, <<"priority">>) of + false -> 0; + SubEl -> + case catch + jlib:binary_to_integer(xml:get_tag_cdata(SubEl)) + of + P when is_integer(P) -> P; + _ -> 0 + end end. process_privacy_iq(From, To, - #iq{type = Type, sub_el = SubEl} = IQ, - StateData) -> - {Res, NewStateData} = - case Type of - get -> - R = ejabberd_hooks:run_fold( - privacy_iq_get, StateData#state.server, - {error, ?ERR_FEATURE_NOT_IMPLEMENTED}, - [From, To, IQ, StateData#state.privacy_list]), - {R, StateData}; - set -> - case ejabberd_hooks:run_fold( - privacy_iq_set, StateData#state.server, - {error, ?ERR_FEATURE_NOT_IMPLEMENTED}, - [From, To, IQ]) of - {result, R, NewPrivList} -> - {{result, R}, - StateData#state{privacy_list = NewPrivList}}; - R -> {R, StateData} - end - end, - IQRes = - case Res of - {result, Result} -> - IQ#iq{type = result, sub_el = Result}; - {error, Error} -> - IQ#iq{type = error, sub_el = [SubEl, Error]} - end, - ejabberd_router:route( - To, From, jlib:iq_to_xml(IQRes)), + #iq{type = Type, sub_el = SubEl} = IQ, StateData) -> + {Res, NewStateData} = case Type of + get -> + R = ejabberd_hooks:run_fold(privacy_iq_get, + StateData#state.server, + {error, + ?ERR_FEATURE_NOT_IMPLEMENTED}, + [From, To, IQ, + StateData#state.privacy_list]), + {R, StateData}; + set -> + case ejabberd_hooks:run_fold(privacy_iq_set, + StateData#state.server, + {error, + ?ERR_FEATURE_NOT_IMPLEMENTED}, + [From, To, IQ]) + of + {result, R, NewPrivList} -> + {{result, R}, + StateData#state{privacy_list = + NewPrivList}}; + R -> {R, StateData} + end + end, + IQRes = case Res of + {result, Result} -> + IQ#iq{type = result, sub_el = Result}; + {error, Error} -> + IQ#iq{type = error, sub_el = [SubEl, Error]} + end, + ejabberd_router:route(To, From, jlib:iq_to_xml(IQRes)), NewStateData. - resend_offline_messages(StateData) -> - case ejabberd_hooks:run_fold( - resend_offline_messages_hook, StateData#state.server, - [], - [StateData#state.user, StateData#state.server]) of - Rs when is_list(Rs) -> - lists:foreach( - fun({route, - From, To, {xmlelement, _Name, _Attrs, _Els} = Packet}) -> - Pass = case privacy_check_packet(StateData, From, To, Packet, in) of - allow -> - true; - deny -> - false - end, - if - Pass -> - %% Attrs2 = jlib:replace_from_to_attrs( - %% jlib:jid_to_string(From), - %% jlib:jid_to_string(To), - %% Attrs), - %% FixedPacket = {xmlelement, Name, Attrs2, Els}, - %% Use route instead of send_element to go through standard workflow - ejabberd_router:route(From, To, Packet); - %% send_element(StateData, FixedPacket), - %% ejabberd_hooks:run(user_receive_packet, - %% StateData#state.server, - %% [StateData#state.jid, - %% From, To, FixedPacket]); - true -> - ok - end - end, Rs) + case + ejabberd_hooks:run_fold(resend_offline_messages_hook, + StateData#state.server, [], + [StateData#state.user, StateData#state.server]) + of + Rs -> %%when is_list(Rs) -> + lists:foreach(fun ({route, From, To, + #xmlel{} = Packet}) -> + Pass = case privacy_check_packet(StateData, + From, To, + Packet, in) + of + allow -> true; + deny -> false + end, + if Pass -> + ejabberd_router:route(From, To, Packet); + %% send_element(StateData, FixedPacket), + %% ejabberd_hooks:run(user_receive_packet, + %% StateData#state.server, + %% [StateData#state.jid, + %% From, To, FixedPacket]); + true -> ok + end + end, + Rs) end. resend_subscription_requests(#state{user = User, - server = Server} = StateData) -> - PendingSubscriptions = ejabberd_hooks:run_fold( - resend_subscription_requests_hook, - Server, - [], - [User, Server]), - lists:foreach(fun(XMLPacket) -> - send_element(StateData, - XMLPacket) + server = Server} = + StateData) -> + PendingSubscriptions = + ejabberd_hooks:run_fold(resend_subscription_requests_hook, + Server, [], [User, Server]), + lists:foreach(fun (XMLPacket) -> + send_element(StateData, XMLPacket) end, PendingSubscriptions). -get_showtag(undefined) -> - "unavailable"; +get_showtag(undefined) -> <<"unavailable">>; get_showtag(Presence) -> - case xml:get_path_s(Presence, [{elem, "show"}, cdata]) of - "" -> "available"; - ShowTag -> ShowTag + case xml:get_path_s(Presence, + [{elem, <<"show">>}, cdata]) + of + <<"">> -> <<"available">>; + ShowTag -> ShowTag end. -get_statustag(undefined) -> - ""; +get_statustag(undefined) -> <<"">>; get_statustag(Presence) -> - case xml:get_path_s(Presence, [{elem, "status"}, cdata]) of - ShowTag -> ShowTag + case xml:get_path_s(Presence, + [{elem, <<"status">>}, cdata]) + of + ShowTag -> ShowTag end. process_unauthenticated_stanza(StateData, El) -> - NewEl = case xml:get_tag_attr_s("xml:lang", El) of - "" -> - case StateData#state.lang of - "" -> El; - Lang -> - xml:replace_tag_attr("xml:lang", Lang, El) - end; - _ -> - El + NewEl = case xml:get_tag_attr_s(<<"xml:lang">>, El) of + <<"">> -> + case StateData#state.lang of + <<"">> -> El; + Lang -> xml:replace_tag_attr(<<"xml:lang">>, Lang, El) + end; + _ -> El end, case jlib:iq_query_info(NewEl) of - #iq{} = IQ -> - Res = ejabberd_hooks:run_fold(c2s_unauthenticated_iq, - StateData#state.server, - empty, - [StateData#state.server, IQ, - StateData#state.ip]), - case Res of - empty -> - % The only reasonable IQ's here are auth and register IQ's - % They contain secrets, so don't include subelements to response - ResIQ = IQ#iq{type = error, - sub_el = [?ERR_SERVICE_UNAVAILABLE]}, - Res1 = jlib:replace_from_to( - jlib:make_jid("", StateData#state.server, ""), - jlib:make_jid("", "", ""), - jlib:iq_to_xml(ResIQ)), - send_element(StateData, jlib:remove_attr("to", Res1)); - _ -> - send_element(StateData, Res) - end; - _ -> - % Drop any stanza, which isn't IQ stanza - ok + #iq{} = IQ -> + Res = ejabberd_hooks:run_fold(c2s_unauthenticated_iq, + StateData#state.server, empty, + [StateData#state.server, IQ, + StateData#state.ip]), + case Res of + empty -> + ResIQ = IQ#iq{type = error, + sub_el = [?ERR_SERVICE_UNAVAILABLE]}, + Res1 = jlib:replace_from_to(jlib:make_jid(<<"">>, + StateData#state.server, + <<"">>), + jlib:make_jid(<<"">>, <<"">>, + <<"">>), + jlib:iq_to_xml(ResIQ)), + send_element(StateData, + jlib:remove_attr(<<"to">>, Res1)); + _ -> send_element(StateData, Res) + end; + _ -> + % Drop any stanza, which isn't IQ stanza + ok end. peerip(SockMod, Socket) -> IP = case SockMod of - gen_tcp -> inet:peername(Socket); - _ -> SockMod:peername(Socket) + gen_tcp -> inet:peername(Socket); + _ -> SockMod:peername(Socket) end, case IP of - {ok, IPOK} -> IPOK; - _ -> undefined + {ok, IPOK} -> IPOK; + _ -> undefined end. %% fsm_next_state_pack: Pack the StateData structure to improve @@ -2227,70 +2225,64 @@ fsm_next_state_gc(StateName, PackedStateData) -> %% fsm_next_state: Generate the next_state FSM tuple with different %% timeout, depending on the future state fsm_next_state(session_established, StateData) -> - {next_state, session_established, StateData, ?C2S_HIBERNATE_TIMEOUT}; + {next_state, session_established, StateData, + ?C2S_HIBERNATE_TIMEOUT}; fsm_next_state(StateName, StateData) -> {next_state, StateName, StateData, ?C2S_OPEN_TIMEOUT}. %% fsm_reply: Generate the reply FSM tuple with different timeout, %% depending on the future state fsm_reply(Reply, session_established, StateData) -> - {reply, Reply, session_established, StateData, ?C2S_HIBERNATE_TIMEOUT}; + {reply, Reply, session_established, StateData, + ?C2S_HIBERNATE_TIMEOUT}; fsm_reply(Reply, StateName, StateData) -> {reply, Reply, StateName, StateData, ?C2S_OPEN_TIMEOUT}. %% Used by c2s blacklist plugins -is_ip_blacklisted(undefined) -> - false; -is_ip_blacklisted({IP,_Port}) -> +is_ip_blacklisted(undefined) -> false; +is_ip_blacklisted({IP, _Port}) -> ejabberd_hooks:run_fold(check_bl_c2s, false, [IP]). %% Check from attributes %% returns invalid-from|NewElement check_from(El, FromJID) -> - case xml:get_tag_attr("from", El) of - false -> - El; - {value, SJID} -> - JID = jlib:string_to_jid(SJID), - case JID of - error -> - 'invalid-from'; - #jid{} -> - if - (JID#jid.luser == FromJID#jid.luser) and - (JID#jid.lserver == FromJID#jid.lserver) and - (JID#jid.lresource == FromJID#jid.lresource) -> - El; - (JID#jid.luser == FromJID#jid.luser) and - (JID#jid.lserver == FromJID#jid.lserver) and - (JID#jid.lresource == "") -> - El; - true -> - 'invalid-from' - end - end + case xml:get_tag_attr(<<"from">>, El) of + false -> El; + {value, SJID} -> + JID = jlib:string_to_jid(SJID), + case JID of + error -> 'invalid-from'; + #jid{} -> + if (JID#jid.luser == FromJID#jid.luser) and + (JID#jid.lserver == FromJID#jid.lserver) + and (JID#jid.lresource == FromJID#jid.lresource) -> + El; + (JID#jid.luser == FromJID#jid.luser) and + (JID#jid.lserver == FromJID#jid.lserver) + and (JID#jid.lresource == <<"">>) -> + El; + true -> 'invalid-from' + end + end end. fsm_limit_opts(Opts) -> case lists:keysearch(max_fsm_queue, 1, Opts) of - {value, {_, N}} when is_integer(N) -> - [{max_queue, N}]; - _ -> - case ejabberd_config:get_local_option(max_fsm_queue) of - N when is_integer(N) -> - [{max_queue, N}]; - _ -> - [] - end + {value, {_, N}} when is_integer(N) -> [{max_queue, N}]; + _ -> + case ejabberd_config:get_local_option( + max_fsm_queue, + fun(I) when is_integer(I), I > 0 -> I end) of + undefined -> []; + N -> [{max_queue, N}] + end end. bounce_messages() -> receive - {route, From, To, El} -> - ejabberd_router:route(From, To, El), - bounce_messages() - after 0 -> - ok + {route, From, To, El} -> + ejabberd_router:route(From, To, El), bounce_messages() + after 0 -> ok end. %%%---------------------------------------------------------------------- @@ -2298,40 +2290,40 @@ bounce_messages() -> %%%---------------------------------------------------------------------- route_blocking(What, StateData) -> - SubEl = - case What of - {block, JIDs} -> - {xmlelement, "block", - [{"xmlns", ?NS_BLOCKING}], - lists:map( - fun(JID) -> - {xmlelement, "item", - [{"jid", jlib:jid_to_string(JID)}], - []} - end, JIDs)}; - {unblock, JIDs} -> - {xmlelement, "unblock", - [{"xmlns", ?NS_BLOCKING}], - lists:map( - fun(JID) -> - {xmlelement, "item", - [{"jid", jlib:jid_to_string(JID)}], - []} - end, JIDs)}; - unblock_all -> - {xmlelement, "unblock", - [{"xmlns", ?NS_BLOCKING}], []} - end, - PrivPushIQ = - #iq{type = set, xmlns = ?NS_BLOCKING, - id = "push", - sub_el = [SubEl]}, + SubEl = case What of + {block, JIDs} -> + #xmlel{name = <<"block">>, + attrs = [{<<"xmlns">>, ?NS_BLOCKING}], + children = + lists:map(fun (JID) -> + #xmlel{name = <<"item">>, + attrs = + [{<<"jid">>, + jlib:jid_to_string(JID)}], + children = []} + end, + JIDs)}; + {unblock, JIDs} -> + #xmlel{name = <<"unblock">>, + attrs = [{<<"xmlns">>, ?NS_BLOCKING}], + children = + lists:map(fun (JID) -> + #xmlel{name = <<"item">>, + attrs = + [{<<"jid">>, + jlib:jid_to_string(JID)}], + children = []} + end, + JIDs)}; + unblock_all -> + #xmlel{name = <<"unblock">>, + attrs = [{<<"xmlns">>, ?NS_BLOCKING}], children = []} + end, + PrivPushIQ = #iq{type = set, xmlns = ?NS_BLOCKING, + id = <<"push">>, sub_el = [SubEl]}, PrivPushEl = - jlib:replace_from_to( - jlib:jid_remove_resource( - StateData#state.jid), - StateData#state.jid, - jlib:iq_to_xml(PrivPushIQ)), + jlib:replace_from_to(jlib:jid_remove_resource(StateData#state.jid), + StateData#state.jid, jlib:iq_to_xml(PrivPushIQ)), send_element(StateData, PrivPushEl), %% No need to replace active privacy list here, %% blocking pushes are always accompanied by @@ -2344,45 +2336,35 @@ route_blocking(What, StateData) -> %% Try to reduce the heap footprint of the four presence sets %% by ensuring that we re-use strings and Jids wherever possible. -pack(S = #state{pres_a=A, - pres_i=I, - pres_f=F, - pres_t=T}) -> - {NewA, Pack1} = pack_jid_set(A, gb_trees:empty()), - {NewI, Pack2} = pack_jid_set(I, Pack1), +pack(S = #state{pres_a = A, pres_f = F, + pres_t = T}) -> + {NewA, Pack2} = pack_jid_set(A, gb_trees:empty()), {NewF, Pack3} = pack_jid_set(F, Pack2), {NewT, _Pack4} = pack_jid_set(T, Pack3), - %% Throw away Pack4 so that if we delete references to - %% Strings or Jids in any of the sets there will be - %% no live references for the GC to find. - S#state{pres_a=NewA, - pres_i=NewI, - pres_f=NewF, - pres_t=NewT}. + S#state{pres_a = NewA, pres_f = NewF, + pres_t = NewT}. pack_jid_set(Set, Pack) -> - Jids = ?SETS:to_list(Set), + Jids = (?SETS):to_list(Set), {PackedJids, NewPack} = pack_jids(Jids, Pack, []), - {?SETS:from_list(PackedJids), NewPack}. + {(?SETS):from_list(PackedJids), NewPack}. pack_jids([], Pack, Acc) -> {Acc, Pack}; -pack_jids([{U,S,R}=Jid | Jids], Pack, Acc) -> +pack_jids([{U, S, R} = Jid | Jids], Pack, Acc) -> case gb_trees:lookup(Jid, Pack) of - {value, PackedJid} -> - pack_jids(Jids, Pack, [PackedJid | Acc]); - none -> - {NewU, Pack1} = pack_string(U, Pack), - {NewS, Pack2} = pack_string(S, Pack1), - {NewR, Pack3} = pack_string(R, Pack2), - NewJid = {NewU, NewS, NewR}, - NewPack = gb_trees:insert(NewJid, NewJid, Pack3), - pack_jids(Jids, NewPack, [NewJid | Acc]) + {value, PackedJid} -> + pack_jids(Jids, Pack, [PackedJid | Acc]); + none -> + {NewU, Pack1} = pack_string(U, Pack), + {NewS, Pack2} = pack_string(S, Pack1), + {NewR, Pack3} = pack_string(R, Pack2), + NewJid = {NewU, NewS, NewR}, + NewPack = gb_trees:insert(NewJid, NewJid, Pack3), + pack_jids(Jids, NewPack, [NewJid | Acc]) end. pack_string(String, Pack) -> case gb_trees:lookup(String, Pack) of - {value, PackedString} -> - {PackedString, Pack}; - none -> - {String, gb_trees:insert(String, String, Pack)} + {value, PackedString} -> {PackedString, Pack}; + none -> {String, gb_trees:insert(String, String, Pack)} end. diff --git a/src/ejabberd_c2s_config.erl b/src/ejabberd_c2s_config.erl index b416edd2b..4dbc48f38 100644 --- a/src/ejabberd_c2s_config.erl +++ b/src/ejabberd_c2s_config.erl @@ -26,6 +26,7 @@ %%%---------------------------------------------------------------------- -module(ejabberd_c2s_config). + -author('mremond@process-one.net'). -export([get_c2s_limits/0]). @@ -33,28 +34,33 @@ %% Get first c2s configuration limitations to apply it to other c2s %% connectors. get_c2s_limits() -> - case ejabberd_config:get_local_option(listen) of - undefined -> - []; - C2SFirstListen -> - case lists:keysearch(ejabberd_c2s, 2, C2SFirstListen) of - false -> - []; - {value, {_Port, ejabberd_c2s, Opts}} -> - select_opts_values(Opts) - end + case ejabberd_config:get_local_option(listen, fun(V) -> V end) of + undefined -> []; + C2SFirstListen -> + case lists:keysearch(ejabberd_c2s, 2, C2SFirstListen) of + false -> []; + {value, {_Port, ejabberd_c2s, Opts}} -> + select_opts_values(Opts) + end end. %% Only get access, shaper and max_stanza_size values + select_opts_values(Opts) -> select_opts_values(Opts, []). + select_opts_values([], SelectedValues) -> SelectedValues; -select_opts_values([{access,Value}|Opts], SelectedValues) -> - select_opts_values(Opts, [{access, Value}|SelectedValues]); -select_opts_values([{shaper,Value}|Opts], SelectedValues) -> - select_opts_values(Opts, [{shaper, Value}|SelectedValues]); -select_opts_values([{max_stanza_size,Value}|Opts], SelectedValues) -> - select_opts_values(Opts, [{max_stanza_size, Value}|SelectedValues]); -select_opts_values([_Opt|Opts], SelectedValues) -> +select_opts_values([{access, Value} | Opts], + SelectedValues) -> + select_opts_values(Opts, + [{access, Value} | SelectedValues]); +select_opts_values([{shaper, Value} | Opts], + SelectedValues) -> + select_opts_values(Opts, + [{shaper, Value} | SelectedValues]); +select_opts_values([{max_stanza_size, Value} | Opts], + SelectedValues) -> + select_opts_values(Opts, + [{max_stanza_size, Value} | SelectedValues]); +select_opts_values([_Opt | Opts], SelectedValues) -> select_opts_values(Opts, SelectedValues). - diff --git a/src/ejabberd_captcha.erl b/src/ejabberd_captcha.erl index 319bf8e81..6cf23a493 100644 --- a/src/ejabberd_captcha.erl +++ b/src/ejabberd_captcha.erl @@ -32,36 +32,44 @@ -export([start_link/0]). %% gen_server callbacks --export([init/1, handle_call/3, handle_cast/2, handle_info/2, - terminate/2, code_change/3]). +-export([init/1, handle_call/3, handle_cast/2, + handle_info/2, terminate/2, code_change/3]). --export([create_captcha/6, build_captcha_html/2, check_captcha/2, - process_reply/1, process/2, is_feature_available/0, - create_captcha_x/5, create_captcha_x/6]). +-export([create_captcha/6, build_captcha_html/2, + check_captcha/2, process_reply/1, process/2, + is_feature_available/0, create_captcha_x/5, + create_captcha_x/6]). -include("jlib.hrl"). + -include("ejabberd.hrl"). + -include("web/ejabberd_http.hrl"). -define(VFIELD(Type, Var, Value), - {xmlelement, "field", [{"type", Type}, {"var", Var}], - [{xmlelement, "value", [], [Value]}]}). + #xmlel{name = <<"field">>, + attrs = [{<<"type">>, Type}, {<<"var">>, Var}], + children = + [#xmlel{name = <<"value">>, attrs = [], + children = [Value]}]}). --define(CAPTCHA_TEXT(Lang), translate:translate(Lang, "Enter the text you see")). --define(CAPTCHA_LIFETIME, 120000). % two minutes --define(LIMIT_PERIOD, 60*1000*1000). % one minute +-define(CAPTCHA_TEXT(Lang), + translate:translate(Lang, + <<"Enter the text you see">>)). --record(state, {limits = treap:empty()}). --record(captcha, {id, pid, key, tref, args}). +-define(CAPTCHA_LIFETIME, 120000). --define(T(S), - case catch mnesia:transaction(fun() -> S end) of - {atomic, Res} -> - Res; - {_, Reason} -> - ?ERROR_MSG("mnesia transaction failed: ~p", [Reason]), - {error, Reason} - end). +-define(LIMIT_PERIOD, 60*1000*1000). + +-type error() :: efbig | enodata | limit | malformed_image | timeout. + +-record(state, {limits = treap:empty() :: treap:treap()}). + +-record(captcha, {id :: binary(), + pid :: pid(), + key :: binary(), + tref :: reference(), + args :: any()}). %%==================================================================== %% API @@ -71,98 +79,197 @@ %% Description: Starts the server %%-------------------------------------------------------------------- start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). + gen_server:start_link({local, ?MODULE}, ?MODULE, [], + []). -create_captcha(SID, From, To, Lang, Limiter, Args) - when is_list(Lang), is_list(SID), - is_record(From, jid), is_record(To, jid) -> +-spec create_captcha(binary(), jid(), jid(), + binary(), any(), any()) -> {error, error()} | + {ok, binary(), [xmlel()]}. + +create_captcha(SID, From, To, Lang, Limiter, Args) -> case create_image(Limiter) of - {ok, Type, Key, Image} -> - Id = randoms:get_string(), - B64Image = jlib:encode_base64(binary_to_list(Image)), - JID = jlib:jid_to_string(From), - CID = "sha1+" ++ sha:sha(Image) ++ "@bob.xmpp.org", - Data = {xmlelement, "data", - [{"xmlns", ?NS_BOB}, {"cid", CID}, - {"max-age", "0"}, {"type", Type}], - [{xmlcdata, B64Image}]}, - Captcha = - {xmlelement, "captcha", [{"xmlns", ?NS_CAPTCHA}], - [{xmlelement, "x", [{"xmlns", ?NS_XDATA}, {"type", "form"}], - [?VFIELD("hidden", "FORM_TYPE", {xmlcdata, ?NS_CAPTCHA}), - ?VFIELD("hidden", "from", {xmlcdata, jlib:jid_to_string(To)}), - ?VFIELD("hidden", "challenge", {xmlcdata, Id}), - ?VFIELD("hidden", "sid", {xmlcdata, SID}), - {xmlelement, "field", [{"var", "ocr"}, {"label", ?CAPTCHA_TEXT(Lang)}], - [{xmlelement, "required", [], []}, - {xmlelement, "media", [{"xmlns", ?NS_MEDIA}], - [{xmlelement, "uri", [{"type", Type}], - [{xmlcdata, "cid:" ++ CID}]}]}]}]}]}, - BodyString1 = translate:translate(Lang, "Your messages to ~s are being blocked. To unblock them, visit ~s"), - BodyString = io_lib:format(BodyString1, [JID, get_url(Id)]), - Body = {xmlelement, "body", [], - [{xmlcdata, BodyString}]}, - OOB = {xmlelement, "x", [{"xmlns", ?NS_OOB}], - [{xmlelement, "url", [], [{xmlcdata, get_url(Id)}]}]}, - Tref = erlang:send_after(?CAPTCHA_LIFETIME, ?MODULE, {remove_id, Id}), - case ?T(mnesia:write(#captcha{id=Id, pid=self(), key=Key, - tref=Tref, args=Args})) of - ok -> - {ok, Id, [Body, OOB, Captcha, Data]}; - Err -> - {error, Err} - end; - Err -> - Err + {ok, Type, Key, Image} -> + Id = <<(randoms:get_string())/binary>>, + B64Image = jlib:encode_base64((Image)), + JID = jlib:jid_to_string(From), + CID = <<"sha1+", (sha:sha(Image))/binary, + "@bob.xmpp.org">>, + Data = #xmlel{name = <<"data">>, + attrs = + [{<<"xmlns">>, ?NS_BOB}, {<<"cid">>, CID}, + {<<"max-age">>, <<"0">>}, {<<"type">>, Type}], + children = [{xmlcdata, B64Image}]}, + Captcha = #xmlel{name = <<"captcha">>, + attrs = [{<<"xmlns">>, ?NS_CAPTCHA}], + children = + [#xmlel{name = <<"x">>, + attrs = + [{<<"xmlns">>, ?NS_XDATA}, + {<<"type">>, <<"form">>}], + children = + [?VFIELD(<<"hidden">>, + <<"FORM_TYPE">>, + {xmlcdata, ?NS_CAPTCHA}), + ?VFIELD(<<"hidden">>, <<"from">>, + {xmlcdata, + jlib:jid_to_string(To)}), + ?VFIELD(<<"hidden">>, + <<"challenge">>, + {xmlcdata, Id}), + ?VFIELD(<<"hidden">>, <<"sid">>, + {xmlcdata, SID}), + #xmlel{name = <<"field">>, + attrs = + [{<<"var">>, <<"ocr">>}, + {<<"label">>, + ?CAPTCHA_TEXT(Lang)}], + children = + [#xmlel{name = + <<"required">>, + attrs = [], + children = []}, + #xmlel{name = + <<"media">>, + attrs = + [{<<"xmlns">>, + ?NS_MEDIA}], + children = + [#xmlel{name + = + <<"uri">>, + attrs + = + [{<<"type">>, + Type}], + children + = + [{xmlcdata, + <<"cid:", + CID/binary>>}]}]}]}]}]}, + BodyString1 = translate:translate(Lang, + <<"Your messages to ~s are being blocked. " + "To unblock them, visit ~s">>), + BodyString = iolist_to_binary(io_lib:format(BodyString1, + [JID, get_url(Id)])), + Body = #xmlel{name = <<"body">>, attrs = [], + children = [{xmlcdata, BodyString}]}, + OOB = #xmlel{name = <<"x">>, + attrs = [{<<"xmlns">>, ?NS_OOB}], + children = + [#xmlel{name = <<"url">>, attrs = [], + children = [{xmlcdata, get_url(Id)}]}]}, + Tref = erlang:send_after(?CAPTCHA_LIFETIME, ?MODULE, + {remove_id, Id}), + ets:insert(captcha, + #captcha{id = Id, pid = self(), key = Key, tref = Tref, + args = Args}), + {ok, Id, [Body, OOB, Captcha, Data]}; + Err -> Err end. +-spec create_captcha_x(binary(), jid(), binary(), + any(), [xmlel()]) -> {ok, [xmlel()]} | + {error, error()}. + create_captcha_x(SID, To, Lang, Limiter, HeadEls) -> create_captcha_x(SID, To, Lang, Limiter, HeadEls, []). -create_captcha_x(SID, To, Lang, Limiter, HeadEls, TailEls) -> +-spec create_captcha_x(binary(), jid(), binary(), + any(), [xmlel()], [xmlel()]) -> {ok, [xmlel()]} | + {error, error()}. + +create_captcha_x(SID, To, Lang, Limiter, HeadEls, + TailEls) -> case create_image(Limiter) of - {ok, Type, Key, Image} -> - Id = randoms:get_string(), - B64Image = jlib:encode_base64(binary_to_list(Image)), - CID = "sha1+" ++ sha:sha(Image) ++ "@bob.xmpp.org", - Data = {xmlelement, "data", - [{"xmlns", ?NS_BOB}, {"cid", CID}, - {"max-age", "0"}, {"type", Type}], - [{xmlcdata, B64Image}]}, - HelpTxt = translate:translate( - Lang, - "If you don't see the CAPTCHA image here, " - "visit the web page."), - Imageurl = get_url(Id ++ "/image"), - Captcha = - {xmlelement, "x", [{"xmlns", ?NS_XDATA}, {"type", "form"}], - [?VFIELD("hidden", "FORM_TYPE", {xmlcdata, ?NS_CAPTCHA}) | HeadEls] ++ - [{xmlelement, "field", [{"type", "fixed"}], - [{xmlelement, "value", [], [{xmlcdata, HelpTxt}]}]}, - {xmlelement, "field", [{"type", "hidden"}, {"var", "captchahidden"}], - [{xmlelement, "value", [], [{xmlcdata, "workaround-for-psi"}]}]}, - {xmlelement, "field", - [{"type", "text-single"}, - {"label", translate:translate(Lang, "CAPTCHA web page")}, - {"var", "url"}], - [{xmlelement, "value", [], [{xmlcdata, Imageurl}]}]}, - ?VFIELD("hidden", "from", {xmlcdata, jlib:jid_to_string(To)}), - ?VFIELD("hidden", "challenge", {xmlcdata, Id}), - ?VFIELD("hidden", "sid", {xmlcdata, SID}), - {xmlelement, "field", [{"var", "ocr"}, {"label", ?CAPTCHA_TEXT(Lang)}], - [{xmlelement, "required", [], []}, - {xmlelement, "media", [{"xmlns", ?NS_MEDIA}], - [{xmlelement, "uri", [{"type", Type}], - [{xmlcdata, "cid:" ++ CID}]}]}]}] ++ TailEls}, - Tref = erlang:send_after(?CAPTCHA_LIFETIME, ?MODULE, {remove_id, Id}), - case ?T(mnesia:write(#captcha{id=Id, key=Key, tref=Tref})) of - ok -> - {ok, [Captcha, Data]}; - Err -> - {error, Err} - end; - Err -> - Err + {ok, Type, Key, Image} -> + Id = <<(randoms:get_string())/binary>>, + B64Image = jlib:encode_base64((Image)), + CID = <<"sha1+", (sha:sha(Image))/binary, + "@bob.xmpp.org">>, + Data = #xmlel{name = <<"data">>, + attrs = + [{<<"xmlns">>, ?NS_BOB}, {<<"cid">>, CID}, + {<<"max-age">>, <<"0">>}, {<<"type">>, Type}], + children = [{xmlcdata, B64Image}]}, + HelpTxt = translate:translate(Lang, + <<"If you don't see the CAPTCHA image here, " + "visit the web page.">>), + Imageurl = get_url(<>), + Captcha = #xmlel{name = <<"x">>, + attrs = + [{<<"xmlns">>, ?NS_XDATA}, + {<<"type">>, <<"form">>}], + children = + [?VFIELD(<<"hidden">>, <<"FORM_TYPE">>, + {xmlcdata, ?NS_CAPTCHA}) + | HeadEls] + ++ + [#xmlel{name = <<"field">>, + attrs = [{<<"type">>, <<"fixed">>}], + children = + [#xmlel{name = <<"value">>, + attrs = [], + children = + [{xmlcdata, + HelpTxt}]}]}, + #xmlel{name = <<"field">>, + attrs = + [{<<"type">>, <<"hidden">>}, + {<<"var">>, <<"captchahidden">>}], + children = + [#xmlel{name = <<"value">>, + attrs = [], + children = + [{xmlcdata, + <<"workaround-for-psi">>}]}]}, + #xmlel{name = <<"field">>, + attrs = + [{<<"type">>, <<"text-single">>}, + {<<"label">>, + translate:translate(Lang, + <<"CAPTCHA web page">>)}, + {<<"var">>, <<"url">>}], + children = + [#xmlel{name = <<"value">>, + attrs = [], + children = + [{xmlcdata, + Imageurl}]}]}, + ?VFIELD(<<"hidden">>, <<"from">>, + {xmlcdata, jlib:jid_to_string(To)}), + ?VFIELD(<<"hidden">>, <<"challenge">>, + {xmlcdata, Id}), + ?VFIELD(<<"hidden">>, <<"sid">>, + {xmlcdata, SID}), + #xmlel{name = <<"field">>, + attrs = + [{<<"var">>, <<"ocr">>}, + {<<"label">>, + ?CAPTCHA_TEXT(Lang)}], + children = + [#xmlel{name = <<"required">>, + attrs = [], children = []}, + #xmlel{name = <<"media">>, + attrs = + [{<<"xmlns">>, + ?NS_MEDIA}], + children = + [#xmlel{name = + <<"uri">>, + attrs = + [{<<"type">>, + Type}], + children = + [{xmlcdata, + <<"cid:", + CID/binary>>}]}]}]}] + ++ TailEls}, + Tref = erlang:send_after(?CAPTCHA_LIFETIME, ?MODULE, + {remove_id, Id}), + ets:insert(captcha, + #captcha{id = Id, key = Key, tref = Tref}), + {ok, [Captcha, Data]}; + Err -> Err end. %% @spec (Id::string(), Lang::string()) -> {FormEl, {ImgEl, TextEl, IdEl, KeyEl}} | captcha_not_found @@ -171,206 +278,177 @@ create_captcha_x(SID, To, Lang, Limiter, HeadEls, TailEls) -> %% TextEl = xmlelement() %% IdEl = xmlelement() %% KeyEl = xmlelement() +-spec build_captcha_html(binary(), binary()) -> captcha_not_found | + {xmlel(), + {xmlel(), xmlel(), + xmlel(), xmlel()}}. + build_captcha_html(Id, Lang) -> - case mnesia:dirty_read(captcha, Id) of - [#captcha{}] -> - ImgEl = {xmlelement, "img", [{"src", get_url(Id ++ "/image")}], []}, - TextEl = {xmlcdata, ?CAPTCHA_TEXT(Lang)}, - IdEl = {xmlelement, "input", [{"type", "hidden"}, - {"name", "id"}, - {"value", Id}], []}, - KeyEl = {xmlelement, "input", [{"type", "text"}, - {"name", "key"}, - {"size", "10"}], []}, - FormEl = {xmlelement, "form", [{"action", get_url(Id)}, - {"name", "captcha"}, - {"method", "POST"}], - [ImgEl, - {xmlelement, "br", [], []}, - TextEl, - {xmlelement, "br", [], []}, - IdEl, - KeyEl, - {xmlelement, "br", [], []}, - {xmlelement, "input", [{"type", "submit"}, - {"name", "enter"}, - {"value", "OK"}], []} - ]}, - {FormEl, {ImgEl, TextEl, IdEl, KeyEl}}; - _ -> - captcha_not_found + case lookup_captcha(Id) of + {ok, _} -> + ImgEl = #xmlel{name = <<"img">>, + attrs = + [{<<"src">>, get_url(<>)}], + children = []}, + TextEl = {xmlcdata, ?CAPTCHA_TEXT(Lang)}, + IdEl = #xmlel{name = <<"input">>, + attrs = + [{<<"type">>, <<"hidden">>}, {<<"name">>, <<"id">>}, + {<<"value">>, Id}], + children = []}, + KeyEl = #xmlel{name = <<"input">>, + attrs = + [{<<"type">>, <<"text">>}, {<<"name">>, <<"key">>}, + {<<"size">>, <<"10">>}], + children = []}, + FormEl = #xmlel{name = <<"form">>, + attrs = + [{<<"action">>, get_url(Id)}, + {<<"name">>, <<"captcha">>}, + {<<"method">>, <<"POST">>}], + children = + [ImgEl, + #xmlel{name = <<"br">>, attrs = [], + children = []}, + TextEl, + #xmlel{name = <<"br">>, attrs = [], + children = []}, + IdEl, KeyEl, + #xmlel{name = <<"br">>, attrs = [], + children = []}, + #xmlel{name = <<"input">>, + attrs = + [{<<"type">>, <<"submit">>}, + {<<"name">>, <<"enter">>}, + {<<"value">>, <<"OK">>}], + children = []}]}, + {FormEl, {ImgEl, TextEl, IdEl, KeyEl}}; + _ -> captcha_not_found end. %% @spec (Id::string(), ProvidedKey::string()) -> captcha_valid | captcha_non_valid | captcha_not_found -check_captcha(Id, ProvidedKey) -> - ?T(case mnesia:read(captcha, Id, write) of - [#captcha{pid=Pid, args=Args, key=StoredKey, tref=Tref}] -> - mnesia:delete({captcha, Id}), - erlang:cancel_timer(Tref), - if StoredKey == ProvidedKey -> - if is_pid(Pid) -> - Pid ! {captcha_succeed, Args}; - true -> - ok - end, - captcha_valid; - true -> - if is_pid(Pid) -> - Pid ! {captcha_failed, Args}; - true -> - ok - end, - captcha_non_valid - end; - _ -> - captcha_not_found - end). +-spec check_captcha(binary(), binary()) -> captcha_not_found | + captcha_valid | + captcha_non_valid. -process_reply({xmlelement, _, _, _} = El) -> - case xml:get_subtag(El, "x") of - false -> - {error, malformed}; - Xdata -> - Fields = jlib:parse_xdata_submit(Xdata), - case catch {proplists:get_value("challenge", Fields), - proplists:get_value("ocr", Fields)} of - {[Id|_], [OCR|_]} -> - ?T(case mnesia:read(captcha, Id, write) of - [#captcha{pid=Pid, args=Args, key=Key, tref=Tref}] -> - mnesia:delete({captcha, Id}), - erlang:cancel_timer(Tref), - if OCR == Key -> - if is_pid(Pid) -> - Pid ! {captcha_succeed, Args}; - true -> - ok - end, - ok; - true -> - if is_pid(Pid) -> - Pid ! {captcha_failed, Args}; - true -> - ok - end, - {error, bad_match} - end; - _ -> - {error, not_found} - end); - _ -> - {error, malformed} - end +-spec process_reply(xmlel()) -> ok | {error, bad_match | not_found | malformed}. + +process_reply(#xmlel{} = El) -> + case xml:get_subtag(El, <<"x">>) of + false -> {error, malformed}; + Xdata -> + Fields = jlib:parse_xdata_submit(Xdata), + case catch {proplists:get_value(<<"challenge">>, + Fields), + proplists:get_value(<<"ocr">>, Fields)} + of + {[Id | _], [OCR | _]} -> + case check_captcha(Id, OCR) of + captcha_valid -> ok; + captcha_non_valid -> {error, bad_match}; + captcha_not_found -> {error, not_found} + end; + _ -> {error, malformed} + end end; -process_reply(_) -> - {error, malformed}. +process_reply(_) -> {error, malformed}. - -process(_Handlers, #request{method='GET', lang=Lang, path=[_, Id]}) -> +process(_Handlers, + #request{method = 'GET', lang = Lang, + path = [_, Id]}) -> case build_captcha_html(Id, Lang) of - {FormEl, _} when is_tuple(FormEl) -> - Form = - {xmlelement, "div", [{"align", "center"}], - [FormEl]}, - ejabberd_web:make_xhtml([Form]); - captcha_not_found -> - ejabberd_web:error(not_found) + {FormEl, _} when is_tuple(FormEl) -> + Form = #xmlel{name = <<"div">>, + attrs = [{<<"align">>, <<"center">>}], + children = [FormEl]}, + ejabberd_web:make_xhtml([Form]); + captcha_not_found -> ejabberd_web:error(not_found) end; - -process(_Handlers, #request{method='GET', path=[_, Id, "image"], ip = IP}) -> +process(_Handlers, + #request{method = 'GET', path = [_, Id, <<"image">>], + ip = IP}) -> {Addr, _Port} = IP, - case mnesia:dirty_read(captcha, Id) of - [#captcha{key=Key}] -> - case create_image(Addr, Key) of - {ok, Type, _, Img} -> - {200, - [{"Content-Type", Type}, - {"Cache-Control", "no-cache"}, - {"Last-Modified", httpd_util:rfc1123_date()}], - Img}; - {error, limit} -> - ejabberd_web:error(not_allowed); - _ -> - ejabberd_web:error(not_found) - end; - _ -> - ejabberd_web:error(not_found) + case lookup_captcha(Id) of + {ok, #captcha{key = Key}} -> + case create_image(Addr, Key) of + {ok, Type, _, Img} -> + {200, + [{<<"Content-Type">>, Type}, + {<<"Cache-Control">>, <<"no-cache">>}, + {<<"Last-Modified">>, list_to_binary(httpd_util:rfc1123_date())}], + Img}; + {error, limit} -> ejabberd_web:error(not_allowed); + _ -> ejabberd_web:error(not_found) + end; + _ -> ejabberd_web:error(not_found) end; - -process(_Handlers, #request{method='POST', q=Q, lang=Lang, path=[_, Id]}) -> - ProvidedKey = proplists:get_value("key", Q, none), +process(_Handlers, + #request{method = 'POST', q = Q, lang = Lang, + path = [_, Id]}) -> + ProvidedKey = proplists:get_value(<<"key">>, Q, none), case check_captcha(Id, ProvidedKey) of - captcha_valid -> - Form = - {xmlelement, "p", [], - [{xmlcdata, - translate:translate(Lang, "The CAPTCHA is valid.") - }]}, - ejabberd_web:make_xhtml([Form]); - captcha_non_valid -> - ejabberd_web:error(not_allowed); - captcha_not_found -> - ejabberd_web:error(not_found) + captcha_valid -> + Form = #xmlel{name = <<"p">>, attrs = [], + children = + [{xmlcdata, + translate:translate(Lang, + <<"The CAPTCHA is valid.">>)}]}, + ejabberd_web:make_xhtml([Form]); + captcha_non_valid -> ejabberd_web:error(not_allowed); + captcha_not_found -> ejabberd_web:error(not_found) end; - process(_Handlers, _Request) -> ejabberd_web:error(not_found). - %%==================================================================== %% gen_server callbacks %%==================================================================== init([]) -> - mnesia:create_table(captcha, - [{ram_copies, [node()]}, - {attributes, record_info(fields, captcha)}]), - mnesia:add_table_copy(captcha, node(), ram_copies), + mnesia:delete_table(captcha), + ets:new(captcha, + [named_table, public, {keypos, #captcha.id}]), check_captcha_setup(), {ok, #state{}}. -handle_call({is_limited, Limiter, RateLimit}, _From, State) -> +handle_call({is_limited, Limiter, RateLimit}, _From, + State) -> NowPriority = now_priority(), - CleanPriority = NowPriority + ?LIMIT_PERIOD, + CleanPriority = NowPriority + (?LIMIT_PERIOD), Limits = clean_treap(State#state.limits, CleanPriority), case treap:lookup(Limiter, Limits) of - {ok, _, Rate} when Rate >= RateLimit -> - {reply, true, State#state{limits = Limits}}; - {ok, Priority, Rate} -> - NewLimits = treap:insert(Limiter, Priority, Rate+1, Limits), - {reply, false, State#state{limits = NewLimits}}; - _ -> - NewLimits = treap:insert(Limiter, NowPriority, 1, Limits), - {reply, false, State#state{limits = NewLimits}} + {ok, _, Rate} when Rate >= RateLimit -> + {reply, true, State#state{limits = Limits}}; + {ok, Priority, Rate} -> + NewLimits = treap:insert(Limiter, Priority, Rate + 1, + Limits), + {reply, false, State#state{limits = NewLimits}}; + _ -> + NewLimits = treap:insert(Limiter, NowPriority, 1, + Limits), + {reply, false, State#state{limits = NewLimits}} end; handle_call(_Request, _From, State) -> {reply, bad_request, State}. -handle_cast(_Msg, State) -> - {noreply, State}. +handle_cast(_Msg, State) -> {noreply, State}. handle_info({remove_id, Id}, State) -> ?DEBUG("captcha ~p timed out", [Id]), - _ = ?T(case mnesia:read(captcha, Id, write) of - [#captcha{args=Args, pid=Pid}] -> - if is_pid(Pid) -> - Pid ! {captcha_failed, Args}; - true -> - ok - end, - mnesia:delete({captcha, Id}); - _ -> - ok - end), + case ets:lookup(captcha, Id) of + [#captcha{args = Args, pid = Pid}] -> + if is_pid(Pid) -> Pid ! {captcha_failed, Args}; + true -> ok + end, + ets:delete(captcha, Id); + _ -> ok + end, {noreply, State}; +handle_info(_Info, State) -> {noreply, State}. -handle_info(_Info, State) -> - {noreply, State}. +terminate(_Reason, _State) -> ok. -terminate(_Reason, _State) -> - ok. - -code_change(_OldVsn, State, _Extra) -> - {ok, State}. +code_change(_OldVsn, State, _Extra) -> {ok, State}. %%-------------------------------------------------------------------- %%% Internal functions @@ -382,126 +460,136 @@ code_change(_OldVsn, State, _Extra) -> %% Image = binary() %% Reason = atom() %%-------------------------------------------------------------------- -create_image() -> - create_image(undefined). +create_image() -> create_image(undefined). create_image(Limiter) -> - %% Six numbers from 1 to 9. - Key = string:substr(randoms:get_string(), 1, 6), + Key = str:substr(randoms:get_string(), 1, 6), create_image(Limiter, Key). create_image(Limiter, Key) -> case is_limited(Limiter) of - true -> - {error, limit}; - false -> - do_create_image(Key) + true -> {error, limit}; + false -> do_create_image(Key) end. do_create_image(Key) -> FileName = get_prog_name(), Cmd = lists:flatten(io_lib:format("~s ~s", [FileName, Key])), case cmd(Cmd) of - {ok, <<16#89, $P, $N, $G, $\r, $\n, 16#1a, $\n, _/binary>> = Img} -> - {ok, "image/png", Key, Img}; - {ok, <<16#ff, 16#d8, _/binary>> = Img} -> - {ok, "image/jpeg", Key, Img}; - {ok, <<$G, $I, $F, $8, X, $a, _/binary>> = Img} when X==$7; X==$9 -> - {ok, "image/gif", Key, Img}; - {error, enodata = Reason} -> - ?ERROR_MSG("Failed to process output from \"~s\". " - "Maybe ImageMagick's Convert program is not installed.", - [Cmd]), - {error, Reason}; - {error, Reason} -> - ?ERROR_MSG("Failed to process an output from \"~s\": ~p", - [Cmd, Reason]), - {error, Reason}; - _ -> - Reason = malformed_image, - ?ERROR_MSG("Failed to process an output from \"~s\": ~p", - [Cmd, Reason]), - {error, Reason} + {ok, + <<137, $P, $N, $G, $\r, $\n, 26, $\n, _/binary>> = + Img} -> + {ok, <<"image/png">>, Key, Img}; + {ok, <<255, 216, _/binary>> = Img} -> + {ok, <<"image/jpeg">>, Key, Img}; + {ok, <<$G, $I, $F, $8, X, $a, _/binary>> = Img} + when X == $7; X == $9 -> + {ok, <<"image/gif">>, Key, Img}; + {error, enodata = Reason} -> + ?ERROR_MSG("Failed to process output from \"~s\". " + "Maybe ImageMagick's Convert program " + "is not installed.", + [Cmd]), + {error, Reason}; + {error, Reason} -> + ?ERROR_MSG("Failed to process an output from \"~s\": ~p", + [Cmd, Reason]), + {error, Reason}; + _ -> + Reason = malformed_image, + ?ERROR_MSG("Failed to process an output from \"~s\": ~p", + [Cmd, Reason]), + {error, Reason} end. get_prog_name() -> - case ejabberd_config:get_local_option(captcha_cmd) of - FileName when is_list(FileName) -> - FileName; - Value when (Value == undefined) or (Value == "") -> - ?DEBUG("The option captcha_cmd is not configured, but some " - "module wants to use the CAPTCHA feature.", []), - false + case ejabberd_config:get_local_option( + captcha_cmd, + fun(FileName) -> + F = iolist_to_binary(FileName), + if F /= <<"">> -> F end + end) of + undefined -> + ?DEBUG("The option captcha_cmd is not configured, " + "but some module wants to use the CAPTCHA " + "feature.", + []), + false; + FileName -> + FileName end. get_url(Str) -> - CaptchaHost = ejabberd_config:get_local_option(captcha_host), - case string:tokens(CaptchaHost, ":") of - [Host] -> - "http://" ++ Host ++ "/captcha/" ++ Str; - ["http"++_ = TransferProt, Host] -> - TransferProt ++ ":" ++ Host ++ "/captcha/" ++ Str; - [Host, PortString] -> - TransferProt = atom_to_list(get_transfer_protocol(PortString)), - TransferProt ++ "://" ++ Host ++ ":" ++ PortString ++ "/captcha/" ++ Str; - [TransferProt, Host, PortString] -> - TransferProt ++ ":" ++ Host ++ ":" ++ PortString ++ "/captcha/" ++ Str; - _ -> - "http://" ++ ?MYNAME ++ "/captcha/" ++ Str + CaptchaHost = ejabberd_config:get_local_option( + captcha_host, + fun iolist_to_binary/1, + <<"">>), + case str:tokens(CaptchaHost, <<":">>) of + [Host] -> + <<"http://", Host/binary, "/captcha/", Str/binary>>; + [<<"http", _/binary>> = TransferProt, Host] -> + <>; + [Host, PortString] -> + TransferProt = + iolist_to_binary(atom_to_list(get_transfer_protocol(PortString))), + <>; + [TransferProt, Host, PortString] -> + <>; + _ -> + <<"http://", (?MYNAME)/binary, "/captcha/", Str/binary>> end. get_transfer_protocol(PortString) -> - PortNumber = list_to_integer(PortString), + PortNumber = jlib:binary_to_integer(PortString), PortListeners = get_port_listeners(PortNumber), get_captcha_transfer_protocol(PortListeners). get_port_listeners(PortNumber) -> - AllListeners = ejabberd_config:get_local_option(listen), - lists:filter( - fun({{Port, _Ip, _Netp}, _Module1, _Opts1}) when Port == PortNumber -> - true; - (_) -> - false - end, - AllListeners). + AllListeners = ejabberd_config:get_local_option(listen, fun(V) -> V end), + lists:filter(fun ({{Port, _Ip, _Netp}, _Module1, + _Opts1}) + when Port == PortNumber -> + true; + (_) -> false + end, + AllListeners). get_captcha_transfer_protocol([]) -> - throw("The port number mentioned in captcha_host is not " - "a ejabberd_http listener with 'captcha' option. " - "Change the port number or specify http:// in that option."); -get_captcha_transfer_protocol([{{_Port, _Ip, tcp}, ejabberd_http, Opts} + throw(<<"The port number mentioned in captcha_host " + "is not a ejabberd_http listener with " + "'captcha' option. Change the port number " + "or specify http:// in that option.">>); +get_captcha_transfer_protocol([{{_Port, _Ip, tcp}, + ejabberd_http, Opts} | Listeners]) -> case lists:member(captcha, Opts) of - true -> - case lists:member(tls, Opts) of - true -> - https; - false -> - http - end; - false -> - get_captcha_transfer_protocol(Listeners) + true -> + case lists:member(tls, Opts) of + true -> https; + false -> http + end; + false -> get_captcha_transfer_protocol(Listeners) end; get_captcha_transfer_protocol([_ | Listeners]) -> get_captcha_transfer_protocol(Listeners). -is_limited(undefined) -> - false; +is_limited(undefined) -> false; is_limited(Limiter) -> - case ejabberd_config:get_local_option(captcha_limit) of - Int when is_integer(Int), Int > 0 -> - case catch gen_server:call(?MODULE, {is_limited, Limiter, Int}, - 5000) of - true -> - true; - false -> - false; - Err -> - ?ERROR_MSG("Call failed: ~p", [Err]), - false - end; - _ -> - false + case ejabberd_config:get_local_option( + captcha_limit, + fun(I) when is_integer(I), I > 0 -> I end) of + undefined -> false; + Int -> + case catch gen_server:call(?MODULE, + {is_limited, Limiter, Int}, 5000) + of + true -> true; + false -> false; + Err -> ?ERROR_MSG("Call failed: ~p", [Err]), false + end end. %%-------------------------------------------------------------------- @@ -511,82 +599,97 @@ is_limited(Limiter) -> %% Description: os:cmd/1 replacement %%-------------------------------------------------------------------- -define(CMD_TIMEOUT, 5000). --define(MAX_FILE_SIZE, 64*1024). + +-define(MAX_FILE_SIZE, 64 * 1024). cmd(Cmd) -> Port = open_port({spawn, Cmd}, [stream, eof, binary]), - TRef = erlang:start_timer(?CMD_TIMEOUT, self(), timeout), + TRef = erlang:start_timer(?CMD_TIMEOUT, self(), + timeout), recv_data(Port, TRef, <<>>). recv_data(Port, TRef, Buf) -> receive - {Port, {data, Bytes}} -> - NewBuf = <>, - if size(NewBuf) > ?MAX_FILE_SIZE -> - return(Port, TRef, {error, efbig}); - true -> - recv_data(Port, TRef, NewBuf) - end; - {Port, {data, _}} -> - return(Port, TRef, {error, efbig}); - {Port, eof} when Buf /= <<>> -> - return(Port, TRef, {ok, Buf}); - {Port, eof} -> - return(Port, TRef, {error, enodata}); - {timeout, TRef, _} -> - return(Port, TRef, {error, timeout}) + {Port, {data, Bytes}} -> + NewBuf = <>, + if byte_size(NewBuf) > (?MAX_FILE_SIZE) -> + return(Port, TRef, {error, efbig}); + true -> recv_data(Port, TRef, NewBuf) + end; + {Port, {data, _}} -> return(Port, TRef, {error, efbig}); + {Port, eof} when Buf /= <<>> -> + return(Port, TRef, {ok, Buf}); + {Port, eof} -> return(Port, TRef, {error, enodata}); + {timeout, TRef, _} -> + return(Port, TRef, {error, timeout}) end. return(Port, TRef, Result) -> case erlang:cancel_timer(TRef) of - false -> - receive - {timeout, TRef, _} -> - ok - after 0 -> - ok - end; - _ -> - ok + false -> + receive {timeout, TRef, _} -> ok after 0 -> ok end; + _ -> ok end, catch port_close(Port), Result. is_feature_available() -> case get_prog_name() of - Prog when is_list(Prog) -> true; - false -> false + Prog when is_binary(Prog) -> true; + false -> false end. check_captcha_setup() -> case is_feature_available() of - true -> - case create_image() of - {ok, _, _, _} -> - ok; - _Err -> - ?CRITICAL_MSG("Captcha is enabled in the option captcha_cmd, " - "but it can't generate images.", []), - throw({error, captcha_cmd_enabled_but_fails}) - end; - false -> - ok + true -> + case create_image() of + {ok, _, _, _} -> ok; + _Err -> + ?CRITICAL_MSG("Captcha is enabled in the option captcha_cmd, " + "but it can't generate images.", + []), + throw({error, captcha_cmd_enabled_but_fails}) + end; + false -> ok + end. + +lookup_captcha(Id) -> + case ets:lookup(captcha, Id) of + [C] -> {ok, C}; + _ -> {error, enoent} + end. + +check_captcha(Id, ProvidedKey) -> + case ets:lookup(captcha, Id) of + [#captcha{pid = Pid, args = Args, key = ValidKey, + tref = Tref}] -> + ets:delete(captcha, Id), + erlang:cancel_timer(Tref), + if ValidKey == ProvidedKey -> + if is_pid(Pid) -> Pid ! {captcha_succeed, Args}; + true -> ok + end, + captcha_valid; + true -> + if is_pid(Pid) -> Pid ! {captcha_failed, Args}; + true -> ok + end, + captcha_non_valid + end; + _ -> captcha_not_found end. clean_treap(Treap, CleanPriority) -> case treap:is_empty(Treap) of - true -> - Treap; - false -> - {_Key, Priority, _Value} = treap:get_root(Treap), - if - Priority > CleanPriority -> - clean_treap(treap:delete_root(Treap), CleanPriority); - true -> - Treap - end + true -> Treap; + false -> + {_Key, Priority, _Value} = treap:get_root(Treap), + if Priority > CleanPriority -> + clean_treap(treap:delete_root(Treap), CleanPriority); + true -> Treap + end end. now_priority() -> {MSec, Sec, USec} = now(), - -((MSec*1000000 + Sec)*1000000 + USec). + -((MSec * 1000000 + Sec) * 1000000 + USec). diff --git a/src/ejabberd_check.erl b/src/ejabberd_check.erl index ddb2cbdb0..352251806 100644 --- a/src/ejabberd_check.erl +++ b/src/ejabberd_check.erl @@ -31,8 +31,6 @@ -include("ejabberd.hrl"). -include("ejabberd_config.hrl"). --compile([export_all]). - %% TODO: %% We want to implement library checking at launch time to issue %% human readable user messages. @@ -87,7 +85,7 @@ get_db_used() -> fun([Domain, DB], Acc) -> case check_odbc_option( ejabberd_config:get_local_option( - {auth_method, Domain})) of + {auth_method, Domain}, fun(V) -> V end)) of true -> [get_db_type(DB)|Acc]; _ -> Acc end diff --git a/src/ejabberd_commands.erl b/src/ejabberd_commands.erl index b61ef46de..3b8abf97b 100644 --- a/src/ejabberd_commands.erl +++ b/src/ejabberd_commands.erl @@ -228,7 +228,8 @@ init() -> ets:new(ejabberd_commands, [named_table, set, public, {keypos, #ejabberd_commands.name}]). -%% @spec ([ejabberd_commands()]) -> ok +-spec register_commands([ejabberd_commands()]) -> ok. + %% @doc Register ejabberd commands. %% If a command is already registered, a warning is printed and the old command is preserved. register_commands(Commands) -> @@ -243,7 +244,8 @@ register_commands(Commands) -> end, Commands). -%% @spec ([ejabberd_commands()]) -> ok +-spec unregister_commands([ejabberd_commands()]) -> ok. + %% @doc Unregister ejabberd commands. unregister_commands(Commands) -> lists:foreach( @@ -252,7 +254,8 @@ unregister_commands(Commands) -> end, Commands). -%% @spec () -> [{Name::atom(), Args::[aterm()], Desc::string()}] +-spec list_commands() -> [{atom(), [aterm()], string()}]. + %% @doc Get a list of all the available commands, arguments and description. list_commands() -> Commands = ets:match(ejabberd_commands, @@ -262,7 +265,8 @@ list_commands() -> _ = '_'}), [{A, B, C} || [A, B, C] <- Commands]. -%% @spec (Name::atom()) -> {Args::[aterm()], Result::rterm()} | {error, command_unknown} +-spec get_command_format(atom()) -> {[aterm()], rterm()} | {error, command_unknown}. + %% @doc Get the format of arguments and result of a command. get_command_format(Name) -> Matched = ets:match(ejabberd_commands, @@ -277,7 +281,8 @@ get_command_format(Name) -> {Args, Result} end. -%% @spec (Name::atom()) -> ejabberd_commands() | command_not_found +-spec get_command_definition(atom()) -> ejabberd_commands() | command_not_found. + %% @doc Get the definition record of a command. get_command_definition(Name) -> case ets:lookup(ejabberd_commands, Name) of @@ -314,6 +319,8 @@ execute_command2(Command, Arguments) -> ?DEBUG("Executing command ~p:~p with Args=~p", [Module, Function, Arguments]), apply(Module, Function, Arguments). +-spec get_tags_commands() -> [{string(), [string()]}]. + %% @spec () -> [{Tag::string(), [CommandName::string()]}] %% @doc Get all the tags and associated commands. get_tags_commands() -> @@ -377,6 +384,9 @@ check_access_commands(AccessCommands, Auth, Method, Command, Arguments) -> L when is_list(L) -> ok end. +-spec check_auth(noauth) -> noauth_provided; + ({binary(), binary(), binary()}) -> {ok, binary(), binary()}. + check_auth(noauth) -> no_auth_provided; check_auth({User, Server, Password}) -> @@ -391,7 +401,7 @@ check_access(all, _) -> check_access(Access, Auth) -> {ok, User, Server} = check_auth(Auth), %% Check this user has access permission - case acl:match_rule(Server, Access, jlib:make_jid(User, Server, "")) of + case acl:match_rule(Server, Access, jlib:make_jid(User, Server, <<"">>)) of allow -> true; deny -> false end. diff --git a/src/ejabberd_commands.hrl b/src/ejabberd_commands.hrl index 1ababc8be..116bb7357 100644 --- a/src/ejabberd_commands.hrl +++ b/src/ejabberd_commands.hrl @@ -19,10 +19,32 @@ %%% %%%---------------------------------------------------------------------- --record(ejabberd_commands, {name, tags = [], - desc = "", longdesc = "", - module, function, - args = [], result = rescode}). +-type aterm() :: {atom(), atype()}. +-type atype() :: integer | string | binary | + {tuple, [aterm()]} | {list, aterm()}. +-type rterm() :: {atom(), rtype()}. +-type rtype() :: integer | string | atom | + {tuple, [rterm()]} | {list, rterm()} | + rescode | restuple. + +-record(ejabberd_commands, + {name :: atom(), + tags = [] :: [atom()] | '_' | '$2', + desc = "" :: string() | '_' | '$3', + longdesc = "" :: string() | '_', + module :: atom(), + function :: atom(), + args = [] :: [aterm()] | '_' | '$1' | '$2', + result = {res, rescode} :: rterm() | '_' | '$2'}). + +-type ejabberd_commands() :: #ejabberd_commands{name :: atom(), + tags :: [atom()], + desc :: string(), + longdesc :: string(), + module :: atom(), + function :: atom(), + args :: [aterm()], + result :: rterm()}. %% @type ejabberd_commands() = #ejabberd_commands{ %% name = atom(), @@ -50,3 +72,4 @@ %% @type rterm() = {Name::atom(), Type::rtype()}. %% A result term is a tuple with the term name and the term type. + diff --git a/src/ejabberd_config.erl b/src/ejabberd_config.erl index e2a8633ee..0a4208b7f 100644 --- a/src/ejabberd_config.erl +++ b/src/ejabberd_config.erl @@ -29,9 +29,13 @@ -export([start/0, load_file/1, add_global_option/2, add_local_option/2, - get_global_option/1, get_local_option/1]). + get_global_option/2, get_local_option/2, + get_global_option/3, get_local_option/3]). -export([get_vh_by_auth_method/1]). -export([is_file_readable/1]). +-export([get_version/0, get_myhosts/0, get_mylang/0]). +-export([prepare_opt_val/4]). +-export([convert_table_to_binary/5]). -include("ejabberd.hrl"). -include("ejabberd_config.hrl"). @@ -96,11 +100,15 @@ load_file(File) -> %% in which the options 'include_config_file' were parsed %% and the terms in those files were included. %% @spec(string()) -> [term()] +%% @spec(iolist()) -> [term()] +get_plain_terms_file(File) when is_binary(File) -> + get_plain_terms_file(binary_to_list(File)); get_plain_terms_file(File1) -> File = get_absolute_path(File1), case file:consult(File) of {ok, Terms} -> - include_config_files(Terms); + BinTerms = strings_to_binary(Terms), + include_config_files(BinTerms); {error, {LineNumber, erl_parse, _ParseMessage} = Reason} -> ExitText = describe_config_problem(File, Reason, LineNumber), ?ERROR_MSG(ExitText, []), @@ -159,7 +167,7 @@ normalize_hosts(Hosts) -> normalize_hosts([], PrepHosts) -> lists:reverse(PrepHosts); normalize_hosts([Host|Hosts], PrepHosts) -> - case jlib:nodeprep(Host) of + case jlib:nodeprep(iolist_to_binary(Host)) of error -> ?ERROR_MSG("Can't load config file: " "invalid host name [~p]", [Host]), @@ -564,7 +572,6 @@ set_opts(State) -> exit("Error reading Mnesia database") end. - add_global_option(Opt, Val) -> mnesia:transaction(fun() -> mnesia:write(#config{key = Opt, @@ -577,23 +584,63 @@ add_local_option(Opt, Val) -> value = Val}) end). +-spec prepare_opt_val(any(), any(), check_fun(), any()) -> any(). -get_global_option(Opt) -> +prepare_opt_val(Opt, Val, F, Default) -> + Res = case F of + {Mod, Fun} -> + catch Mod:Fun(Val); + _ -> + catch F(Val) + end, + case Res of + {'EXIT', _} -> + ?INFO_MSG("Configuration problem:~n" + "** Option: ~s~n" + "** Invalid value: ~s~n" + "** Using as fallback: ~s", + [format_term(Opt), + format_term(Val), + format_term(Default)]), + Default; + _ -> + Res + end. + +-type check_fun() :: fun((any()) -> any()) | {module(), atom()}. + +-spec get_global_option(any(), check_fun()) -> any(). + +get_global_option(Opt, F) -> + get_global_option(Opt, F, undefined). + +-spec get_global_option(any(), check_fun(), any()) -> any(). + +get_global_option(Opt, F, Default) -> case ets:lookup(config, Opt) of [#config{value = Val}] -> - Val; + prepare_opt_val(Opt, Val, F, Default); _ -> - undefined + Default end. -get_local_option(Opt) -> +-spec get_local_option(any(), check_fun()) -> any(). + +get_local_option(Opt, F) -> + get_local_option(Opt, F, undefined). + +-spec get_local_option(any(), check_fun(), any()) -> any(). + +get_local_option(Opt, F, Default) -> case ets:lookup(local_config, Opt) of [#local_config{value = Val}] -> - Val; + prepare_opt_val(Opt, Val, F, Default); _ -> - undefined + Default end. +-spec get_vh_by_auth_method(atom()) -> [binary()]. + %% Return the list of hosts handled by a given module get_vh_by_auth_method(AuthMethod) -> mnesia:dirty_select(local_config, @@ -613,8 +660,25 @@ is_file_readable(Path) -> false end. +get_version() -> + list_to_binary(element(2, application:get_key(ejabberd, vsn))). + +-spec get_myhosts() -> [binary()]. + +get_myhosts() -> + ejabberd_config:get_global_option(hosts, fun(V) -> V end). + +-spec get_mylang() -> binary(). + +get_mylang() -> + ejabberd_config:get_global_option( + language, + fun iolist_to_binary/1, + <<"en">>). + replace_module(mod_announce_odbc) -> {mod_announce, odbc}; replace_module(mod_blocking_odbc) -> {mod_blocking, odbc}; +replace_module(mod_caps_odbc) -> {mod_caps, odbc}; replace_module(mod_irc_odbc) -> {mod_irc, odbc}; replace_module(mod_last_odbc) -> {mod_last, odbc}; replace_module(mod_muc_odbc) -> {mod_muc, odbc}; @@ -632,10 +696,161 @@ replace_modules(Modules) -> fun({Module, Opts}) -> case replace_module(Module) of {NewModule, DBType} -> + emit_deprecation_warning(Module, NewModule, DBType), NewOpts = [{db_type, DBType} | lists:keydelete(db_type, 1, Opts)], {NewModule, NewOpts}; NewModule -> + if Module /= NewModule -> + emit_deprecation_warning(Module, NewModule); + true -> + ok + end, {NewModule, Opts} end end, Modules). + +strings_to_binary([]) -> + []; +strings_to_binary(L) when is_list(L) -> + case is_string(L) of + true -> + list_to_binary(L); + false -> + strings_to_binary1(L) + end; +strings_to_binary(T) when is_tuple(T) -> + list_to_tuple(strings_to_binary(tuple_to_list(T))); +strings_to_binary(X) -> + X. + +strings_to_binary1([El|L]) -> + [strings_to_binary(El)|strings_to_binary1(L)]; +strings_to_binary1([]) -> + []; +strings_to_binary1(T) -> + T. + +is_string([C|T]) when (C >= 0) and (C =< 255) -> + is_string(T); +is_string([]) -> + true; +is_string(_) -> + false. + +binary_to_strings(B) when is_binary(B) -> + binary_to_list(B); +binary_to_strings([H|T]) -> + [binary_to_strings(H)|binary_to_strings(T)]; +binary_to_strings(T) when is_tuple(T) -> + list_to_tuple(binary_to_strings(tuple_to_list(T))); +binary_to_strings(T) -> + T. + +format_term(Bin) when is_binary(Bin) -> + io_lib:format("\"~s\"", [Bin]); +format_term(S) when is_list(S), S /= [] -> + case lists:all(fun(C) -> (C>=0) and (C=<255) end, S) of + true -> + io_lib:format("\"~s\"", [S]); + false -> + io_lib:format("~p", [binary_to_strings(S)]) + end; +format_term(T) -> + io_lib:format("~p", [binary_to_strings(T)]). + +-spec convert_table_to_binary(atom(), [atom()], atom(), + fun(), fun()) -> ok. + +convert_table_to_binary(Tab, Fields, Type, DetectFun, ConvertFun) -> + case is_table_still_list(Tab, DetectFun) of + true -> + ?INFO_MSG("Converting '~s' table from strings to binaries.", [Tab]), + TmpTab = list_to_atom(atom_to_list(Tab) ++ "_tmp_table"), + catch mnesia:delete_table(TmpTab), + case mnesia:create_table(TmpTab, + [{disc_only_copies, [node()]}, + {type, Type}, + {local_content, true}, + {record_name, Tab}, + {attributes, Fields}]) of + {atomic, ok} -> + mnesia:transform_table(Tab, ignore, Fields), + case mnesia:transaction( + fun() -> + mnesia:write_lock_table(TmpTab), + mnesia:foldl( + fun(R, _) -> + NewR = ConvertFun(R), + mnesia:dirty_write(TmpTab, NewR) + end, ok, Tab) + end) of + {atomic, ok} -> + mnesia:clear_table(Tab), + case mnesia:transaction( + fun() -> + mnesia:write_lock_table(Tab), + mnesia:foldl( + fun(R, _) -> + mnesia:dirty_write(R) + end, ok, TmpTab) + end) of + {atomic, ok} -> + mnesia:delete_table(TmpTab); + Err -> + report_and_stop(Tab, Err) + end; + Err -> + report_and_stop(Tab, Err) + end; + Err -> + report_and_stop(Tab, Err) + end; + false -> + ok + end. + +is_table_still_list(Tab, DetectFun) -> + is_table_still_list(Tab, DetectFun, mnesia:dirty_first(Tab)). + +is_table_still_list(_Tab, _DetectFun, '$end_of_table') -> + false; +is_table_still_list(Tab, DetectFun, Key) -> + Rs = mnesia:dirty_read(Tab, Key), + Res = lists:foldl(fun(_, true) -> + true; + (_, false) -> + false; + (R, _) -> + case DetectFun(R) of + '$next' -> + '$next'; + El -> + is_list(El) + end + end, '$next', Rs), + case Res of + true -> + true; + false -> + false; + '$next' -> + is_table_still_list(Tab, DetectFun, mnesia:dirty_next(Tab, Key)) + end. + +report_and_stop(Tab, Err) -> + ErrTxt = lists:flatten( + io_lib:format( + "Failed to convert '~s' table to binary: ~p", + [Tab, Err])), + ?CRITICAL_MSG(ErrTxt, []), + timer:sleep(1000), + halt(string:substr(ErrTxt, 1, 199)). + +emit_deprecation_warning(Module, NewModule, DBType) -> + ?WARNING_MSG("Module ~s is deprecated, use {~s, [{db_type, ~s}, ...]}" + " instead", [Module, NewModule, DBType]). + +emit_deprecation_warning(Module, NewModule) -> + ?WARNING_MSG("Module ~s is deprecated, use ~s instead", + [Module, NewModule]). diff --git a/src/ejabberd_config.hrl b/src/ejabberd_config.hrl index b0fa46aca..bf749dd19 100644 --- a/src/ejabberd_config.hrl +++ b/src/ejabberd_config.hrl @@ -19,10 +19,16 @@ %%% %%%---------------------------------------------------------------------- --record(config, {key, value}). --record(local_config, {key, value}). --record(state, {opts = [], - hosts = [], - override_local = false, - override_global = false, - override_acls = false}). +-record(config, {key :: any(), value :: any()}). + +-record(local_config, {key :: any(), value :: any()}). + +-type config() :: #config{}. +-type local_config() :: #local_config{}. + +-record(state, + {opts = [] :: [acl:acl() | config() | local_config()], + hosts = [] :: [binary()], + override_local = false :: boolean(), + override_global = false :: boolean(), + override_acls = false :: boolean()}). diff --git a/src/ejabberd_ctl.erl b/src/ejabberd_ctl.erl index 9b41b1463..2b4702176 100644 --- a/src/ejabberd_ctl.erl +++ b/src/ejabberd_ctl.erl @@ -72,10 +72,10 @@ start() -> _ -> case net_kernel:longnames() of true -> - SNode ++ "@" ++ inet_db:gethostname() ++ - "." ++ inet_db:res_option(domain); + lists:flatten([SNode, "@", inet_db:gethostname(), + ".", inet_db:res_option(domain)]); false -> - SNode ++ "@" ++ inet_db:gethostname(); + lists:flatten([SNode, "@", inet_db:gethostname()]); _ -> SNode end @@ -124,6 +124,8 @@ unregister_commands(CmdDescs, Module, Function) -> %% Process %%----------------------------- +-spec process([string()]) -> non_neg_integer(). + %% The commands status, stop and restart are defined here to ensure %% they are usable even if ejabberd is completely stopped. process(["status"]) -> @@ -159,7 +161,7 @@ process(["mnesia", "info"]) -> mnesia:info(), ?STATUS_SUCCESS; -process(["mnesia", Arg]) when is_list(Arg) -> +process(["mnesia", Arg]) -> case catch mnesia:system_info(list_to_atom(Arg)) of {'EXIT', Error} -> ?PRINT("Error: ~p~n", [Error]); Return -> ?PRINT("~p~n", [Return]) @@ -190,8 +192,9 @@ process(["help" | Mode]) -> print_usage_help(MaxC, ShCode), ?STATUS_SUCCESS; [CmdString | _] -> - CmdStringU = ejabberd_regexp:greplace(CmdString, "-", "_"), - print_usage_commands(CmdStringU, MaxC, ShCode), + CmdStringU = ejabberd_regexp:greplace( + list_to_binary(CmdString), <<"-">>, <<"_">>), + print_usage_commands(binary_to_list(CmdStringU), MaxC, ShCode), ?STATUS_SUCCESS end; @@ -214,30 +217,27 @@ process2(Args, AccessCommands) -> process2(Args, Auth, AccessCommands) -> case try_run_ctp(Args, Auth, AccessCommands) of {String, wrong_command_arguments} - when is_list(String) -> + when is_list(String) -> io:format(lists:flatten(["\n" | String]++["\n"])), [CommandString | _] = Args, process(["help" | [CommandString]]), {lists:flatten(String), ?STATUS_ERROR}; {String, Code} - when is_list(String) and is_integer(Code) -> + when is_list(String) and is_integer(Code) -> {lists:flatten(String), Code}; String - when is_list(String) -> + when is_list(String) -> {lists:flatten(String), ?STATUS_SUCCESS}; Code - when is_integer(Code) -> + when is_integer(Code) -> {"", Code}; Other -> {"Erroneous result: " ++ io_lib:format("~p", [Other]), ?STATUS_ERROR} end. get_accesscommands() -> - case ejabberd_config:get_local_option(ejabberdctl_access_commands) of - ACs when is_list(ACs) -> ACs; - _ -> [] - end. - + ejabberd_config:get_local_option(ejabberdctl_access_commands, + fun(V) when is_list(V) -> V end, []). %%----------------------------- %% Command calling @@ -281,8 +281,9 @@ try_call_command(Args, Auth, AccessCommands) -> %% @spec (Args::[string()], Auth, AccessCommands) -> string() | integer() | {string(), integer()} | {error, ErrorType} call_command([CmdString | Args], Auth, AccessCommands) -> - CmdStringU = ejabberd_regexp:greplace(CmdString, "-", "_"), - Command = list_to_atom(CmdStringU), + CmdStringU = ejabberd_regexp:greplace( + list_to_binary(CmdString), <<"-">>, <<"_">>), + Command = list_to_atom(binary_to_list(CmdStringU)), case ejabberd_commands:get_command_format(Command) of {error, command_unknown} -> {error, command_unknown}; @@ -331,10 +332,12 @@ format_args(Args, ArgsFormat) -> format_arg(Arg, integer) -> format_arg2(Arg, "~d"); +format_arg(Arg, binary) -> + list_to_binary(format_arg(Arg, string)); format_arg("", string) -> ""; format_arg(Arg, string) -> - NumChars = integer_to_list(string:len(Arg)), + NumChars = integer_to_list(length(Arg)), Parse = "~" ++ NumChars ++ "c", format_arg2(Arg, Parse). @@ -540,24 +543,25 @@ split_desc_segments(MaxL, Words) -> join(L, Words) -> join(L, Words, 0, [], []). -join(_L, [], _LenLastSeg, LastSeg, ResSeg) -> - ResSeg2 = [lists:reverse(LastSeg) | ResSeg], - lists:reverse(ResSeg2); -join(L, [Word | Words], LenLastSeg, LastSeg, ResSeg) -> - LWord = length(Word), - case LWord + LenLastSeg < L of - true -> - %% This word fits in the last segment - %% If this word ends with "\n", reset column counter - case string:str(Word, "\n") of - 0 -> - join(L, Words, LenLastSeg+LWord+1, [" ", Word | LastSeg], ResSeg); - _ -> - join(L, Words, LWord+1, [" ", Word | LastSeg], ResSeg) - end; - false -> - join(L, Words, LWord, [" ", Word], [lists:reverse(LastSeg) | ResSeg]) - end. +join(_Len, [], _CurSegLen, CurSeg, AllSegs) -> + lists:reverse([CurSeg | AllSegs]); +join(Len, [Word | Tail], CurSegLen, CurSeg, AllSegs) -> + WordLen = length(Word), + SegSize = WordLen + CurSegLen + 1, + {NewCurSeg, NewAllSegs, NewCurSegLen} = + if SegSize < Len -> + {[CurSeg, " ", Word], AllSegs, SegSize}; + true -> + {Word, [CurSeg | AllSegs], WordLen} + end, + NewLen = case string:str(Word, "\n") of + 0 -> + NewCurSegLen; + _ -> + 0 + end, + join(Len, Tail, NewLen, NewCurSeg, NewAllSegs). + format_command_lines(CALD, MaxCmdLen, MaxC, ShCode, dual) when MaxC - MaxCmdLen < 40 -> @@ -568,7 +572,8 @@ format_command_lines(CALD, MaxCmdLen, MaxC, ShCode, dual) -> lists:map( fun({Cmd, Args, CmdArgsL, Desc}) -> DescFmt = prepare_description(MaxCmdLen+4, MaxC, Desc), - [" ", ?B(Cmd), " ", [[?U(Arg), " "] || Arg <- Args], string:chars($\s, MaxCmdLen - CmdArgsL + 1), + [" ", ?B(Cmd), " ", [[?U(Arg), " "] || Arg <- Args], + string:chars($\s, MaxCmdLen - CmdArgsL + 1), DescFmt, "\n"] end, CALD); @@ -608,7 +613,8 @@ print_usage_tags(Tag, MaxC, ShCode) -> end, CommandsList = lists:map( fun(NameString) -> - C = ejabberd_commands:get_command_definition(list_to_atom(NameString)), + C = ejabberd_commands:get_command_definition( + list_to_atom(NameString)), #ejabberd_commands{name = Name, args = Args, desc = Desc} = C, @@ -689,10 +695,10 @@ filter_commands(All, SubString) -> end. filter_commands_regexp(All, Glob) -> - RegExp = ejabberd_regexp:sh_to_awk(Glob), + RegExp = ejabberd_regexp:sh_to_awk(list_to_binary(Glob)), lists:filter( fun(Command) -> - case ejabberd_regexp:run(Command, RegExp) of + case ejabberd_regexp:run(list_to_binary(Command), RegExp) of match -> true; nomatch -> diff --git a/src/ejabberd_ctl.hrl b/src/ejabberd_ctl.hrl index 27bb2489c..09a5287ee 100644 --- a/src/ejabberd_ctl.hrl +++ b/src/ejabberd_ctl.hrl @@ -20,6 +20,9 @@ %%%---------------------------------------------------------------------- -define(STATUS_SUCCESS, 0). --define(STATUS_ERROR, 1). --define(STATUS_USAGE, 2). --define(STATUS_BADRPC, 3). + +-define(STATUS_ERROR, 1). + +-define(STATUS_USAGE, 2). + +-define(STATUS_BADRPC, 3). diff --git a/src/ejabberd_frontend_socket.erl b/src/ejabberd_frontend_socket.erl index 93df685c7..98f305536 100644 --- a/src/ejabberd_frontend_socket.erl +++ b/src/ejabberd_frontend_socket.erl @@ -25,6 +25,7 @@ %%%---------------------------------------------------------------------- -module(ejabberd_frontend_socket). + -author('alexey@process-one.net'). -behaviour(gen_server). @@ -48,8 +49,8 @@ sockname/1, peername/1]). %% gen_server callbacks --export([init/1, handle_call/3, handle_cast/2, handle_info/2, - terminate/2, code_change/3]). +-export([init/1, handle_call/3, handle_cast/2, + handle_info/2, terminate/2, code_change/3]). -record(state, {sockmod, socket, receiver}). @@ -68,30 +69,30 @@ start_link(Module, SockMod, Socket, Opts, Receiver) -> start(Module, SockMod, Socket, Opts) -> case Module:socket_type() of - xml_stream -> - MaxStanzaSize = - case lists:keysearch(max_stanza_size, 1, Opts) of - {value, {_, Size}} -> Size; - _ -> infinity - end, - Receiver = ejabberd_receiver:start(Socket, SockMod, none, MaxStanzaSize), - case SockMod:controlling_process(Socket, Receiver) of - ok -> - ok; - {error, _Reason} -> - SockMod:close(Socket) - end, - supervisor:start_child(ejabberd_frontend_socket_sup, - [Module, SockMod, Socket, Opts, Receiver]); - raw -> - %{ok, Pid} = Module:start({SockMod, Socket}, Opts), - %case SockMod:controlling_process(Socket, Pid) of - % ok -> - % ok; - % {error, _Reason} -> - % SockMod:close(Socket) - %end - todo + xml_stream -> + MaxStanzaSize = case lists:keysearch(max_stanza_size, 1, + Opts) + of + {value, {_, Size}} -> Size; + _ -> infinity + end, + Receiver = ejabberd_receiver:start(Socket, SockMod, + none, MaxStanzaSize), + case SockMod:controlling_process(Socket, Receiver) of + ok -> ok; + {error, _Reason} -> SockMod:close(Socket) + end, + supervisor:start_child(ejabberd_frontend_socket_sup, + [Module, SockMod, Socket, Opts, Receiver]); + raw -> + %{ok, Pid} = Module:start({SockMod, Socket}, Opts), + %case SockMod:controlling_process(Socket, Pid) of + % ok -> + % ok; + % {error, _Reason} -> + % SockMod:close(Socket) + %end + todo end. starttls(FsmRef, _TLSOpts) -> @@ -108,8 +109,7 @@ compress(FsmRef) -> FsmRef. compress(FsmRef, Data) -> - gen_server:call(FsmRef, {compress, Data}), - FsmRef. + gen_server:call(FsmRef, {compress, Data}), FsmRef. reset_stream(FsmRef) -> gen_server:call(FsmRef, reset_stream). @@ -120,8 +120,7 @@ send(FsmRef, Data) -> change_shaper(FsmRef, Shaper) -> gen_server:call(FsmRef, {change_shaper, Shaper}). -monitor(FsmRef) -> - erlang:monitor(process, FsmRef). +monitor(FsmRef) -> erlang:monitor(process, FsmRef). get_sockmod(FsmRef) -> gen_server:call(FsmRef, get_sockmod). @@ -132,11 +131,9 @@ get_peer_certificate(FsmRef) -> get_verify_result(FsmRef) -> gen_server:call(FsmRef, get_verify_result). -close(FsmRef) -> - gen_server:call(FsmRef, close). +close(FsmRef) -> gen_server:call(FsmRef, close). -sockname(FsmRef) -> - gen_server:call(FsmRef, sockname). +sockname(FsmRef) -> gen_server:call(FsmRef, sockname). peername(_FsmRef) -> %% TODO: Frontend improvements planned by Aleksey @@ -156,7 +153,6 @@ peername(_FsmRef) -> %% Description: Initiates the server %%-------------------------------------------------------------------- init([Module, SockMod, Socket, Opts, Receiver]) -> - %% TODO: monitor the receiver Node = ejabberd_node_groups:get_closest_node(backend), {SockMod2, Socket2} = check_starttls(SockMod, Socket, Receiver, Opts), {ok, Pid} = @@ -188,7 +184,8 @@ handle_call({starttls, TLSOpts, Data}, _From, State) -> catch (State#state.sockmod):send( State#state.socket, Data), Reply = ok, - {reply, Reply, State#state{socket = TLSSocket, sockmod = tls}, + {reply, Reply, + State#state{socket = TLSSocket, sockmod = tls}, ?HIBERNATE_TIMEOUT}; handle_call(compress, _From, State) -> @@ -208,42 +205,35 @@ handle_call({compress, Data}, _From, State) -> catch (State#state.sockmod):send( State#state.socket, Data), Reply = ok, - {reply, Reply, State#state{socket = ZlibSocket, sockmod = ejabberd_zlib}, + {reply, Reply, + State#state{socket = ZlibSocket, sockmod = ejabberd_zlib}, ?HIBERNATE_TIMEOUT}; - handle_call(reset_stream, _From, State) -> ejabberd_receiver:reset_stream(State#state.receiver), Reply = ok, {reply, Reply, State, ?HIBERNATE_TIMEOUT}; - handle_call({send, Data}, _From, State) -> - catch (State#state.sockmod):send( - State#state.socket, Data), + catch (State#state.sockmod):send(State#state.socket, Data), Reply = ok, {reply, Reply, State, ?HIBERNATE_TIMEOUT}; - handle_call({change_shaper, Shaper}, _From, State) -> - ejabberd_receiver:change_shaper(State#state.receiver, Shaper), + ejabberd_receiver:change_shaper(State#state.receiver, + Shaper), Reply = ok, {reply, Reply, State, ?HIBERNATE_TIMEOUT}; - handle_call(get_sockmod, _From, State) -> Reply = State#state.sockmod, {reply, Reply, State, ?HIBERNATE_TIMEOUT}; - handle_call(get_peer_certificate, _From, State) -> Reply = tls:get_peer_certificate(State#state.socket), {reply, Reply, State, ?HIBERNATE_TIMEOUT}; - handle_call(get_verify_result, _From, State) -> Reply = tls:get_verify_result(State#state.socket), {reply, Reply, State, ?HIBERNATE_TIMEOUT}; - handle_call(close, _From, State) -> ejabberd_receiver:close(State#state.receiver), Reply = ok, {stop, normal, Reply, State}; - handle_call(sockname, _From, State) -> #state{sockmod = SockMod, socket = Socket} = State, Reply = @@ -254,21 +244,15 @@ handle_call(sockname, _From, State) -> SockMod:sockname(Socket) end, {reply, Reply, State, ?HIBERNATE_TIMEOUT}; - handle_call(peername, _From, State) -> #state{sockmod = SockMod, socket = Socket} = State, - Reply = - case SockMod of - gen_tcp -> - inet:peername(Socket); - _ -> - SockMod:peername(Socket) - end, + Reply = case SockMod of + gen_tcp -> inet:peername(Socket); + _ -> SockMod:peername(Socket) + end, {reply, Reply, State, ?HIBERNATE_TIMEOUT}; - handle_call(_Request, _From, State) -> - Reply = ok, - {reply, Reply, State, ?HIBERNATE_TIMEOUT}. + Reply = ok, {reply, Reply, State, ?HIBERNATE_TIMEOUT}. %%-------------------------------------------------------------------- %% Function: handle_cast(Msg, State) -> {noreply, State} | @@ -286,7 +270,8 @@ handle_cast(_Msg, State) -> %% Description: Handling all non call/cast messages %%-------------------------------------------------------------------- handle_info(timeout, State) -> - proc_lib:hibernate(gen_server, enter_loop, [?MODULE, [], State]), + proc_lib:hibernate(gen_server, enter_loop, + [?MODULE, [], State]), {noreply, State, ?HIBERNATE_TIMEOUT}; handle_info(_Info, State) -> {noreply, State, ?HIBERNATE_TIMEOUT}. @@ -298,15 +283,13 @@ handle_info(_Info, State) -> %% cleaning up. When it returns, the gen_server terminates with Reason. %% The return value is ignored. %%-------------------------------------------------------------------- -terminate(_Reason, _State) -> - ok. +terminate(_Reason, _State) -> ok. %%-------------------------------------------------------------------- %% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} %% Description: Convert process state when code is changed %%-------------------------------------------------------------------- -code_change(_OldVsn, State, _Extra) -> - {ok, State}. +code_change(_OldVsn, State, _Extra) -> {ok, State}. %%-------------------------------------------------------------------- %%% Internal functions diff --git a/src/ejabberd_hooks.erl b/src/ejabberd_hooks.erl index 26573aa7d..e4f9f597b 100644 --- a/src/ejabberd_hooks.erl +++ b/src/ejabberd_hooks.erl @@ -67,58 +67,76 @@ start_link() -> gen_server:start_link({local, ejabberd_hooks}, ejabberd_hooks, [], []). -%% @spec (Hook::atom(), Function::function(), Seq::integer()) -> ok +-spec add(atom(), fun(), number()) -> any(). + %% @doc See add/4. add(Hook, Function, Seq) when is_function(Function) -> add(Hook, global, undefined, Function, Seq). +-spec add(atom(), binary() | atom(), fun() | atom() , number()) -> any(). add(Hook, Host, Function, Seq) when is_function(Function) -> add(Hook, Host, undefined, Function, Seq); -%% @spec (Hook::atom(), Module::atom(), Function::atom(), Seq::integer()) -> ok %% @doc Add a module and function to this hook. %% The integer sequence is used to sort the calls: low number is called before high number. add(Hook, Module, Function, Seq) -> add(Hook, global, Module, Function, Seq). +-spec add(atom(), binary() | global, atom(), atom() | fun(), number()) -> any(). + add(Hook, Host, Module, Function, Seq) -> gen_server:call(ejabberd_hooks, {add, Hook, Host, Module, Function, Seq}). +-spec add_dist(atom(), atom(), atom(), atom() | fun(), number()) -> any(). + add_dist(Hook, Node, Module, Function, Seq) -> gen_server:call(ejabberd_hooks, {add, Hook, global, Node, Module, Function, Seq}). +-spec add_dist(atom(), binary() | global, atom(), atom(), atom() | fun(), number()) -> any(). + add_dist(Hook, Host, Node, Module, Function, Seq) -> gen_server:call(ejabberd_hooks, {add, Hook, Host, Node, Module, Function, Seq}). -%% @spec (Hook::atom(), Function::function(), Seq::integer()) -> ok +-spec delete(atom(), fun(), number()) -> ok. + %% @doc See del/4. delete(Hook, Function, Seq) when is_function(Function) -> delete(Hook, global, undefined, Function, Seq). +-spec delete(atom(), binary() | atom(), atom() | fun(), number()) -> ok. + delete(Hook, Host, Function, Seq) when is_function(Function) -> delete(Hook, Host, undefined, Function, Seq); -%% @spec (Hook::atom(), Module::atom(), Function::atom(), Seq::integer()) -> ok %% @doc Delete a module and function from this hook. %% It is important to indicate exactly the same information than when the call was added. delete(Hook, Module, Function, Seq) -> delete(Hook, global, Module, Function, Seq). +-spec delete(atom(), binary() | global, atom(), atom() | fun(), number()) -> ok. + delete(Hook, Host, Module, Function, Seq) -> gen_server:call(ejabberd_hooks, {delete, Hook, Host, Module, Function, Seq}). +-spec delete_dist(atom(), atom(), atom(), atom() | fun(), number()) -> ok. + delete_dist(Hook, Node, Module, Function, Seq) -> delete_dist(Hook, global, Node, Module, Function, Seq). +-spec delete_dist(atom(), binary() | global, atom(), atom(), atom() | fun(), number()) -> ok. + delete_dist(Hook, Host, Node, Module, Function, Seq) -> gen_server:call(ejabberd_hooks, {delete, Hook, Host, Node, Module, Function, Seq}). -%% @spec (Hook::atom(), Args) -> ok +-spec run(atom(), list()) -> ok. + %% @doc Run the calls of this hook in order, don't care about function results. %% If a call returns stop, no more calls are performed. run(Hook, Args) -> run(Hook, global, Args). +-spec run(atom(), binary() | global, list()) -> ok. + run(Hook, Host, Args) -> case ets:lookup(hooks, {Hook, Host}) of [{_, Ls}] -> @@ -127,7 +145,8 @@ run(Hook, Host, Args) -> ok end. -%% @spec (Hook::atom(), Val, Args) -> Val | stopped | NewVal +-spec run_fold(atom(), any(), list()) -> any(). + %% @doc Run the calls of this hook in order. %% The arguments passed to the function are: [Val | Args]. %% The result of a call is used as Val for the next call. @@ -136,6 +155,8 @@ run(Hook, Host, Args) -> run_fold(Hook, Val, Args) -> run_fold(Hook, global, Val, Args). +-spec run_fold(atom(), binary() | global, any(), list()) -> any(). + run_fold(Hook, Host, Val, Args) -> case ets:lookup(hooks, {Hook, Host}) of [{_, Ls}] -> diff --git a/src/ejabberd_listener.erl b/src/ejabberd_listener.erl index 78f1691e1..9df8678b0 100644 --- a/src/ejabberd_listener.erl +++ b/src/ejabberd_listener.erl @@ -35,7 +35,8 @@ stop_listener/2, parse_listener_portip/2, add_listener/3, - delete_listener/2 + delete_listener/2, + validate_cfg/1 ]). -include("ejabberd.hrl"). @@ -53,7 +54,7 @@ init(_) -> {ok, {{one_for_one, 10, 1}, []}}. bind_tcp_ports() -> - case ejabberd_config:get_local_option(listen) of + case ejabberd_config:get_local_option(listen, fun validate_cfg/1) of undefined -> ignore; Ls -> @@ -77,7 +78,8 @@ bind_tcp_port(PortIP, Module, RawOpts) -> udp -> ok; _ -> ListenSocket = listen_tcp(PortIP, Module, SockOpts, Port, IPS), - ets:insert(listen_sockets, {PortIP, ListenSocket}) + ets:insert(listen_sockets, {PortIP, ListenSocket}), + ok end catch throw:{error, Error} -> @@ -85,7 +87,7 @@ bind_tcp_port(PortIP, Module, RawOpts) -> end. start_listeners() -> - case ejabberd_config:get_local_option(listen) of + case ejabberd_config:get_local_option(listen, fun validate_cfg/1) of undefined -> ignore; Ls -> @@ -215,17 +217,17 @@ parse_listener_portip(PortIP, Opts) -> case add_proto(PortIP, Opts) of {P, Prot} -> T = get_ip_tuple(IPOpt, IPVOpt), - S = inet_parse:ntoa(T), + S = jlib:ip_to_list(T), {P, T, S, Prot}; {P, T, Prot} when is_integer(P) and is_tuple(T) -> - S = inet_parse:ntoa(T), + S = jlib:ip_to_list(T), {P, T, S, Prot}; - {P, S, Prot} when is_integer(P) and is_list(S) -> - [S | _] = string:tokens(S, "/"), - {ok, T} = inet_parse:address(S), + {P, S, Prot} when is_integer(P) and is_binary(S) -> + [S | _] = str:tokens(S, <<"/">>), + {ok, T} = inet_parse:address(binary_to_list(S)), {P, T, S, Prot} end, - IPV = case size(IPT) of + IPV = case tuple_size(IPT) of 4 -> inet; 8 -> inet6 end, @@ -337,7 +339,7 @@ start_listener2(Port, Module, Opts) -> start_listener_sup(Port, Module, Opts). start_module_sup(_Port, Module) -> - Proc1 = gen_mod:get_module_proc("sup", Module), + Proc1 = gen_mod:get_module_proc(<<"sup">>, Module), ChildSpec1 = {Proc1, {ejabberd_tmp_sup, start_link, [Proc1, strip_frontend(Module)]}, @@ -357,7 +359,7 @@ start_listener_sup(Port, Module, Opts) -> supervisor:start_child(ejabberd_listeners, ChildSpec). stop_listeners() -> - Ports = ejabberd_config:get_local_option(listen), + Ports = ejabberd_config:get_local_option(listen, fun validate_cfg/1), lists:foreach( fun({PortIpNetp, Module, _Opts}) -> delete_listener(PortIpNetp, Module) @@ -390,7 +392,8 @@ add_listener(PortIP, Module, Opts) -> PortIP1 = {Port, IPT, Proto}, case start_listener(PortIP1, Module, Opts) of {ok, _Pid} -> - Ports = case ejabberd_config:get_local_option(listen) of + Ports = case ejabberd_config:get_local_option( + listen, fun validate_cfg/1) of undefined -> []; Ls -> @@ -420,7 +423,8 @@ delete_listener(PortIP, Module) -> delete_listener(PortIP, Module, Opts) -> {Port, IPT, _, _, Proto, _} = parse_listener_portip(PortIP, Opts), PortIP1 = {Port, IPT, Proto}, - Ports = case ejabberd_config:get_local_option(listen) of + Ports = case ejabberd_config:get_local_option( + listen, fun validate_cfg/1) of undefined -> []; Ls -> @@ -430,11 +434,16 @@ delete_listener(PortIP, Module, Opts) -> ejabberd_config:add_local_option(listen, Ports1), stop_listener(PortIP1, Module). + +-spec is_frontend({frontend, module} | module()) -> boolean(). + is_frontend({frontend, _Module}) -> true; is_frontend(_) -> false. %% @doc(FrontMod) -> atom() %% where FrontMod = atom() | {frontend, atom()} +-spec strip_frontend({frontend, module()} | module()) -> module(). + strip_frontend({frontend, Module}) -> Module; strip_frontend(Module) when is_atom(Module) -> Module. @@ -505,7 +514,7 @@ socket_error(Reason, PortIP, Module, SockOpts, Port, IPS) -> "IP address not available: " ++ IPS; eaddrinuse -> "IP address and port number already used: " - ++IPS++" "++integer_to_list(Port); + ++binary_to_list(IPS)++" "++integer_to_list(Port); _ -> format_error(Reason) end, @@ -520,3 +529,44 @@ format_error(Reason) -> ReasonStr -> ReasonStr end. + +-define(IS_CHAR(C), (is_integer(C) and (C >= 0) and (C =< 255))). +-define(IS_UINT(U), (is_integer(U) and (U >= 0) and (U =< 65535))). +-define(IS_PORT(P), (is_integer(P) and (P > 0) and (P =< 65535))). +-define(IS_TRANSPORT(T), ((T == tcp) or (T == udp))). + +-type transport() :: udp | tcp. +-type port_ip_transport() :: inet:port_number() | + {inet:port_number(), transport()} | + {inet:port_number(), inet:ip_address()} | + {inet:port_number(), inet:ip_address(), + transport()}. +-spec validate_cfg(list()) -> [{port_ip_transport(), module(), list()}]. + +validate_cfg(L) -> + lists:map( + fun({PortIPTransport, Mod, Opts}) when is_atom(Mod), is_list(Opts) -> + case PortIPTransport of + Port when ?IS_PORT(Port) -> + {Port, Mod, Opts}; + {Port, Trans} when ?IS_PORT(Port) and ?IS_TRANSPORT(Trans) -> + {{Port, Trans}, Mod, Opts}; + {Port, IP} when ?IS_PORT(Port) -> + {{Port, prepare_ip(IP)}, Mod, Opts}; + {Port, IP, Trans} when ?IS_PORT(Port) and ?IS_TRANSPORT(Trans) -> + {{Port, prepare_ip(IP), Trans}, Mod, Opts} + end + end, L). + +prepare_ip({A, B, C, D} = IP) + when ?IS_CHAR(A) and ?IS_CHAR(B) and ?IS_CHAR(C) and ?IS_CHAR(D) -> + IP; +prepare_ip({A, B, C, D, E, F, G, H} = IP) + when ?IS_UINT(A) and ?IS_UINT(B) and ?IS_UINT(C) and ?IS_UINT(D) + and ?IS_UINT(E) and ?IS_UINT(F) and ?IS_UINT(G) and ?IS_UINT(H) -> + IP; +prepare_ip(IP) when is_list(IP) -> + {ok, Addr} = inet_parse:address(IP), + Addr; +prepare_ip(IP) when is_binary(IP) -> + prepare_ip(binary_to_list(IP)). diff --git a/src/ejabberd_local.erl b/src/ejabberd_local.erl index 1fe7cb0a4..12dfea0c8 100644 --- a/src/ejabberd_local.erl +++ b/src/ejabberd_local.erl @@ -25,6 +25,7 @@ %%%---------------------------------------------------------------------- -module(ejabberd_local). + -author('alexey@process-one.net'). -behaviour(gen_server). @@ -32,30 +33,27 @@ %% API -export([start_link/0]). --export([route/3, - route_iq/4, - route_iq/5, - process_iq_reply/3, - register_iq_handler/4, - register_iq_handler/5, - register_iq_response_handler/4, - register_iq_response_handler/5, - unregister_iq_handler/2, - unregister_iq_response_handler/2, - refresh_iq_handlers/0, - bounce_resource_packet/3 - ]). +-export([route/3, route_iq/4, route_iq/5, + process_iq_reply/3, register_iq_handler/4, + register_iq_handler/5, register_iq_response_handler/4, + register_iq_response_handler/5, unregister_iq_handler/2, + unregister_iq_response_handler/2, refresh_iq_handlers/0, + bounce_resource_packet/3]). %% gen_server callbacks --export([init/1, handle_call/3, handle_cast/2, handle_info/2, - terminate/2, code_change/3]). +-export([init/1, handle_call/3, handle_cast/2, + handle_info/2, terminate/2, code_change/3]). -include("ejabberd.hrl"). + -include("jlib.hrl"). -record(state, {}). --record(iq_response, {id, module, function, timer}). +-record(iq_response, {id = <<"">> :: binary(), + module :: atom(), + function :: atom() | fun(), + timer = make_ref() :: reference()}). -define(IQTABLE, local_iqtable). @@ -70,65 +68,59 @@ %% Description: Starts the server %%-------------------------------------------------------------------- start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). + gen_server:start_link({local, ?MODULE}, ?MODULE, [], + []). process_iq(From, To, Packet) -> IQ = jlib:iq_query_info(Packet), case IQ of - #iq{xmlns = XMLNS} -> - Host = To#jid.lserver, - case ets:lookup(?IQTABLE, {XMLNS, Host}) of - [{_, Module, Function}] -> - ResIQ = Module:Function(From, To, IQ), - if - ResIQ /= ignore -> - ejabberd_router:route( - To, From, jlib:iq_to_xml(ResIQ)); - true -> - ok - end; - [{_, Module, Function, Opts}] -> - gen_iq_handler:handle(Host, Module, Function, Opts, - From, To, IQ); - [] -> - Err = jlib:make_error_reply( - Packet, ?ERR_FEATURE_NOT_IMPLEMENTED), - ejabberd_router:route(To, From, Err) - end; - reply -> - IQReply = jlib:iq_query_or_response_info(Packet), - process_iq_reply(From, To, IQReply); - _ -> - Err = jlib:make_error_reply(Packet, ?ERR_BAD_REQUEST), - ejabberd_router:route(To, From, Err), - ok + #iq{xmlns = XMLNS} -> + Host = To#jid.lserver, + case ets:lookup(?IQTABLE, {XMLNS, Host}) of + [{_, Module, Function}] -> + ResIQ = Module:Function(From, To, IQ), + if ResIQ /= ignore -> + ejabberd_router:route(To, From, jlib:iq_to_xml(ResIQ)); + true -> ok + end; + [{_, Module, Function, Opts}] -> + gen_iq_handler:handle(Host, Module, Function, Opts, + From, To, IQ); + [] -> + Err = jlib:make_error_reply(Packet, + ?ERR_FEATURE_NOT_IMPLEMENTED), + ejabberd_router:route(To, From, Err) + end; + reply -> + IQReply = jlib:iq_query_or_response_info(Packet), + process_iq_reply(From, To, IQReply); + _ -> + Err = jlib:make_error_reply(Packet, ?ERR_BAD_REQUEST), + ejabberd_router:route(To, From, Err), + ok end. process_iq_reply(From, To, #iq{id = ID} = IQ) -> case get_iq_callback(ID) of - {ok, undefined, Function} -> - Function(IQ), - ok; - {ok, Module, Function} -> - Module:Function(From, To, IQ), - ok; - _ -> - nothing + {ok, undefined, Function} -> Function(IQ), ok; + {ok, Module, Function} -> + Module:Function(From, To, IQ), ok; + _ -> nothing end. route(From, To, Packet) -> case catch do_route(From, To, Packet) of - {'EXIT', Reason} -> - ?ERROR_MSG("~p~nwhen processing: ~p", - [Reason, {From, To, Packet}]); - _ -> - ok + {'EXIT', Reason} -> + ?ERROR_MSG("~p~nwhen processing: ~p", + [Reason, {From, To, Packet}]); + _ -> ok end. route_iq(From, To, IQ, F) -> route_iq(From, To, IQ, F, undefined). -route_iq(From, To, #iq{type = Type} = IQ, F, Timeout) when is_function(F) -> +route_iq(From, To, #iq{type = Type} = IQ, F, Timeout) + when is_function(F) -> Packet = if Type == set; Type == get -> ID = randoms:get_string(), Host = From#jid.lserver, @@ -139,15 +131,16 @@ route_iq(From, To, #iq{type = Type} = IQ, F, Timeout) when is_function(F) -> end, ejabberd_router:route(From, To, Packet). -register_iq_response_handler(Host, ID, Module, Function) -> - register_iq_response_handler(Host, ID, Module, Function, undefined). +register_iq_response_handler(Host, ID, Module, + Function) -> + register_iq_response_handler(Host, ID, Module, Function, + undefined). -register_iq_response_handler(_Host, ID, Module, Function, Timeout0) -> +register_iq_response_handler(_Host, ID, Module, + Function, Timeout0) -> Timeout = case Timeout0 of - undefined -> - ?IQ_TIMEOUT; - N when is_integer(N), N > 0 -> - N + undefined -> ?IQ_TIMEOUT; + N when is_integer(N), N > 0 -> N end, TRef = erlang:start_timer(Timeout, ejabberd_local, ID), mnesia:dirty_write(#iq_response{id = ID, @@ -156,14 +149,15 @@ register_iq_response_handler(_Host, ID, Module, Function, Timeout0) -> timer = TRef}). register_iq_handler(Host, XMLNS, Module, Fun) -> - ejabberd_local ! {register_iq_handler, Host, XMLNS, Module, Fun}. + ejabberd_local ! + {register_iq_handler, Host, XMLNS, Module, Fun}. register_iq_handler(Host, XMLNS, Module, Fun, Opts) -> - ejabberd_local ! {register_iq_handler, Host, XMLNS, Module, Fun, Opts}. + ejabberd_local ! + {register_iq_handler, Host, XMLNS, Module, Fun, Opts}. unregister_iq_response_handler(_Host, ID) -> - catch get_iq_callback(ID), - ok. + catch get_iq_callback(ID), ok. unregister_iq_handler(Host, XMLNS) -> ejabberd_local ! {unregister_iq_handler, Host, XMLNS}. @@ -172,7 +166,8 @@ refresh_iq_handlers() -> ejabberd_local ! refresh_iq_handlers. bounce_resource_packet(From, To, Packet) -> - Err = jlib:make_error_reply(Packet, ?ERR_ITEM_NOT_FOUND), + Err = jlib:make_error_reply(Packet, + ?ERR_ITEM_NOT_FOUND), ejabberd_router:route(To, From, Err), stop. @@ -188,12 +183,15 @@ bounce_resource_packet(From, To, Packet) -> %% Description: Initiates the server %%-------------------------------------------------------------------- init([]) -> - lists:foreach( - fun(Host) -> - ejabberd_router:register_route(Host, {apply, ?MODULE, route}), - ejabberd_hooks:add(local_send_to_resource_hook, Host, - ?MODULE, bounce_resource_packet, 100) - end, ?MYHOSTS), + lists:foreach(fun (Host) -> + ejabberd_router:register_route(Host, + {apply, ?MODULE, + route}), + ejabberd_hooks:add(local_send_to_resource_hook, Host, + ?MODULE, bounce_resource_packet, + 100) + end, + ?MYHOSTS), catch ets:new(?IQTABLE, [named_table, public]), update_table(), mnesia:create_table(iq_response, @@ -212,70 +210,68 @@ init([]) -> %% Description: Handling call messages %%-------------------------------------------------------------------- handle_call(_Request, _From, State) -> - Reply = ok, - {reply, Reply, State}. - %%-------------------------------------------------------------------- %% Function: handle_cast(Msg, State) -> {noreply, State} | %% {noreply, State, Timeout} | %% {stop, Reason, State} %% Description: Handling cast messages %%-------------------------------------------------------------------- -handle_cast(_Msg, State) -> - {noreply, State}. - %%-------------------------------------------------------------------- %% Function: handle_info(Info, State) -> {noreply, State} | %% {noreply, State, Timeout} | %% {stop, Reason, State} %% Description: Handling all non call/cast messages %%-------------------------------------------------------------------- + Reply = ok, {reply, Reply, State}. + +handle_cast(_Msg, State) -> {noreply, State}. + handle_info({route, From, To, Packet}, State) -> case catch do_route(From, To, Packet) of - {'EXIT', Reason} -> - ?ERROR_MSG("~p~nwhen processing: ~p", - [Reason, {From, To, Packet}]); - _ -> - ok + {'EXIT', Reason} -> + ?ERROR_MSG("~p~nwhen processing: ~p", + [Reason, {From, To, Packet}]); + _ -> ok end, {noreply, State}; -handle_info({register_iq_handler, Host, XMLNS, Module, Function}, State) -> +handle_info({register_iq_handler, Host, XMLNS, Module, + Function}, + State) -> ets:insert(?IQTABLE, {{XMLNS, Host}, Module, Function}), catch mod_disco:register_feature(Host, XMLNS), {noreply, State}; -handle_info({register_iq_handler, Host, XMLNS, Module, Function, Opts}, State) -> - ets:insert(?IQTABLE, {{XMLNS, Host}, Module, Function, Opts}), +handle_info({register_iq_handler, Host, XMLNS, Module, + Function, Opts}, + State) -> + ets:insert(?IQTABLE, + {{XMLNS, Host}, Module, Function, Opts}), catch mod_disco:register_feature(Host, XMLNS), {noreply, State}; -handle_info({unregister_iq_handler, Host, XMLNS}, State) -> +handle_info({unregister_iq_handler, Host, XMLNS}, + State) -> case ets:lookup(?IQTABLE, {XMLNS, Host}) of - [{_, Module, Function, Opts}] -> - gen_iq_handler:stop_iq_handler(Module, Function, Opts); - _ -> - ok + [{_, Module, Function, Opts}] -> + gen_iq_handler:stop_iq_handler(Module, Function, Opts); + _ -> ok end, ets:delete(?IQTABLE, {XMLNS, Host}), catch mod_disco:unregister_feature(Host, XMLNS), {noreply, State}; handle_info(refresh_iq_handlers, State) -> - lists:foreach( - fun(T) -> - case T of - {{XMLNS, Host}, _Module, _Function, _Opts} -> - catch mod_disco:register_feature(Host, XMLNS); - {{XMLNS, Host}, _Module, _Function} -> - catch mod_disco:register_feature(Host, XMLNS); - _ -> - ok - end - end, ets:tab2list(?IQTABLE)), + lists:foreach(fun (T) -> + case T of + {{XMLNS, Host}, _Module, _Function, _Opts} -> + catch mod_disco:register_feature(Host, XMLNS); + {{XMLNS, Host}, _Module, _Function} -> + catch mod_disco:register_feature(Host, XMLNS); + _ -> ok + end + end, + ets:tab2list(?IQTABLE)), {noreply, State}; handle_info({timeout, _TRef, ID}, State) -> process_iq_timeout(ID), {noreply, State}; -handle_info(_Info, State) -> - {noreply, State}. - %%-------------------------------------------------------------------- %% Function: terminate(Reason, State) -> void() %% Description: This function is called by a gen_server when it is about to @@ -283,48 +279,43 @@ handle_info(_Info, State) -> %% cleaning up. When it returns, the gen_server terminates with Reason. %% The return value is ignored. %%-------------------------------------------------------------------- -terminate(_Reason, _State) -> - ok. - %%-------------------------------------------------------------------- %% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} %% Description: Convert process state when code is changed %%-------------------------------------------------------------------- -code_change(_OldVsn, State, _Extra) -> - {ok, State}. - %%-------------------------------------------------------------------- %%% Internal functions %%-------------------------------------------------------------------- +handle_info(_Info, State) -> {noreply, State}. + +terminate(_Reason, _State) -> ok. + +code_change(_OldVsn, State, _Extra) -> {ok, State}. + do_route(From, To, Packet) -> - ?DEBUG("local route~n\tfrom ~p~n\tto ~p~n\tpacket ~P~n", + ?DEBUG("local route~n\tfrom ~p~n\tto ~p~n\tpacket " + "~P~n", [From, To, Packet, 8]), - if - To#jid.luser /= "" -> - ejabberd_sm:route(From, To, Packet); - To#jid.lresource == "" -> - {xmlelement, Name, _Attrs, _Els} = Packet, - case Name of - "iq" -> - process_iq(From, To, Packet); - "message" -> - ok; - "presence" -> - ok; - _ -> - ok - end; - true -> - {xmlelement, _Name, Attrs, _Els} = Packet, - case xml:get_attr_s("type", Attrs) of - "error" -> ok; - "result" -> ok; - _ -> - ejabberd_hooks:run(local_send_to_resource_hook, - To#jid.lserver, - [From, To, Packet]) - end - end. + if To#jid.luser /= <<"">> -> + ejabberd_sm:route(From, To, Packet); + To#jid.lresource == <<"">> -> + #xmlel{name = Name} = Packet, + case Name of + <<"iq">> -> process_iq(From, To, Packet); + <<"message">> -> ok; + <<"presence">> -> ok; + _ -> ok + end; + true -> + #xmlel{attrs = Attrs} = Packet, + case xml:get_attr_s(<<"type">>, Attrs) of + <<"error">> -> ok; + <<"result">> -> ok; + _ -> + ejabberd_hooks:run(local_send_to_resource_hook, + To#jid.lserver, [From, To, Packet]) + end + end. update_table() -> case catch mnesia:table_info(iq_response, attributes) of @@ -365,13 +356,7 @@ process_iq_timeout() -> cancel_timer(TRef) -> case erlang:cancel_timer(TRef) of - false -> - receive - {timeout, TRef, _} -> - ok - after 0 -> - ok - end; - _ -> - ok + false -> + receive {timeout, TRef, _} -> ok after 0 -> ok end; + _ -> ok end. diff --git a/src/ejabberd_node_groups.erl b/src/ejabberd_node_groups.erl index 84c1d69ca..371a1bc28 100644 --- a/src/ejabberd_node_groups.erl +++ b/src/ejabberd_node_groups.erl @@ -60,20 +60,20 @@ start_link() -> join(Name) -> PG = {?MODULE, Name}, - ?PG2:create(PG), - ?PG2:join(PG, whereis(?MODULE)). + pg2:create(PG), + pg2:join(PG, whereis(?MODULE)). leave(Name) -> PG = {?MODULE, Name}, - ?PG2:leave(PG, whereis(?MODULE)). + pg2:leave(PG, whereis(?MODULE)). get_members(Name) -> PG = {?MODULE, Name}, - [node(P) || P <- ?PG2:get_members(PG)]. + [node(P) || P <- pg2:get_members(PG)]. get_closest_node(Name) -> PG = {?MODULE, Name}, - node(?PG2:get_closest_pid(PG)). + node(pg2:get_closest_pid(PG)). %%==================================================================== %% gen_server callbacks @@ -88,7 +88,7 @@ get_closest_node(Name) -> %%-------------------------------------------------------------------- init([]) -> {FE, BE} = - case ejabberd_config:get_local_option(node_type) of + case ejabberd_config:get_local_option(node_type, fun(N) -> N end) of frontend -> {true, false}; backend -> diff --git a/src/ejabberd_piefxis.erl b/src/ejabberd_piefxis.erl index b193dc67c..668f1ba62 100644 --- a/src/ejabberd_piefxis.erl +++ b/src/ejabberd_piefxis.erl @@ -3,6 +3,10 @@ %%% Author : Pablo Polvorin, Vidal Santiago Martinez %%% Purpose : XEP-0227: Portable Import/Export Format for XMPP-IM Servers %%% Created : 17 Jul 2008 by Pablo Polvorin +%%%------------------------------------------------------------------- +%%% @author Evgeniy Khramtsov +%%% @copyright (C) 2012, Evgeniy Khramtsov +%%% @doc %%% %%% %%% ejabberd, Copyright (C) 2002-2013 ProcessOne @@ -31,239 +35,85 @@ %%% - XEP-227: 6. Security Considerations %%% - Other schemas of XInclude are not tested, and may not be imported correctly. %%% - If a host has many users, split that host in XML files with 50 users each. - %%%% Headers -module(ejabberd_piefxis). +%% API -export([import_file/1, export_server/1, export_host/2]). --record(parsing_state, {parser, host, dir}). +-define(CHUNK_SIZE, 1024*20). %20k -include("ejabberd.hrl"). +-include("jlib.hrl"). +-include("mod_privacy.hrl"). +-include("mod_roster.hrl"). %%-include_lib("exmpp/include/exmpp.hrl"). %%-include_lib("exmpp/include/exmpp_client.hrl"). %% Copied from exmpp header files: --define(NS_ROSTER, "jabber:iq:roster"). --define(NS_VCARD, "vcard-temp"). --record(xmlcdata, { - cdata = <<>> - }). --record(xmlattr, { - ns = undefined, - name, - value - }). --record(xmlel, { - ns = undefined, - declared_ns = [], - name, - attrs = [], - children = [] - }). --record(iq, { - kind, - type, - id, - ns, - payload, - error, - lang, - iq_ns - }). --record(xmlendtag, { - ns = undefined, - name - }). - - %% Copied from mod_private.erl --record(private_storage, {usns, xml}). - %%-define(ERROR_MSG(M,Args),io:format(M,Args)). %%-define(INFO_MSG(M,Args),ok). - --define(CHUNK_SIZE,1024*20). %20k - --define(BTL, binary_to_list). --define(LTB, list_to_binary). - --define(NS_XINCLUDE, 'http://www.w3.org/2001/XInclude'). - %%%================================== - %%%% Import file +-define(NS_PIE, <<"urn:xmpp:pie:0">>). +-define(NS_PIEFXIS, <<"http://www.xmpp.org/extensions/xep-0227.html#ns">>). +-define(NS_XI, <<"http://www.w3.org/2001/XInclude">>). -import_file(FileName) -> - _ = #xmlattr{}, %% this stupid line is only to prevent compilation warning about "recod xmlattr is unused" - import_file(FileName, 2). +-record(state, {xml_stream_state :: xml_stream:xml_stream_state(), + user = <<"">> :: binary(), + server = <<"">> :: binary(), + fd :: file:io_device(), + dir = <<"">> :: binary()}). -import_file(FileName, RootDepth) -> - try_start_exmpp(), - Dir = filename:dirname(FileName), - {ok, IO} = try_open_file(FileName), - Parser = exmpp_xml:start_parser([{max_size,infinity}, - {root_depth, RootDepth}, - {emit_endtag,true}]), - read_chunks(IO, #parsing_state{parser=Parser, dir=Dir}), - file:close(IO), - exmpp_xml:stop_parser(Parser). - -try_start_exmpp() -> - try exmpp:start() - catch - error:{already_started, exmpp} -> ok; - error:undef -> throw({error, exmpp_not_installed}) - end. - -try_open_file(FileName) -> - case file:open(FileName,[read,binary]) of - {ok, IO} -> {ok, IO}; - {error, enoent} -> throw({error, {file_not_found, FileName}}) - end. +-type state() :: #state{}. %%File could be large.. we read it in chunks -read_chunks(IO,State) -> - case file:read(IO,?CHUNK_SIZE) of - {ok,Chunk} -> - NewState = process_chunk(Chunk,State), - read_chunks(IO,NewState); - eof -> - ok - end. +%%%=================================================================== +%%% API +%%%=================================================================== +import_file(FileName) -> + import_file(FileName, #state{}). -process_chunk(Chunk,S =#parsing_state{parser=Parser}) -> - case exmpp_xml:parse(Parser,Chunk) of - continue -> - S; - XMLElements -> - process_elements(XMLElements,S) +-spec import_file(binary(), state()) -> ok | {error, atom()}. + +import_file(FileName, State) -> + case file:open(FileName, [read, binary]) of + {ok, Fd} -> + Dir = filename:dirname(FileName), + XMLStreamState = xml_stream:new(self(), infinity), + Res = process(State#state{xml_stream_state = XMLStreamState, + fd = Fd, + dir = Dir}), + file:close(Fd), + Res; + {error, Reason} -> + ErrTxt = file:format_error(Reason), + ?ERROR_MSG("Failed to open file '~s': ~s", [FileName, ErrTxt]), + {error, Reason} end. %%%================================== %%%% Process Elements - -process_elements(Elements,State) -> - lists:foldl(fun process_element/2,State,Elements). - %%%================================== %%%% Process Element - -process_element(El=#xmlel{name=user, ns=_XMLNS}, - State=#parsing_state{host=Host}) -> - case add_user(El,Host) of - ok -> ok; - {error, _Other} -> error - end, - State; - -process_element(H=#xmlel{name=host},State) -> - State#parsing_state{host=?BTL(exmpp_xml:get_attribute(H, <<"jid">>, none))}; - -process_element(#xmlel{name='server-data'},State) -> - State; - -process_element(El=#xmlel{name=include, ns=?NS_XINCLUDE}, State=#parsing_state{dir=Dir}) -> - case exmpp_xml:get_attribute(El, <<"href">>, none) of - none -> - ok; - HrefB -> - Href = binary_to_list(HrefB), - %%?INFO_MSG("Parse also this file: ~n~p", [Href]), - FileName = filename:join([Dir, Href]), - import_file(FileName, 1), - Href - end, - State; - -process_element(#xmlcdata{cdata = _CData},State) -> - State; - -process_element(#xmlendtag{ns = _NS, name='server-data'},State) -> - State; - -process_element(#xmlendtag{ns = _NS, name=_Name},State) -> - State; - -process_element(El,State) -> - io:format("Warning!: unknown element found: ~p ~n",[El]), - State. - %%%================================== %%%% Add user - -add_user(El, Domain) -> - User = exmpp_xml:get_attribute(El, <<"name">>, none), - PasswordFormat = exmpp_xml:get_attribute(El, <<"password-format">>, <<"plaintext">>), - Password = exmpp_xml:get_attribute(El, <<"password">>, none), - add_user(El, Domain, User, PasswordFormat, Password). - %% @spec (El::xmlel(), Domain::string(), User::binary(), Password::binary() | none) %% -> ok | {error, ErrorText::string()} %% @doc Add a new user to the database. %% If user already exists, it will be only updated. -add_user(El, Domain, UserBinary, <<"plaintext">>, none) -> - User = ?BTL(UserBinary), - io:format("Account ~s@~s will not be created, updating it...~n", - [User, Domain]), - io:format(""), - populate_user_with_elements(El, Domain, User), - ok; -add_user(El, Domain, UserBinary, PasswordFormat, PasswordBinary) -> - User = ?BTL(UserBinary), - Password2 = prepare_password(PasswordFormat, PasswordBinary, El), - case create_user(User,Password2,Domain) of - ok -> - populate_user_with_elements(El, Domain, User), - ok; - {atomic, exists} -> - io:format("Account ~s@~s already exists, updating it...~n", - [User, Domain]), - io:format(""), - populate_user_with_elements(El, Domain, User), - ok; - {error, Other} -> - ?ERROR_MSG("Error adding user ~s@~s: ~p~n", [User, Domain, Other]), - {error, Other} - end. - -prepare_password(<<"plaintext">>, PasswordBinary, _El) -> - ?BTL(PasswordBinary); -prepare_password(<<"scram">>, none, El) -> - ScramEl = exmpp_xml:get_element(El, 'scram-hash'), - #scram{storedkey = base64:decode(exmpp_xml:get_attribute( - ScramEl, <<"stored-key">>, none)), - serverkey = base64:decode(exmpp_xml:get_attribute( - ScramEl, <<"server-key">>, none)), - salt = base64:decode(exmpp_xml:get_attribute( - ScramEl, <<"salt">>, none)), - iterationcount = list_to_integer(exmpp_xml:get_attribute_as_list( - ScramEl, <<"iteration-count">>, - ?SCRAM_DEFAULT_ITERATION_COUNT)) - }. - -populate_user_with_elements(El, Domain, User) -> - exmpp_xml:foreach( - fun (_,Child) -> - populate_user(User,Domain,Child) - end, - El). +-spec export_server(binary()) -> any(). %% @spec (User::string(), Password::string(), Domain::string()) %% -> ok | {atomic, exists} | {error, not_allowed} %% @doc Create a new user -create_user(User,Password,Domain) -> - case ejabberd_auth:try_register(User,Domain,Password) of - {atomic,ok} -> ok; - {atomic, exists} -> {atomic, exists}; - {error, not_allowed} -> {error, not_allowed}; - Other -> {error, Other} - end. +export_server(Dir) -> + export_hosts(?MYHOSTS, Dir). %%%================================== %%%% Populate user - %% @spec (User::string(), Domain::string(), El::xml()) %% -> ok | {error, not_found} %% @@ -286,28 +136,10 @@ create_user(User,Password,Domain) -> %% %% %% ''' +-spec export_host(binary(), binary()) -> any(). -populate_user(User,Domain,El=#xmlel{name='query', ns='jabber:iq:roster'}) -> - io:format("Trying to add/update roster list...",[]), - case loaded_module(Domain, mod_roster) of - {ok, _DBType} -> - case mod_roster:set_items(User, Domain, - exmpp_xml:xmlel_to_xmlelement(El)) of - {atomic, ok} -> - io:format(" DONE.~n",[]), - ok; - _ -> - io:format(" ERROR.~n",[]), - ?ERROR_MSG("Error trying to add a new user: ~s ~n", - [exmpp_xml:document_to_list(El)]), - {error, not_found} - end; - E -> io:format(" ERROR: ~p~n",[E]), - ?ERROR_MSG("No modules loaded [mod_roster] ~s ~n", - [exmpp_xml:document_to_list(El)]), - {error, not_found} - end; - +export_host(Dir, Host) -> + export_hosts([Host], Dir). %% @spec User = String with the user name %% Domain = String with a domain name @@ -328,180 +160,471 @@ populate_user(User,Domain,El=#xmlel{name='query', ns='jabber:iq:roster'}) -> %% %% %% ''' - -populate_user(User,Domain,El=#xmlel{name='vCard', ns='vcard-temp'}) -> - io:format("Trying to add/update vCards...",[]), - case loaded_module(Domain, mod_vcard) of - {ok, _} -> FullUser = jid_to_old_jid(exmpp_jid:make(User, Domain)), - IQ = iq_to_old_iq(#iq{type = set, payload = El}), - case mod_vcard:process_sm_iq(FullUser, FullUser , IQ) of - {error,_Err} -> - io:format(" ERROR.~n",[]), - ?ERROR_MSG("Error processing vcard ~s : ~p ~n", - [exmpp_xml:document_to_list(El), _Err]); - _ -> - io:format(" DONE.~n",[]), ok - end; - _ -> - io:format(" ERROR.~n",[]), - ?ERROR_MSG("No modules loaded [mod_vcard] ~s ~n", - [exmpp_xml:document_to_list(El)]), - {error, not_found} - end; +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +export_hosts(Hosts, Dir) -> + FnT = make_filename_template(), + DFn = make_main_basefilename(Dir, FnT), + case file:open(DFn, [raw, write]) of + {ok, Fd} -> + print(Fd, make_piefxis_xml_head()), + print(Fd, make_piefxis_server_head()), + FilesAndHosts = [{make_host_filename(FnT, Host), Host} + || Host <- Hosts], + lists:foreach( + fun({FnH, _}) -> + print(Fd, make_xinclude(FnH)) + end, FilesAndHosts), + print(Fd, make_piefxis_server_tail()), + print(Fd, make_piefxis_xml_tail()), + file:close(Fd), + lists:foldl( + fun({FnH, Host}, ok) -> + export_host(Dir, FnH, Host); + (_, Err) -> + Err + end, ok, FilesAndHosts); + {error, Reason} -> + ErrTxt = file:format_error(Reason), + ?ERROR_MSG("Failed to open file '~s': ~s", [DFn, ErrTxt]), + {error, Reason} + end. %% @spec User = String with the user name %% Domain = String with a domain name %% El = Sub XML element with offline messages values %% @ret ok | {error, not_found} %% @doc Read off-line message from the XML and send it to the server - -populate_user(User,Domain,El=#xmlel{name='offline-messages'}) -> - io:format("Trying to add/update offline-messages...",[]), - case loaded_module(Domain, mod_offline) of - {ok, _DBType} -> - ok = exmpp_xml:foreach( - fun (_Element, {xmlcdata, _}) -> - ok; - (_Element, Child) -> - From = exmpp_xml:get_attribute(Child, <<"from">>,none), - FullFrom = jid_to_old_jid(exmpp_jid:parse(From)), - FullUser = jid_to_old_jid(exmpp_jid:make(User, - Domain)), - OldChild = exmpp_xml:xmlel_to_xmlelement(Child), - _R = mod_offline:store_packet(FullFrom, FullUser, OldChild) - end, El), io:format(" DONE.~n",[]); - _ -> - io:format(" ERROR.~n",[]), - ?ERROR_MSG("No modules loaded [mod_offline] ~s ~n", - [exmpp_xml:document_to_list(El)]), - {error, not_found} - end; +export_host(Dir, FnH, Host) -> + DFn = make_host_basefilename(Dir, FnH), + case file:open(DFn, [raw, write]) of + {ok, Fd} -> + print(Fd, make_piefxis_xml_head()), + print(Fd, make_piefxis_host_head(Host)), + Users = ejabberd_auth:get_vh_registered_users(Host), + case export_users(Users, Host, Fd) of + ok -> + print(Fd, make_piefxis_host_tail()), + print(Fd, make_piefxis_xml_tail()), + file:close(Fd), + ok; + Err -> + file:close(Fd), + file:delete(DFn), + Err + end; + {error, Reason} -> + ErrTxt = file:format_error(Reason), + ?ERROR_MSG("Failed to open file '~s': ~s", [DFn, ErrTxt]), + {error, Reason} + end. %% @spec User = String with the user name %% Domain = String with a domain name %% El = Sub XML element with private storage values %% @ret ok | {error, not_found} %% @doc Private storage parsing - -populate_user(User,Domain,El=#xmlel{name='query', ns='jabber:iq:private'}) -> - io:format("Trying to add/update private storage...",[]), - case loaded_module(Domain, mod_private) of - {ok, _DBType} -> - FullUser = jid_to_old_jid(exmpp_jid:make(User, Domain)), - IQ = iq_to_old_iq(#iq{type = set, - ns = 'jabber:iq:private', - kind = request, - iq_ns = 'jabberd:client', - payload = El}), - case mod_private:process_sm_iq(FullUser, FullUser, IQ ) of - {error, _Err} -> - io:format(" ERROR.~n",[]), - ?ERROR_MSG("Error processing private storage ~s : ~p ~n", - [exmpp_xml:document_to_list(El), _Err]); - _ -> io:format(" DONE.~n",[]), ok - end; - _ -> - io:format(" ERROR.~n",[]), - ?ERROR_MSG("No modules loaded [mod_private] ~s ~n", - [exmpp_xml:document_to_list(El)]), - {error, not_found} +export_users([{User, _S}|Users], Server, Fd) -> + case export_user(User, Server, Fd) of + ok -> + export_users(Users, Server, Fd); + Err -> + Err end; - -populate_user(_User, _Domain, #xmlcdata{cdata = _CData}) -> - ok; - -populate_user(_User, _Domain, _El) -> +export_users([], _Server, _Fd) -> ok. %%%================================== %%%% Utilities +export_user(User, Server, Fd) -> + Pass = ejabberd_auth:get_password_s(User, Server), + Els = get_offline(User, Server) ++ + get_vcard(User, Server) ++ + get_privacy(User, Server) ++ + get_roster(User, Server) ++ + get_private(User, Server), + print(Fd, xml:element_to_binary( + #xmlel{name = <<"user">>, + attrs = [{<<"name">>, User}, + {<<"password">>, Pass}], + children = Els})). -loaded_module(Domain, Module) -> - case gen_mod:is_loaded(Domain, Module) of - true -> - {ok, gen_mod:db_type(Domain, Module)}; - false -> - {error, not_found} +get_vcard(User, Server) -> + JID = jlib:make_jid(User, Server, <<>>), + case mod_vcard:process_sm_iq(JID, JID, #iq{type = get}) of + #iq{type = result, sub_el = [_|_] = VCardEls} -> + VCardEls; + _ -> + [] end. -jid_to_old_jid(Jid) -> - {jid, to_list(exmpp_jid:node_as_list(Jid)), - to_list(exmpp_jid:domain_as_list(Jid)), - to_list(exmpp_jid:resource_as_list(Jid)), - to_list(exmpp_jid:prep_node_as_list(Jid)), - to_list(exmpp_jid:prep_domain_as_list(Jid)), - to_list(exmpp_jid:prep_resource_as_list(Jid))}. - -iq_to_old_iq(#iq{id = ID, type = Type, lang = Lang, ns= NS, payload = El }) -> - {iq, to_list(ID), Type, to_list(NS), to_list(Lang), - exmpp_xml:xmlel_to_xmlelement(El)}. - -to_list(L) when is_list(L) -> L; -to_list(B) when is_binary(B) -> binary_to_list(B); -to_list(undefined) -> ""; -to_list(B) when is_atom(B) -> atom_to_list(B). - %%%================================== +get_offline(User, Server) -> + case mod_offline:get_offline_els(User, Server) of + [] -> + []; + Els -> + NewEls = lists:map( + fun(#xmlel{attrs = Attrs} = El) -> + NewAttrs = lists:keystore(<<"xmlns">>, 1, + Attrs, + {<<"xmlns">>, + <<"jabber:client">>}), + El#xmlel{attrs = NewAttrs} + end, Els), + [#xmlel{name = <<"offline-messages">>, children = NewEls}] + end. %%%% Export hosts +get_privacy(User, Server) -> + case mod_privacy:get_user_lists(User, Server) of + {ok, #privacy{default = Default, + lists = [_|_] = Lists}} -> + XLists = lists:map( + fun({Name, Items}) -> + XItems = lists:map( + fun mod_privacy:item_to_xml/1, Items), + #xmlel{name = <<"list">>, + attrs = [{<<"name">>, Name}], + children = XItems} + end, Lists), + DefaultEl = case Default of + none -> + []; + _ -> + [#xmlel{name = <<"default">>, + attrs = [{<<"name">>, Default}]}] + end, + [#xmlel{name = <<"query">>, + attrs = [{<<"xmlns">>, ?NS_PRIVACY}], + children = DefaultEl ++ XLists}]; + _ -> + [] + end. %% @spec (Dir::string(), Hosts::[string()]) -> ok -export_hosts(Dir, Hosts) -> - try_start_exmpp(), +get_roster(User, Server) -> + JID = jlib:make_jid(User, Server, <<>>), + case mod_roster:get_roster(User, Server) of + [_|_] = Items -> + Subs = + lists:flatmap( + fun(#roster{ask = Ask, + askmessage = Msg} = R) + when Ask == in; Ask == both -> + Status = if is_binary(Msg) -> (Msg); + true -> <<"">> + end, + [#xmlel{name = <<"presence">>, + attrs = + [{<<"from">>, + jlib:jid_to_string(R#roster.jid)}, + {<<"to">>, jlib:jid_to_string(JID)}, + {<<"xmlns">>, <<"jabber:client">>}, + {<<"type">>, <<"subscribe">>}], + children = + [#xmlel{name = <<"status">>, + attrs = [], + children = + [{xmlcdata, Status}]}]}]; + (_) -> + [] + end, Items), + Rs = lists:flatmap( + fun(#roster{ask = in, subscription = none}) -> + []; + (R) -> + [mod_roster:item_to_xml(R)] + end, Items), + [#xmlel{name = <<"query">>, + attrs = [{<<"xmlns">>, ?NS_ROSTER}], + children = Rs} | Subs]; + _ -> + [] + end. - FnT = make_filename_template(), - DFn = make_main_basefilename(Dir, FnT), +get_private(User, Server) -> + case mod_private:get_data(User, Server) of + [_|_] = Els -> + [#xmlel{name = <<"query">>, + attrs = [{<<"xmlns">>, ?NS_PRIVATE}], + children = Els}]; + _ -> + [] + end. - {ok, Fd} = file_open(DFn), - print(Fd, make_piefxis_xml_head()), - print(Fd, make_piefxis_server_head()), +process(#state{xml_stream_state = XMLStreamState, fd = Fd} = State) -> + case file:read(Fd, ?CHUNK_SIZE) of + {ok, Data} -> + NewXMLStreamState = xml_stream:parse(XMLStreamState, Data), + case process_els(State#state{xml_stream_state = + NewXMLStreamState}) of + {ok, NewState} -> + process(NewState); + Err -> + xml_stream:close(NewXMLStreamState), + Err + end; + eof -> + xml_stream:close(XMLStreamState), + ok + end. - FilesAndHosts = [{make_host_filename(FnT, Host), Host} || Host <- Hosts], - [print(Fd, make_xinclude(FnH)) || {FnH, _Host} <- FilesAndHosts], +process_els(State) -> + receive + {'$gen_event', El} -> + case process_el(El, State) of + {ok, NewState} -> + process_els(NewState); + Err -> + Err + end + after 0 -> + {ok, State} + end. - print(Fd, make_piefxis_server_tail()), - print(Fd, make_piefxis_xml_tail()), - file_close(Fd), +process_el({xmlstreamstart, <<"server-data">>, Attrs}, State) -> + case xml:get_attr_s(<<"xmlns">>, Attrs) of + ?NS_PIEFXIS -> + {ok, State}; + ?NS_PIE -> + {ok, State}; + NS -> + stop("Unknown 'server-data' namespace = ~s", [NS]) + end; +process_el({xmlstreamend, _}, State) -> + {ok, State}; +process_el({xmlstreamcdata, _}, State) -> + {ok, State}; +process_el({xmlstreamelement, #xmlel{name = <<"xi:include">>, + attrs = Attrs}}, + #state{dir = Dir, user = <<"">>} = State) -> + FileName = xml:get_attr_s(<<"href">>, Attrs), + case import_file(filename:join([Dir, FileName]), State) of + ok -> + {ok, State}; + Err -> + Err + end; +process_el({xmlstreamstart, <<"host">>, Attrs}, State) -> + process_el({xmlstreamelement, #xmlel{name = <<"host">>, + attrs = Attrs}}, State); +process_el({xmlstreamelement, #xmlel{name = <<"host">>, + attrs = Attrs, + children = Els}}, State) -> + JIDS = xml:get_attr_s(<<"jid">>, Attrs), + case jlib:string_to_jid(JIDS) of + #jid{lserver = S} -> + case lists:member(S, ?MYHOSTS) of + true -> + process_users(Els, State#state{server = S}); + false -> + stop("Unknown host: ~s", [S]) + end; + error -> + stop("Invalid 'jid': ~s", [JIDS]) + end; +process_el({xmlstreamstart, <<"user">>, Attrs}, State = #state{server = S}) + when S /= <<"">> -> + process_el({xmlstreamelement, #xmlel{name = <<"user">>, attrs = Attrs}}, + State); +process_el({xmlstreamelement, #xmlel{name = <<"user">>} = El}, + State = #state{server = S}) when S /= <<"">> -> + process_user(El, State); +process_el({xmlstreamelement, El}, State = #state{server = S, user = U}) + when S /= <<"">>, U /= <<"">> -> + process_user_el(El, State); +process_el({xmlstreamelement, El}, _State) -> + stop("Unexpected tag: ~p", [El]); +process_el({xmlstreamstart, El, Attrs}, _State) -> + stop("Unexpected payload: ~p", [{El, Attrs}]); +process_el({xmlstreamerror, Err}, _State) -> + stop("Failed to process element = ~p", [Err]). - [export_host(Dir, FnH, Host) || {FnH, Host} <- FilesAndHosts], +process_users([#xmlel{} = El|Els], State) -> + case process_user(El, State) of + {ok, NewState} -> + process_users(Els, NewState); + Err -> + Err + end; +process_users([_|Els], State) -> + process_users(Els, State); +process_users([], State) -> + {ok, State}. - ok. +process_user(#xmlel{name = <<"user">>, attrs = Attrs, children = Els}, + #state{server = LServer} = State) -> + Name = xml:get_attr_s(<<"name">>, Attrs), + Pass = xml:get_attr_s(<<"password">>, Attrs), + case jlib:nodeprep(Name) of + error -> + stop("Invalid 'user': ~s", [Name]); + LUser -> + case ejabberd_auth:try_register(LUser, LServer, Pass) of + {atomic, _} -> + process_user_els(Els, State#state{user = LUser}); + Err -> + stop("Failed to create user '~s': ~p", [Name, Err]) + end + end. + +process_user_els([#xmlel{} = El|Els], State) -> + case process_user_el(El, State) of + {ok, NewState} -> + process_user_els(Els, NewState); + Err -> + Err + end; +process_user_els([_|Els], State) -> + process_user_els(Els, State); +process_user_els([], State) -> + {ok, State}. + +process_user_el(#xmlel{name = Name, attrs = Attrs, children = Els} = El, + State) -> + case {Name, xml:get_attr_s(<<"xmlns">>, Attrs)} of + {<<"query">>, ?NS_ROSTER} -> + process_roster(El, State); + {<<"query">>, ?NS_PRIVACY} -> + %% Make sure elements go before and + NewEls = lists:reverse(lists:keysort(#xmlel.name, Els)), + process_privacy_el(El#xmlel{children = NewEls}, State); + {<<"query">>, ?NS_PRIVATE} -> + process_private(El, State); + {<<"vCard">>, ?NS_VCARD} -> + process_vcard(El, State); + {<<"offline-messages">>, _} -> + process_offline_msgs(Els, State); + {<<"presence">>, <<"jabber:client">>} -> + process_presence(El, State); + _ -> + {ok, State} + end. + +process_privacy_el(#xmlel{children = [#xmlel{} = SubEl|SubEls]} = El, State) -> + case process_privacy(#xmlel{children = [SubEl]}, State) of + {ok, NewState} -> + process_privacy_el(El#xmlel{children = SubEls}, NewState); + Err -> + Err + end; +process_privacy_el(#xmlel{children = [_|SubEls]} = El, State) -> + process_privacy_el(El#xmlel{children = SubEls}, State); +process_privacy_el(#xmlel{children = []}, State) -> + {ok, State}. + +process_offline_msgs([#xmlel{} = El|Els], State) -> + case process_offline_msg(El, State) of + {ok, NewState} -> + process_offline_msgs(Els, NewState); + Err -> + Err + end; +process_offline_msgs([_|Els], State) -> + process_offline_msgs(Els, State); +process_offline_msgs([], State) -> + {ok, State}. + +process_roster(El, State = #state{user = U, server = S}) -> + case mod_roster:set_items(U, S, El) of + {atomic, _} -> + {ok, State}; + Err -> + stop("Failed to write roster: ~p", [Err]) + end. %%%================================== %%%% Export server +process_privacy(El, State = #state{user = U, server = S}) -> + JID = jlib:make_jid(U, S, <<"">>), + case mod_privacy:process_iq_set( + [], JID, JID, #iq{type = set, sub_el = El}) of + {error, _} = Err -> + stop("Failed to write privacy: ~p", [Err]); + _ -> + {ok, State} + end. %% @spec (Dir::string()) -> ok -export_server(Dir) -> - Hosts = ?MYHOSTS, - export_hosts(Dir, Hosts). +process_private(El, State = #state{user = U, server = S}) -> + JID = jlib:make_jid(U, S, <<"">>), + case mod_private:process_sm_iq( + JID, JID, #iq{type = set, sub_el = El}) of + #iq{type = result} -> + {ok, State}; + Err -> + stop("Failed to write private: ~p", [Err]) + end. %%%================================== %%%% Export host +process_vcard(El, State = #state{user = U, server = S}) -> + JID = jlib:make_jid(U, S, <<"">>), + case mod_vcard:process_sm_iq( + JID, JID, #iq{type = set, sub_el = El}) of + #iq{type = result} -> + {ok, State}; + Err -> + stop("Failed to write vcard: ~p", [Err]) + end. %% @spec (Dir::string(), Host::string()) -> ok -export_host(Dir, Host) -> - Hosts = [Host], - export_hosts(Dir, Hosts). +process_offline_msg(El, State = #state{user = U, server = S}) -> + FromS = xml:get_attr_s(<<"from">>, El#xmlel.attrs), + case jlib:string_to_jid(FromS) of + #jid{} = From -> + To = jlib:make_jid(U, S, <<>>), + NewEl = jlib:replace_from_to(From, To, El), + case catch mod_offline:store_packet(From, To, NewEl) of + {'EXIT', _} = Err -> + stop("Failed to store offline message: ~p", [Err]); + _ -> + {ok, State} + end; + _ -> + stop("Invalid 'from' = ~s", [FromS]) + end. %% @spec (Dir::string(), Fn::string(), Host::string()) -> ok -export_host(Dir, FnH, Host) -> +process_presence(El, #state{user = U, server = S} = State) -> + FromS = xml:get_attr_s(<<"from">>, El#xmlel.attrs), + case jlib:string_to_jid(FromS) of + #jid{} = From -> + To = jlib:make_jid(U, S, <<>>), + NewEl = jlib:replace_from_to(From, To, El), + ejabberd_router:route(From, To, NewEl), + {ok, State}; + _ -> + stop("Invalid 'from' = ~s", [FromS]) + end. - DFn = make_host_basefilename(Dir, FnH), +stop(Fmt, Args) -> + ?ERROR_MSG(Fmt, Args), + {error, import_failed}. - {ok, Fd} = file_open(DFn), - print(Fd, make_piefxis_xml_head()), - print(Fd, make_piefxis_host_head(Host)), +make_filename_template() -> + {{Year, Month, Day}, {Hour, Minute, Second}} = calendar:local_time(), + list_to_binary( + io_lib:format("~4..0w~2..0w~2..0w-~2..0w~2..0w~2..0w", + [Year, Month, Day, Hour, Minute, Second])). - Users = ejabberd_auth:get_vh_registered_users(Host), - [export_user(Fd, Username, Host) || {Username, _Host} <- Users], - timer:sleep(500), % Delay to ensure ERROR_MSG are displayed in the shell +make_main_basefilename(Dir, FnT) -> + Filename2 = <>, + filename:join([Dir, Filename2]). - print(Fd, make_piefxis_host_tail()), - print(Fd, make_piefxis_xml_tail()), - file_close(Fd). +%% @doc Make the filename for the host. +%% Example: ``(<<"20080804-231550">>, <<"jabber.example.org">>) -> +%% <<"20080804-231550_jabber_example_org.xml">>'' +make_host_filename(FnT, Host) -> + Host2 = str:join(str:tokens(Host, <<".">>), <<"_">>), + <>. %%%================================== %%%% PIEFXIS formatting +make_host_basefilename(Dir, FnT) -> + filename:join([Dir, FnT]). %% @spec () -> string() make_piefxis_xml_head() -> @@ -513,9 +636,8 @@ make_piefxis_xml_tail() -> %% @spec () -> string() make_piefxis_server_head() -> - "". + io_lib:format("", + [?NS_PIE, ?NS_XI]). %% @spec () -> string() make_piefxis_server_tail() -> @@ -523,10 +645,8 @@ make_piefxis_server_tail() -> %% @spec (Host::string()) -> string() make_piefxis_host_head(Host) -> - NSString = - " xmlns='http://www.xmpp.org/extensions/xep-0227.html#ns'" - " xmlns:xi='http://www.w3.org/2001/XInclude'", - io_lib:format("", [NSString, Host]). + io_lib:format("", + [?NS_PIE, ?NS_XI, Host]). %% @spec () -> string() make_piefxis_host_tail() -> @@ -539,196 +659,26 @@ make_xinclude(Fn) -> %%%================================== %%%% Export user - %% @spec (Fd, Username::string(), Host::string()) -> ok %% @doc Extract user information and print it. -export_user(Fd, Username, Host) -> - try extract_user(Username, Host) of - UserString -> - print(Fd, UserString) - catch - E1:E2 -> - ?ERROR_MSG("The account ~s@~s is not exported because a problem " - "was found in it:~n~p: ~p", [Username, Host, E1, E2]) - end. - %% @spec (Username::string(), Host::string()) -> string() -extract_user(Username, Host) -> - Password = ejabberd_auth:get_password(Username, Host), - PasswordStr = build_password_string(Password), - UserInfo = [extract_user_info(InfoName, Username, Host) || InfoName <- [roster, offline, private, vcard]], - UserInfoString = lists:flatten(UserInfo), - io_lib:format("", - [Username, PasswordStr, UserInfoString]). - -build_password_string({StoredKey, ServerKey, Salt, IterationCount}) -> - io_lib:format("password-format='scram'>" - " ", - [base64:encode_to_string(StoredKey), - base64:encode_to_string(ServerKey), - base64:encode_to_string(Salt), - IterationCount]); -build_password_string(Password) when is_list(Password) -> - io_lib:format("password-format='plaintext' password='~s'>", [Password]). - %% @spec (InfoName::atom(), Username::string(), Host::string()) -> string() -extract_user_info(roster, Username, Host) -> - case loaded_module(Host, mod_roster) of - {ok, _DBType} -> - From = To = jlib:make_jid(Username, Host, ""), - SubelGet = {xmlelement, "query", [{"xmlns",?NS_ROSTER}], []}, - %%IQGet = #iq{type=get, xmlns=?NS_ROSTER, payload=SubelGet}, % this is for 3.0.0 version - IQGet = {iq, "", get, ?NS_ROSTER, "" , SubelGet}, - Res = mod_roster:process_local_iq(From, To, IQGet), - %%[El] = Res#iq.payload, % this is for 3.0.0 version - {iq, _, result, _, _, Els} = Res, - case Els of - [El] -> exmpp_xml:document_to_list(El); - [] -> "" - end; - _E -> - "" - end; - -extract_user_info(offline, Username, Host) -> - case loaded_module(Host, mod_offline) of - {ok, mnesia} -> - Els = mnesia_pop_offline_messages([], Username, Host), - case Els of - [] -> ""; - Els -> - OfEl = {xmlelement, "offline-messages", [], Els}, - exmpp_xml:document_to_list(OfEl) - end; - {ok, odbc} -> - ""; - _E -> - "" - end; - -extract_user_info(private, Username, Host) -> - case loaded_module(Host, mod_private) of - {ok, mnesia} -> - get_user_private_mnesia(Username, Host); - {ok, odbc} -> - ""; - _E -> - "" - end; - -extract_user_info(vcard, Username, Host) -> - case loaded_module(Host, mod_vcard) of - {ok, _DBType} -> - From = To = jlib:make_jid(Username, Host, ""), - SubelGet = {xmlelement, "vCard", [{"xmlns",?NS_VCARD}], []}, - %%IQGet = #iq{type=get, xmlns=?NS_VCARD, payload=SubelGet}, % this is for 3.0.0 version - IQGet = {iq, "", get, ?NS_VCARD, "" , SubelGet}, - Res = mod_vcard:process_sm_iq(From, To, IQGet), - %%[El] = Res#iq.payload, % this is for 3.0.0 version - {iq, _, result, _, _, Els} = Res, - case Els of - [El] -> exmpp_xml:document_to_list(El); - [] -> "" - end; - _E -> - "" - end. - %%%================================== %%%% Interface with ejabberd offline storage - %% Copied from mod_offline.erl and customized --record(offline_msg, {us, timestamp, expire, from, to, packet}). -mnesia_pop_offline_messages(Ls, User, Server) -> - LUser = jlib:nodeprep(User), - LServer = jlib:nameprep(Server), - US = {LUser, LServer}, - F = fun() -> - Rs = mnesia:wread({offline_msg, US}), - %%mnesia:delete({offline_msg, US}), - Rs - end, - case mnesia:transaction(F) of - {atomic, Rs} -> - TS = now(), - Ls ++ lists:map( - fun(R) -> - {xmlelement, Name, Attrs, Els} = R#offline_msg.packet, - FromString = jlib:jid_to_string(R#offline_msg.from), - Attrs2 = lists:keystore("from", 1, Attrs, {"from", FromString}), - Attrs3 = lists:keystore("xmlns", 1, Attrs2, {"xmlns", "jabber:client"}), - {xmlelement, Name, Attrs3, - Els ++ - [jlib:timestamp_to_xml( - calendar:now_to_universal_time( - R#offline_msg.timestamp))]} - end, - lists:filter( - fun(R) -> - case R#offline_msg.expire of - never -> - true; - TimeStamp -> - TS < TimeStamp - end - end, - lists:keysort(#offline_msg.timestamp, Rs))); - _ -> - Ls - end. - %%%================================== %%%% Interface with ejabberd private storage - -get_user_private_mnesia(Username, Host) -> - ListNsEl = mnesia:dirty_select(private_storage, - [{#private_storage{usns={Username, Host, '$1'}, xml = '$2'}, - [], ['$$']}]), - Els = [exmpp_xml:document_to_list(El) || [_Ns, El] <- ListNsEl], - case lists:flatten(Els) of - "" -> ""; - ElsString -> - io_lib:format("~s", [ElsString]) - end. - %%%================================== %%%% Disk file access - %% @spec () -> string() -make_filename_template() -> - {{Year, Month, Day}, {Hour, Minute, Second}} = calendar:local_time(), - lists:flatten( - io_lib:format("~4..0w~2..0w~2..0w-~2..0w~2..0w~2..0w", - [Year, Month, Day, Hour, Minute, Second])). - %% @spec (Dir::string(), FnT::string()) -> string() -make_main_basefilename(Dir, FnT) -> - Filename2 = filename:flatten([FnT, ".xml"]), - filename:join([Dir, Filename2]). - %% @spec (FnT::string(), Host::string()) -> FnH::string() %% @doc Make the filename for the host. %% Example: ``("20080804-231550", "jabber.example.org") -> "20080804-231550_jabber_example_org.xml"'' -make_host_filename(FnT, Host) -> - Host2 = string:join(string:tokens(Host, "."), "_"), - filename:flatten([FnT, "_", Host2, ".xml"]). - -make_host_basefilename(Dir, FnT) -> - filename:join([Dir, FnT]). - %% @spec (Fn::string()) -> {ok, Fd} -file_open(Fn) -> - file:open(Fn, [write]). - %% @spec (Fd) -> ok -file_close(Fd) -> - file:close(Fd). - %% @spec (Fd, String::string()) -> ok print(Fd, String) -> - io:format(Fd, String, []). - %%%================================== - %%% vim: set filetype=erlang tabstop=8 foldmarker=%%%%,%%%= foldmethod=marker: + file:write(Fd, String). diff --git a/src/ejabberd_rdbms.erl b/src/ejabberd_rdbms.erl index d0b20e6f7..abb17974c 100644 --- a/src/ejabberd_rdbms.erl +++ b/src/ejabberd_rdbms.erl @@ -25,54 +25,51 @@ %%%---------------------------------------------------------------------- -module(ejabberd_rdbms). + -author('alexey@process-one.net'). -export([start/0]). + -include("ejabberd.hrl"). start() -> - %% Check if ejabberd has been compiled with ODBC case catch ejabberd_odbc_sup:module_info() of - {'EXIT',{undef,_}} -> - ?INFO_MSG("ejabberd has not been compiled with relational database support. Skipping database startup.", []); - _ -> - %% If compiled with ODBC, start ODBC on the needed host - start_hosts() + {'EXIT', {undef, _}} -> + ?INFO_MSG("ejabberd has not been compiled with " + "relational database support. Skipping " + "database startup.", + []); + _ -> start_hosts() end. %% Start relationnal DB module on the nodes where it is needed start_hosts() -> - lists:foreach( - fun(Host) -> - case needs_odbc(Host) of - true -> start_odbc(Host); - false -> ok - end - end, ?MYHOSTS). + lists:foreach(fun (Host) -> + case needs_odbc(Host) of + true -> start_odbc(Host); + false -> ok + end + end, + ?MYHOSTS). %% Start the ODBC module on the given host start_odbc(Host) -> - Supervisor_name = gen_mod:get_module_proc(Host, ejabberd_odbc_sup), - ChildSpec = - {Supervisor_name, - {ejabberd_odbc_sup, start_link, [Host]}, - transient, - infinity, - supervisor, - [ejabberd_odbc_sup]}, + Supervisor_name = gen_mod:get_module_proc(Host, + ejabberd_odbc_sup), + ChildSpec = {Supervisor_name, + {ejabberd_odbc_sup, start_link, [Host]}, transient, + infinity, supervisor, [ejabberd_odbc_sup]}, case supervisor:start_child(ejabberd_sup, ChildSpec) of - {ok, _PID} -> - ok; - _Error -> - ?ERROR_MSG("Start of supervisor ~p failed:~n~p~nRetrying...~n", [Supervisor_name, _Error]), - start_odbc(Host) + {ok, _PID} -> ok; + _Error -> + ?ERROR_MSG("Start of supervisor ~p failed:~n~p~nRetrying." + "..~n", + [Supervisor_name, _Error]), + start_odbc(Host) end. %% Returns true if we have configured odbc_server for the given host needs_odbc(Host) -> LHost = jlib:nameprep(Host), - case ejabberd_config:get_local_option({odbc_server, LHost}) of - undefined -> - false; - _ -> true - end. + ejabberd_config:get_local_option( + {odbc_server, LHost}, fun(_) -> true end, false). diff --git a/src/ejabberd_receiver.erl b/src/ejabberd_receiver.erl index 7e93feeb9..c9ed6b350 100644 --- a/src/ejabberd_receiver.erl +++ b/src/ejabberd_receiver.erl @@ -25,6 +25,7 @@ %%%---------------------------------------------------------------------- -module(ejabberd_receiver). + -author('alexey@process-one.net'). -behaviour(gen_server). @@ -41,18 +42,19 @@ close/1]). %% gen_server callbacks --export([init/1, handle_call/3, handle_cast/2, handle_info/2, - terminate/2, code_change/3]). +-export([init/1, handle_call/3, handle_cast/2, + handle_info/2, terminate/2, code_change/3]). -include("ejabberd.hrl"). --record(state, {socket, - sock_mod, - shaper_state, - c2s_pid, - max_stanza_size, - xml_stream_state, - timeout}). +-record(state, + {socket :: inet:socket() | tls:tls_socket() | ejabberd_zlib:zlib_socket(), + sock_mod = gen_tcp :: gen_tcp | tls | ejabberd_zlib, + shaper_state = none :: shaper:shaper(), + c2s_pid :: pid(), + max_stanza_size = infinity :: non_neg_integer() | infinity, + xml_stream_state :: xml_stream:xml_stream_state(), + timeout = infinity:: timeout()}). -define(HIBERNATE_TIMEOUT, 90000). @@ -63,9 +65,16 @@ %% Function: start_link() -> {ok,Pid} | ignore | {error,Error} %% Description: Starts the server %%-------------------------------------------------------------------- +-spec start_link(inet:socket(), atom(), shaper:shaper(), + non_neg_integer() | infinity) -> ignore | + {error, any()} | + {ok, pid()}. + start_link(Socket, SockMod, Shaper, MaxStanzaSize) -> - gen_server:start_link( - ?MODULE, [Socket, SockMod, Shaper, MaxStanzaSize], []). + gen_server:start_link(?MODULE, + [Socket, SockMod, Shaper, MaxStanzaSize], []). + +-spec start(inet:socket(), atom(), shaper:shaper()) -> undefined | pid(). %%-------------------------------------------------------------------- %% Function: start() -> {ok,Pid} | ignore | {error,Error} @@ -74,30 +83,46 @@ start_link(Socket, SockMod, Shaper, MaxStanzaSize) -> start(Socket, SockMod, Shaper) -> start(Socket, SockMod, Shaper, infinity). +-spec start(inet:socket(), atom(), shaper:shaper(), + non_neg_integer() | infinity) -> undefined | pid(). + start(Socket, SockMod, Shaper, MaxStanzaSize) -> - {ok, Pid} = supervisor:start_child( - ejabberd_receiver_sup, - [Socket, SockMod, Shaper, MaxStanzaSize]), + {ok, Pid} = + supervisor:start_child(ejabberd_receiver_sup, + [Socket, SockMod, Shaper, MaxStanzaSize]), Pid. +-spec change_shaper(pid(), shaper:shaper()) -> ok. + change_shaper(Pid, Shaper) -> gen_server:cast(Pid, {change_shaper, Shaper}). -reset_stream(Pid) -> - do_call(Pid, reset_stream). +-spec reset_stream(pid()) -> ok | {error, any()}. + +reset_stream(Pid) -> do_call(Pid, reset_stream). + +-spec starttls(pid(), iodata()) -> {ok, tls:tls_socket()} | {error, any()}. starttls(Pid, TLSSocket) -> do_call(Pid, {starttls, TLSSocket}). +-spec compress(pid(), iodata() | undefined) -> {error, any()} | + {ok, ejabberd_zlib:zlib_socket()}. + compress(Pid, ZlibSocket) -> do_call(Pid, {compress, ZlibSocket}). +-spec become_controller(pid(), pid()) -> ok | {error, any()}. + become_controller(Pid, C2SPid) -> do_call(Pid, {become_controller, C2SPid}). +-spec close(pid()) -> ok. + close(Pid) -> gen_server:cast(Pid, close). + %%==================================================================== %% gen_server callbacks %%==================================================================== @@ -112,16 +137,13 @@ close(Pid) -> init([Socket, SockMod, Shaper, MaxStanzaSize]) -> ShaperState = shaper:new(Shaper), Timeout = case SockMod of - ssl -> - 20; - _ -> - infinity + ssl -> 20; + _ -> infinity end, - {ok, #state{socket = Socket, - sock_mod = SockMod, - shaper_state = ShaperState, - max_stanza_size = MaxStanzaSize, - timeout = Timeout}}. + {ok, + #state{socket = Socket, sock_mod = SockMod, + shaper_state = ShaperState, + max_stanza_size = MaxStanzaSize, timeout = Timeout}}. %%-------------------------------------------------------------------- %% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | @@ -137,11 +159,12 @@ handle_call({starttls, TLSSocket}, _From, c2s_pid = C2SPid, max_stanza_size = MaxStanzaSize} = State) -> close_stream(XMLStreamState), - NewXMLStreamState = xml_stream:new(C2SPid, MaxStanzaSize), + NewXMLStreamState = xml_stream:new(C2SPid, + MaxStanzaSize), NewState = State#state{socket = TLSSocket, sock_mod = tls, xml_stream_state = NewXMLStreamState}, - case tls:recv_data(TLSSocket, "") of + case tls:recv_data(TLSSocket, <<"">>) of {ok, TLSData} -> {reply, ok, process_data(TLSData, NewState), ?HIBERNATE_TIMEOUT}; {error, _Reason} -> @@ -152,11 +175,12 @@ handle_call({compress, ZlibSocket}, _From, c2s_pid = C2SPid, max_stanza_size = MaxStanzaSize} = State) -> close_stream(XMLStreamState), - NewXMLStreamState = xml_stream:new(C2SPid, MaxStanzaSize), + NewXMLStreamState = xml_stream:new(C2SPid, + MaxStanzaSize), NewState = State#state{socket = ZlibSocket, sock_mod = ejabberd_zlib, xml_stream_state = NewXMLStreamState}, - case ejabberd_zlib:recv_data(ZlibSocket, "") of + case ejabberd_zlib:recv_data(ZlibSocket, <<"">>) of {ok, ZlibData} -> {reply, ok, process_data(ZlibData, NewState), ?HIBERNATE_TIMEOUT}; {error, _Reason} -> @@ -164,12 +188,14 @@ handle_call({compress, ZlibSocket}, _From, end; handle_call(reset_stream, _From, #state{xml_stream_state = XMLStreamState, - c2s_pid = C2SPid, - max_stanza_size = MaxStanzaSize} = State) -> + c2s_pid = C2SPid, max_stanza_size = MaxStanzaSize} = + State) -> close_stream(XMLStreamState), - NewXMLStreamState = xml_stream:new(C2SPid, MaxStanzaSize), + NewXMLStreamState = xml_stream:new(C2SPid, + MaxStanzaSize), Reply = ok, - {reply, Reply, State#state{xml_stream_state = NewXMLStreamState}, + {reply, Reply, + State#state{xml_stream_state = NewXMLStreamState}, ?HIBERNATE_TIMEOUT}; handle_call({become_controller, C2SPid}, _From, State) -> XMLStreamState = xml_stream:new(C2SPid, State#state.max_stanza_size), @@ -179,8 +205,7 @@ handle_call({become_controller, C2SPid}, _From, State) -> Reply = ok, {reply, Reply, NewState, ?HIBERNATE_TIMEOUT}; handle_call(_Request, _From, State) -> - Reply = ok, - {reply, Reply, State, ?HIBERNATE_TIMEOUT}. + Reply = ok, {reply, Reply, State, ?HIBERNATE_TIMEOUT}. %%-------------------------------------------------------------------- %% Function: handle_cast(Msg, State) -> {noreply, State} | @@ -190,9 +215,9 @@ handle_call(_Request, _From, State) -> %%-------------------------------------------------------------------- handle_cast({change_shaper, Shaper}, State) -> NewShaperState = shaper:new(Shaper), - {noreply, State#state{shaper_state = NewShaperState}, ?HIBERNATE_TIMEOUT}; -handle_cast(close, State) -> - {stop, normal, State}; + {noreply, State#state{shaper_state = NewShaperState}, + ?HIBERNATE_TIMEOUT}; +handle_cast(close, State) -> {stop, normal, State}; handle_cast(_Msg, State) -> {noreply, State, ?HIBERNATE_TIMEOUT}. @@ -203,45 +228,42 @@ handle_cast(_Msg, State) -> %% Description: Handling all non call/cast messages %%-------------------------------------------------------------------- handle_info({Tag, _TCPSocket, Data}, - #state{socket = Socket, - sock_mod = SockMod} = State) - when (Tag == tcp) or (Tag == ssl) or (Tag == ejabberd_xml) -> + #state{socket = Socket, sock_mod = SockMod} = State) + when (Tag == tcp) or (Tag == ssl) or + (Tag == ejabberd_xml) -> case SockMod of - tls -> - case tls:recv_data(Socket, Data) of - {ok, TLSData} -> - {noreply, process_data(TLSData, State), - ?HIBERNATE_TIMEOUT}; - {error, _Reason} -> - {stop, normal, State} - end; - ejabberd_zlib -> - case ejabberd_zlib:recv_data(Socket, Data) of - {ok, ZlibData} -> - {noreply, process_data(ZlibData, State), - ?HIBERNATE_TIMEOUT}; - {error, _Reason} -> - {stop, normal, State} - end; - _ -> - {noreply, process_data(Data, State), ?HIBERNATE_TIMEOUT} + tls -> + case tls:recv_data(Socket, Data) of + {ok, TLSData} -> + {noreply, process_data(TLSData, State), + ?HIBERNATE_TIMEOUT}; + {error, _Reason} -> {stop, normal, State} + end; + ejabberd_zlib -> + case ejabberd_zlib:recv_data(Socket, Data) of + {ok, ZlibData} -> + {noreply, process_data(ZlibData, State), + ?HIBERNATE_TIMEOUT}; + {error, _Reason} -> {stop, normal, State} + end; + _ -> + {noreply, process_data(Data, State), ?HIBERNATE_TIMEOUT} end; handle_info({Tag, _TCPSocket}, State) - when (Tag == tcp_closed) or (Tag == ssl_closed) -> + when (Tag == tcp_closed) or (Tag == ssl_closed) -> {stop, normal, State}; handle_info({Tag, _TCPSocket, Reason}, State) - when (Tag == tcp_error) or (Tag == ssl_error) -> + when (Tag == tcp_error) or (Tag == ssl_error) -> case Reason of - timeout -> - {noreply, State, ?HIBERNATE_TIMEOUT}; - _ -> - {stop, normal, State} + timeout -> {noreply, State, ?HIBERNATE_TIMEOUT}; + _ -> {stop, normal, State} end; handle_info({timeout, _Ref, activate}, State) -> activate_socket(State), {noreply, State, ?HIBERNATE_TIMEOUT}; handle_info(timeout, State) -> - proc_lib:hibernate(gen_server, enter_loop, [?MODULE, [], State]), + proc_lib:hibernate(gen_server, enter_loop, + [?MODULE, [], State]), {noreply, State, ?HIBERNATE_TIMEOUT}; handle_info(_Info, State) -> {noreply, State, ?HIBERNATE_TIMEOUT}. @@ -253,14 +275,14 @@ handle_info(_Info, State) -> %% cleaning up. When it returns, the gen_server terminates with Reason. %% The return value is ignored. %%-------------------------------------------------------------------- -terminate(_Reason, #state{xml_stream_state = XMLStreamState, - c2s_pid = C2SPid} = State) -> +terminate(_Reason, + #state{xml_stream_state = XMLStreamState, + c2s_pid = C2SPid} = + State) -> close_stream(XMLStreamState), - if - C2SPid /= undefined -> - gen_fsm:send_event(C2SPid, closed); - true -> - ok + if C2SPid /= undefined -> + gen_fsm:send_event(C2SPid, closed); + true -> ok end, catch (State#state.sock_mod):close(State#state.socket), ok. @@ -269,8 +291,7 @@ terminate(_Reason, #state{xml_stream_state = XMLStreamState, %% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} %% Description: Convert process state when code is changed %%-------------------------------------------------------------------- -code_change(_OldVsn, State, _Extra) -> - {ok, State}. +code_change(_OldVsn, State, _Extra) -> {ok, State}. %%-------------------------------------------------------------------- %%% Internal functions @@ -278,48 +299,44 @@ code_change(_OldVsn, State, _Extra) -> activate_socket(#state{socket = Socket, sock_mod = SockMod}) -> - PeerName = - case SockMod of - gen_tcp -> - inet:setopts(Socket, [{active, once}]), - inet:peername(Socket); - _ -> - SockMod:setopts(Socket, [{active, once}]), - SockMod:peername(Socket) - end, + PeerName = case SockMod of + gen_tcp -> + inet:setopts(Socket, [{active, once}]), + inet:peername(Socket); + _ -> + SockMod:setopts(Socket, [{active, once}]), + SockMod:peername(Socket) + end, case PeerName of - {error, _Reason} -> - self() ! {tcp_closed, Socket}; - {ok, _} -> - ok + {error, _Reason} -> self() ! {tcp_closed, Socket}; + {ok, _} -> ok end. %% Data processing for connectors directly generating xmlelement in %% Erlang data structure. %% WARNING: Shaper does not work with Erlang data structure. process_data([], State) -> - activate_socket(State), - State; -process_data([Element|Els], #state{c2s_pid = C2SPid} = State) - when element(1, Element) == xmlelement; - element(1, Element) == xmlstreamstart; - element(1, Element) == xmlstreamelement; - element(1, Element) == xmlstreamend -> - if - C2SPid == undefined -> - State; - true -> - catch gen_fsm:send_event(C2SPid, element_wrapper(Element)), - process_data(Els, State) + activate_socket(State), State; +process_data([Element | Els], + #state{c2s_pid = C2SPid} = State) + when element(1, Element) == xmlel; + element(1, Element) == xmlstreamstart; + element(1, Element) == xmlstreamelement; + element(1, Element) == xmlstreamend -> + if C2SPid == undefined -> State; + true -> + catch gen_fsm:send_event(C2SPid, + element_wrapper(Element)), + process_data(Els, State) end; %% Data processing for connectors receivind data as string. process_data(Data, #state{xml_stream_state = XMLStreamState, - shaper_state = ShaperState, - c2s_pid = C2SPid} = State) -> - ?DEBUG("Received XML on stream = ~p", [binary_to_list(Data)]), + shaper_state = ShaperState, c2s_pid = C2SPid} = + State) -> + ?DEBUG("Received XML on stream = ~p", [(Data)]), XMLStreamState1 = xml_stream:parse(XMLStreamState, Data), - {NewShaperState, Pause} = shaper:update(ShaperState, size(Data)), + {NewShaperState, Pause} = shaper:update(ShaperState, byte_size(Data)), if C2SPid == undefined -> ok; @@ -336,20 +353,16 @@ process_data(Data, %% speaking directly Erlang XML), we wrap it inside the same %% xmlstreamelement coming from the XML parser. element_wrapper(XMLElement) - when element(1, XMLElement) == xmlelement -> + when element(1, XMLElement) == xmlel -> {xmlstreamelement, XMLElement}; -element_wrapper(Element) -> - Element. +element_wrapper(Element) -> Element. -close_stream(undefined) -> - ok; +close_stream(undefined) -> ok; close_stream(XMLStreamState) -> xml_stream:close(XMLStreamState). do_call(Pid, Msg) -> case catch gen_server:call(Pid, Msg) of - {'EXIT', Why} -> - {error, Why}; - Res -> - Res + {'EXIT', Why} -> {error, Why}; + Res -> Res end. diff --git a/src/ejabberd_regexp.erl b/src/ejabberd_regexp.erl index d6210b562..6603ec626 100644 --- a/src/ejabberd_regexp.erl +++ b/src/ejabberd_regexp.erl @@ -25,48 +25,72 @@ %%%---------------------------------------------------------------------- -module(ejabberd_regexp). + -compile([export_all]). -exec(ReM, ReF, ReA, RgM, RgF, RgA) -> - try apply(ReM, ReF, ReA) - catch - error:undef -> - apply(RgM, RgF, RgA); - A:B -> - {error, {A, B}} +exec({ReM, ReF, ReA}, {RgM, RgF, RgA}) -> + try apply(ReM, ReF, ReA) catch + error:undef -> apply(RgM, RgF, RgA); + A:B -> {error, {A, B}} end. +-spec run(binary(), binary()) -> match | nomatch | {error, any()}. + run(String, Regexp) -> - case exec(re, run, [String, Regexp, [{capture, none}]], regexp, first_match, [String, Regexp]) of - {match, _, _} -> match; - {match, _} -> match; - match -> match; - nomatch -> nomatch; - {error, Error} -> {error, Error} + case exec({re, run, [String, Regexp, [{capture, none}]]}, + {regexp, first_match, [binary_to_list(String), + binary_to_list(Regexp)]}) + of + {match, _, _} -> match; + {match, _} -> match; + match -> match; + nomatch -> nomatch; + {error, Error} -> {error, Error} end. +-spec split(binary(), binary()) -> [binary()]. + split(String, Regexp) -> - case exec(re, split, [String, Regexp, [{return, list}]], regexp, split, [String, Regexp]) of - {ok, FieldList} -> FieldList; - {error, Error} -> throw(Error); - A -> A + case exec({re, split, [String, Regexp, [{return, binary}]]}, + {regexp, split, [binary_to_list(String), + binary_to_list(Regexp)]}) + of + {ok, FieldList} -> [iolist_to_binary(F) || F <- FieldList]; + {error, Error} -> throw(Error); + A -> A end. +-spec replace(binary(), binary(), binary()) -> binary(). + replace(String, Regexp, New) -> - case exec(re, replace, [String, Regexp, New, [{return, list}]], regexp, sub, [String, Regexp, New]) of - {ok, NewString, _RepCount} -> NewString; - {error, Error} -> throw(Error); - A -> A + case exec({re, replace, [String, Regexp, New, [{return, binary}]]}, + {regexp, sub, [binary_to_list(String), + binary_to_list(Regexp), + binary_to_list(New)]}) + of + {ok, NewString, _RepCount} -> iolist_to_binary(NewString); + {error, Error} -> throw(Error); + A -> A end. +-spec greplace(binary(), binary(), binary()) -> binary(). + greplace(String, Regexp, New) -> - case exec(re, replace, [String, Regexp, New, [global, {return, list}]], regexp, sub, [String, Regexp, New]) of - {ok, NewString, _RepCount} -> NewString; - {error, Error} -> throw(Error); - A -> A + case exec({re, replace, [String, Regexp, New, [global, {return, binary}]]}, + {regexp, sub, [binary_to_list(String), + binary_to_list(Regexp), + binary_to_list(New)]}) + of + {ok, NewString, _RepCount} -> iolist_to_binary(NewString); + {error, Error} -> throw(Error); + A -> A end. +-spec sh_to_awk(binary()) -> binary(). + sh_to_awk(ShRegExp) -> - case exec(xmerl_regexp, sh_to_awk, [ShRegExp], regexp, sh_to_awk, [ShRegExp]) of - A -> A + case exec({xmerl_regexp, sh_to_awk, [binary_to_list(ShRegExp)]}, + {regexp, sh_to_awk, [binary_to_list(ShRegExp)]}) + of + A -> iolist_to_binary(A) end. diff --git a/src/ejabberd_router.erl b/src/ejabberd_router.erl index f1e70ad0f..8577d81ad 100644 --- a/src/ejabberd_router.erl +++ b/src/ejabberd_router.erl @@ -25,6 +25,7 @@ %%%---------------------------------------------------------------------- -module(ejabberd_router). + -author('alexey@process-one.net'). -behaviour(gen_server). @@ -44,13 +45,17 @@ -export([start_link/0]). %% gen_server callbacks --export([init/1, handle_call/3, handle_cast/2, handle_info/2, - terminate/2, code_change/3]). +-export([init/1, handle_call/3, handle_cast/2, + handle_info/2, terminate/2, code_change/3]). -include("ejabberd.hrl"). + -include("jlib.hrl"). +-type local_hint() :: undefined | integer() | {apply, atom(), atom()}. + -record(route, {domain, pid, local_hint}). + -record(state, {}). %%==================================================================== @@ -63,6 +68,7 @@ start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). +-spec route(jid(), jid(), xmlel()) -> ok. route(From, To, Packet) -> case catch do_route(From, To, Packet) of @@ -75,119 +81,131 @@ route(From, To, Packet) -> %% Route the error packet only if the originating packet is not an error itself. %% RFC3920 9.3.1 +-spec route_error(jid(), jid(), xmlel(), xmlel()) -> ok. + route_error(From, To, ErrPacket, OrigPacket) -> - {xmlelement, _Name, Attrs, _Els} = OrigPacket, - case "error" == xml:get_attr_s("type", Attrs) of - false -> - route(From, To, ErrPacket); - true -> - ok + #xmlel{attrs = Attrs} = OrigPacket, + case <<"error">> == xml:get_attr_s(<<"type">>, Attrs) of + false -> route(From, To, ErrPacket); + true -> ok end. +-spec register_route(binary()) -> term(). + register_route(Domain) -> register_route(Domain, undefined). +-spec register_route(binary(), local_hint()) -> term(). + register_route(Domain, LocalHint) -> case jlib:nameprep(Domain) of - error -> - erlang:error({invalid_domain, Domain}); - LDomain -> - Pid = self(), - case get_component_number(LDomain) of - undefined -> - F = fun() -> - mnesia:write(#route{domain = LDomain, - pid = Pid, - local_hint = LocalHint}) - end, - mnesia:transaction(F); - N -> - F = fun() -> - case mnesia:wread({route, LDomain}) of - [] -> - mnesia:write( - #route{domain = LDomain, - pid = Pid, - local_hint = 1}), - lists:foreach( - fun(I) -> - mnesia:write( - #route{domain = LDomain, - pid = undefined, - local_hint = I}) - end, lists:seq(2, N)); - Rs -> - lists:any( - fun(#route{pid = undefined, - local_hint = I} = R) -> - mnesia:write( - #route{domain = LDomain, - pid = Pid, - local_hint = I}), - mnesia:delete_object(R), - true; - (_) -> - false - end, Rs) - end - end, - mnesia:transaction(F) - end + error -> erlang:error({invalid_domain, Domain}); + LDomain -> + Pid = self(), + case get_component_number(LDomain) of + undefined -> + F = fun () -> + mnesia:write(#route{domain = LDomain, pid = Pid, + local_hint = LocalHint}) + end, + mnesia:transaction(F); + N -> + F = fun () -> + case mnesia:wread({route, LDomain}) of + [] -> + mnesia:write(#route{domain = LDomain, + pid = Pid, + local_hint = 1}), + lists:foreach(fun (I) -> + mnesia:write(#route{domain + = + LDomain, + pid + = + undefined, + local_hint + = + I}) + end, + lists:seq(2, N)); + Rs -> + lists:any(fun (#route{pid = undefined, + local_hint = I} = + R) -> + mnesia:write(#route{domain = + LDomain, + pid = + Pid, + local_hint + = + I}), + mnesia:delete_object(R), + true; + (_) -> false + end, + Rs) + end + end, + mnesia:transaction(F) + end end. +-spec register_routes([binary()]) -> ok. + register_routes(Domains) -> - lists:foreach(fun(Domain) -> - register_route(Domain) - end, Domains). + lists:foreach(fun (Domain) -> register_route(Domain) + end, + Domains). + +-spec unregister_route(binary()) -> term(). unregister_route(Domain) -> case jlib:nameprep(Domain) of - error -> - erlang:error({invalid_domain, Domain}); - LDomain -> - Pid = self(), - case get_component_number(LDomain) of - undefined -> - F = fun() -> - case mnesia:match_object( - #route{domain = LDomain, - pid = Pid, - _ = '_'}) of - [R] -> - mnesia:delete_object(R); - _ -> - ok - end - end, - mnesia:transaction(F); - _ -> - F = fun() -> - case mnesia:match_object(#route{domain=LDomain, - pid = Pid, - _ = '_'}) of - [R] -> - I = R#route.local_hint, - mnesia:write( - #route{domain = LDomain, - pid = undefined, - local_hint = I}), - mnesia:delete_object(R); - _ -> - ok - end - end, - mnesia:transaction(F) - end + error -> erlang:error({invalid_domain, Domain}); + LDomain -> + Pid = self(), + case get_component_number(LDomain) of + undefined -> + F = fun () -> + case mnesia:match_object(#route{domain = LDomain, + pid = Pid, _ = '_'}) + of + [R] -> mnesia:delete_object(R); + _ -> ok + end + end, + mnesia:transaction(F); + _ -> + F = fun () -> + case mnesia:match_object(#route{domain = LDomain, + pid = Pid, _ = '_'}) + of + [R] -> + I = R#route.local_hint, + mnesia:write(#route{domain = LDomain, + pid = undefined, + local_hint = I}), + mnesia:delete_object(R); + _ -> ok + end + end, + mnesia:transaction(F) + end end. -unregister_routes(Domains) -> - lists:foreach(fun(Domain) -> - unregister_route(Domain) - end, Domains). +-spec unregister_routes([binary()]) -> ok. +unregister_routes(Domains) -> + lists:foreach(fun (Domain) -> unregister_route(Domain) + end, + Domains). + +-spec dirty_get_all_routes() -> [binary()]. dirty_get_all_routes() -> - lists:usort(mnesia:dirty_all_keys(route)) -- ?MYHOSTS. + lists:usort(mnesia:dirty_all_keys(route)) -- (?MYHOSTS). + +-spec dirty_get_all_domains() -> [binary()]. dirty_get_all_domains() -> lists:usort(mnesia:dirty_all_keys(route)). @@ -207,17 +225,14 @@ dirty_get_all_domains() -> init([]) -> update_tables(), mnesia:create_table(route, - [{ram_copies, [node()]}, - {type, bag}, - {attributes, - record_info(fields, route)}]), + [{ram_copies, [node()]}, {type, bag}, + {attributes, record_info(fields, route)}]), mnesia:add_table_copy(route, node(), ram_copies), mnesia:subscribe({table, route, simple}), - lists:foreach( - fun(Pid) -> - erlang:monitor(process, Pid) - end, - mnesia:dirty_select(route, [{{route, '_', '$1', '_'}, [], ['$1']}])), + lists:foreach(fun (Pid) -> erlang:monitor(process, Pid) + end, + mnesia:dirty_select(route, + [{{route, '_', '$1', '_'}, [], ['$1']}])), {ok, #state{}}. %%-------------------------------------------------------------------- @@ -230,8 +245,7 @@ init([]) -> %% Description: Handling call messages %%-------------------------------------------------------------------- handle_call(_Request, _From, State) -> - Reply = ok, - {reply, Reply, State}. + Reply = ok, {reply, Reply, State}. %%-------------------------------------------------------------------- %% Function: handle_cast(Msg, State) -> {noreply, State} | @@ -239,8 +253,7 @@ handle_call(_Request, _From, State) -> %% {stop, Reason, State} %% Description: Handling cast messages %%-------------------------------------------------------------------- -handle_cast(_Msg, State) -> - {noreply, State}. +handle_cast(_Msg, State) -> {noreply, State}. %%-------------------------------------------------------------------- %% Function: handle_info(Info, State) -> {noreply, State} | @@ -250,39 +263,35 @@ handle_cast(_Msg, State) -> %%-------------------------------------------------------------------- handle_info({route, From, To, Packet}, State) -> case catch do_route(From, To, Packet) of - {'EXIT', Reason} -> - ?ERROR_MSG("~p~nwhen processing: ~p", - [Reason, {From, To, Packet}]); - _ -> - ok + {'EXIT', Reason} -> + ?ERROR_MSG("~p~nwhen processing: ~p", + [Reason, {From, To, Packet}]); + _ -> ok end, {noreply, State}; -handle_info({mnesia_table_event, {write, #route{pid = Pid}, _ActivityId}}, +handle_info({mnesia_table_event, + {write, #route{pid = Pid}, _ActivityId}}, State) -> - erlang:monitor(process, Pid), - {noreply, State}; + erlang:monitor(process, Pid), {noreply, State}; handle_info({'DOWN', _Ref, _Type, Pid, _Info}, State) -> - F = fun() -> - Es = mnesia:select( - route, - [{#route{pid = Pid, _ = '_'}, - [], - ['$_']}]), - lists:foreach( - fun(E) -> - if - is_integer(E#route.local_hint) -> - LDomain = E#route.domain, - I = E#route.local_hint, - mnesia:write( - #route{domain = LDomain, - pid = undefined, - local_hint = I}), - mnesia:delete_object(E); - true -> - mnesia:delete_object(E) - end - end, Es) + F = fun () -> + Es = mnesia:select(route, + [{#route{pid = Pid, _ = '_'}, [], ['$_']}]), + lists:foreach(fun (E) -> + if is_integer(E#route.local_hint) -> + LDomain = E#route.domain, + I = E#route.local_hint, + mnesia:write(#route{domain = + LDomain, + pid = + undefined, + local_hint = + I}), + mnesia:delete_object(E); + true -> mnesia:delete_object(E) + end + end, + Es) end, mnesia:transaction(F), {noreply, State}; @@ -310,107 +319,93 @@ code_change(_OldVsn, State, _Extra) -> %%% Internal functions %%-------------------------------------------------------------------- do_route(OrigFrom, OrigTo, OrigPacket) -> - ?DEBUG("route~n\tfrom ~p~n\tto ~p~n\tpacket ~p~n", + ?DEBUG("route~n\tfrom ~p~n\tto ~p~n\tpacket " + "~p~n", [OrigFrom, OrigTo, OrigPacket]), case ejabberd_hooks:run_fold(filter_packet, - {OrigFrom, OrigTo, OrigPacket}, []) of - {From, To, Packet} -> - LDstDomain = To#jid.lserver, - case mnesia:dirty_read(route, LDstDomain) of - [] -> - ejabberd_s2s:route(From, To, Packet); - [R] -> - Pid = R#route.pid, - if - node(Pid) == node() -> - case R#route.local_hint of - {apply, Module, Function} -> - Module:Function(From, To, Packet); - _ -> - Pid ! {route, From, To, Packet} - end; - is_pid(Pid) -> - Pid ! {route, From, To, Packet}; - true -> - drop - end; - Rs -> - Value = case ejabberd_config:get_local_option( - {domain_balancing, LDstDomain}) of - undefined -> now(); - random -> now(); - source -> jlib:jid_tolower(From); - destination -> jlib:jid_tolower(To); - bare_source -> - jlib:jid_remove_resource( - jlib:jid_tolower(From)); - bare_destination -> - jlib:jid_remove_resource( - jlib:jid_tolower(To)) - end, - case get_component_number(LDstDomain) of - undefined -> - case [R || R <- Rs, node(R#route.pid) == node()] of - [] -> - R = lists:nth(erlang:phash(Value, length(Rs)), Rs), - Pid = R#route.pid, - if - is_pid(Pid) -> - Pid ! {route, From, To, Packet}; - true -> - drop - end; - LRs -> - R = lists:nth(erlang:phash(Value, length(LRs)), LRs), - Pid = R#route.pid, - case R#route.local_hint of - {apply, Module, Function} -> - Module:Function(From, To, Packet); - _ -> - Pid ! {route, From, To, Packet} - end - end; - _ -> - SRs = lists:ukeysort(#route.local_hint, Rs), - R = lists:nth(erlang:phash(Value, length(SRs)), SRs), + {OrigFrom, OrigTo, OrigPacket}, []) + of + {From, To, Packet} -> + LDstDomain = To#jid.lserver, + case mnesia:dirty_read(route, LDstDomain) of + [] -> ejabberd_s2s:route(From, To, Packet); + [R] -> + Pid = R#route.pid, + if node(Pid) == node() -> + case R#route.local_hint of + {apply, Module, Function} -> + Module:Function(From, To, Packet); + _ -> Pid ! {route, From, To, Packet} + end; + is_pid(Pid) -> Pid ! {route, From, To, Packet}; + true -> drop + end; + Rs -> + Value = case + ejabberd_config:get_local_option({domain_balancing, + LDstDomain}, fun(D) when is_atom(D) -> D end) + of + undefined -> now(); + random -> now(); + source -> jlib:jid_tolower(From); + destination -> jlib:jid_tolower(To); + bare_source -> + jlib:jid_remove_resource(jlib:jid_tolower(From)); + bare_destination -> + jlib:jid_remove_resource(jlib:jid_tolower(To)) + end, + case get_component_number(LDstDomain) of + undefined -> + case [R || R <- Rs, node(R#route.pid) == node()] of + [] -> + R = lists:nth(erlang:phash(Value, str:len(Rs)), Rs), Pid = R#route.pid, - if - is_pid(Pid) -> - Pid ! {route, From, To, Packet}; - true -> - drop + if is_pid(Pid) -> Pid ! {route, From, To, Packet}; + true -> drop + end; + LRs -> + R = lists:nth(erlang:phash(Value, str:len(LRs)), + LRs), + Pid = R#route.pid, + case R#route.local_hint of + {apply, Module, Function} -> + Module:Function(From, To, Packet); + _ -> Pid ! {route, From, To, Packet} end - end - end; - drop -> - ok + end; + _ -> + SRs = lists:ukeysort(#route.local_hint, Rs), + R = lists:nth(erlang:phash(Value, str:len(SRs)), SRs), + Pid = R#route.pid, + if is_pid(Pid) -> Pid ! {route, From, To, Packet}; + true -> drop + end + end + end; + drop -> ok end. get_component_number(LDomain) -> - case ejabberd_config:get_local_option( - {domain_balancing_component_number, LDomain}) of - N when is_integer(N), - N > 1 -> - N; - _ -> - undefined + case + ejabberd_config:get_local_option({domain_balancing_component_number, + LDomain}, fun(D) -> D end) + of + N when is_integer(N), N > 1 -> N; + _ -> undefined end. + update_tables() -> case catch mnesia:table_info(route, attributes) of - [domain, node, pid] -> - mnesia:delete_table(route); - [domain, pid] -> - mnesia:delete_table(route); - [domain, pid, local_hint] -> - ok; - {'EXIT', _} -> - ok + [domain, node, pid] -> mnesia:delete_table(route); + [domain, pid] -> mnesia:delete_table(route); + [domain, pid, local_hint] -> ok; + {'EXIT', _} -> ok end, - case lists:member(local_route, mnesia:system_info(tables)) of - true -> - mnesia:delete_table(local_route); - false -> - ok + case lists:member(local_route, + mnesia:system_info(tables)) + of + true -> mnesia:delete_table(local_route); + false -> ok end. diff --git a/src/ejabberd_s2s.erl b/src/ejabberd_s2s.erl index b06a7ab6c..0832d1dfd 100644 --- a/src/ejabberd_s2s.erl +++ b/src/ejabberd_s2s.erl @@ -25,49 +25,52 @@ %%%---------------------------------------------------------------------- -module(ejabberd_s2s). + -author('alexey@process-one.net'). -behaviour(gen_server). %% API --export([start_link/0, - route/3, - have_connection/1, - has_key/2, - get_connections_pids/1, - try_register/1, - remove_connection/3, - find_connection/2, - dirty_get_connections/0, - allow_host/2, - incoming_s2s_number/0, - outgoing_s2s_number/0, +-export([start_link/0, route/3, have_connection/1, + has_key/2, get_connections_pids/1, try_register/1, + remove_connection/3, find_connection/2, + dirty_get_connections/0, allow_host/2, + incoming_s2s_number/0, outgoing_s2s_number/0, clean_temporarily_blocked_table/0, list_temporarily_blocked_hosts/0, - external_host_overloaded/1, - is_temporarly_blocked/1 - ]). + external_host_overloaded/1, is_temporarly_blocked/1]). %% gen_server callbacks --export([init/1, handle_call/3, handle_cast/2, handle_info/2, - terminate/2, code_change/3]). +-export([init/1, handle_call/3, handle_cast/2, + handle_info/2, terminate/2, code_change/3]). + %% ejabberd API -export([get_info_s2s_connections/1]). -include("ejabberd.hrl"). + -include("jlib.hrl"). + -include("ejabberd_commands.hrl"). -define(DEFAULT_MAX_S2S_CONNECTIONS_NUMBER, 1). + -define(DEFAULT_MAX_S2S_CONNECTIONS_NUMBER_PER_NODE, 1). -define(S2S_OVERLOAD_BLOCK_PERIOD, 60). + %% once a server is temporarly blocked, it stay blocked for 60 seconds --record(s2s, {fromto, pid, key}). +-record(s2s, {fromto = {<<"">>, <<"">>} :: {binary(), binary()}, + pid = self() :: pid() | '_', + key = <<"">> :: binary() | '_'}). + -record(state, {}). --record(temporarily_blocked, {host, timestamp}). +-record(temporarily_blocked, {host = <<"">> :: binary(), + timestamp = now() :: erlang:timestamp()}). + +-type temporarily_blocked() :: #temporarily_blocked{}. %%==================================================================== %% API @@ -77,57 +80,73 @@ %% Description: Starts the server %%-------------------------------------------------------------------- start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). + gen_server:start_link({local, ?MODULE}, ?MODULE, [], + []). + +-spec route(jid(), jid(), xmlel()) -> ok. route(From, To, Packet) -> case catch do_route(From, To, Packet) of - {'EXIT', Reason} -> - ?ERROR_MSG("~p~nwhen processing: ~p", - [Reason, {From, To, Packet}]); - _ -> - ok + {'EXIT', Reason} -> + ?ERROR_MSG("~p~nwhen processing: ~p", + [Reason, {From, To, Packet}]); + _ -> ok end. clean_temporarily_blocked_table() -> - mnesia:clear_table(temporarily_blocked). + mnesia:clear_table(temporarily_blocked). + +-spec list_temporarily_blocked_hosts() -> [temporarily_blocked()]. + list_temporarily_blocked_hosts() -> - ets:tab2list(temporarily_blocked). + ets:tab2list(temporarily_blocked). + +-spec external_host_overloaded(binary()) -> {aborted, any()} | {atomic, ok}. external_host_overloaded(Host) -> - ?INFO_MSG("Disabling connections from ~s for ~p seconds", [Host, ?S2S_OVERLOAD_BLOCK_PERIOD]), - mnesia:transaction( fun() -> - mnesia:write(#temporarily_blocked{host = Host, timestamp = now()}) - end). + ?INFO_MSG("Disabling connections from ~s for ~p " + "seconds", + [Host, ?S2S_OVERLOAD_BLOCK_PERIOD]), + mnesia:transaction(fun () -> + mnesia:write(#temporarily_blocked{host = Host, + timestamp = + now()}) + end). + +-spec is_temporarly_blocked(binary()) -> boolean(). is_temporarly_blocked(Host) -> - case mnesia:dirty_read(temporarily_blocked, Host) of - [] -> false; - [#temporarily_blocked{timestamp = T}=Entry] -> - case timer:now_diff(now(), T) of - N when N > ?S2S_OVERLOAD_BLOCK_PERIOD * 1000 * 1000 -> - mnesia:dirty_delete_object(Entry), - false; - _ -> - true - end - end. + case mnesia:dirty_read(temporarily_blocked, Host) of + [] -> false; + [#temporarily_blocked{timestamp = T} = Entry] -> + case timer:now_diff(now(), T) of + N when N > (?S2S_OVERLOAD_BLOCK_PERIOD) * 1000 * 1000 -> + mnesia:dirty_delete_object(Entry), false; + _ -> true + end + end. +-spec remove_connection({binary(), binary()}, + pid(), binary()) -> {atomic, ok} | + ok | + {aborted, any()}. remove_connection(FromTo, Pid, Key) -> - case catch mnesia:dirty_match_object(s2s, #s2s{fromto = FromTo, - pid = Pid, - _ = '_'}) of - [#s2s{pid = Pid, key = Key}] -> - F = fun() -> - mnesia:delete_object(#s2s{fromto = FromTo, - pid = Pid, - key = Key}) - end, - mnesia:transaction(F); - _ -> - ok + case catch mnesia:dirty_match_object(s2s, + #s2s{fromto = FromTo, pid = Pid, + _ = '_'}) + of + [#s2s{pid = Pid, key = Key}] -> + F = fun () -> + mnesia:delete_object(#s2s{fromto = FromTo, pid = Pid, + key = Key}) + end, + mnesia:transaction(F); + _ -> ok end. +-spec have_connection({binary(), binary()}) -> boolean(). + have_connection(FromTo) -> case catch mnesia:dirty_read(s2s, FromTo) of [_] -> @@ -136,6 +155,8 @@ have_connection(FromTo) -> false end. +-spec has_key({binary(), binary()}, binary()) -> boolean(). + has_key(FromTo, Key) -> case mnesia:dirty_select(s2s, [{#s2s{fromto = FromTo, key = Key, _ = '_'}, @@ -147,6 +168,8 @@ has_key(FromTo, Key) -> true end. +-spec get_connections_pids({binary(), binary()}) -> [pid()]. + get_connections_pids(FromTo) -> case catch mnesia:dirty_read(s2s, FromTo) of L when is_list(L) -> @@ -155,33 +178,32 @@ get_connections_pids(FromTo) -> [] end. +-spec try_register({binary(), binary()}) -> {key, binary()} | false. + try_register(FromTo) -> Key = randoms:get_string(), MaxS2SConnectionsNumber = max_s2s_connections_number(FromTo), MaxS2SConnectionsNumberPerNode = max_s2s_connections_number_per_node(FromTo), - F = fun() -> + F = fun () -> L = mnesia:read({s2s, FromTo}), - NeededConnections = needed_connections_number( - L, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode), - if - NeededConnections > 0 -> - mnesia:write(#s2s{fromto = FromTo, - pid = self(), - key = Key}), - {key, Key}; - true -> - false + NeededConnections = needed_connections_number(L, + MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode), + if NeededConnections > 0 -> + mnesia:write(#s2s{fromto = FromTo, pid = self(), + key = Key}), + {key, Key}; + true -> false end end, case mnesia:transaction(F) of - {atomic, Res} -> - Res; - _ -> - false + {atomic, Res} -> Res; + _ -> false end. +-spec dirty_get_connections() -> [{binary(), binary()}]. + dirty_get_connections() -> mnesia:dirty_all_keys(s2s). @@ -239,15 +261,13 @@ handle_info({mnesia_system_event, {mnesia_down, Node}}, State) -> {noreply, State}; handle_info({route, From, To, Packet}, State) -> case catch do_route(From, To, Packet) of - {'EXIT', Reason} -> - ?ERROR_MSG("~p~nwhen processing: ~p", - [Reason, {From, To, Packet}]); - _ -> - ok + {'EXIT', Reason} -> + ?ERROR_MSG("~p~nwhen processing: ~p", + [Reason, {From, To, Packet}]); + _ -> ok end, {noreply, State}; -handle_info(_Info, State) -> - {noreply, State}. +handle_info(_Info, State) -> {noreply, State}. %%-------------------------------------------------------------------- %% Function: terminate(Reason, State) -> void() @@ -284,76 +304,79 @@ clean_table_from_bad_node(Node) -> mnesia:async_dirty(F). do_route(From, To, Packet) -> - ?DEBUG("s2s manager~n\tfrom ~p~n\tto ~p~n\tpacket ~P~n", - [From, To, Packet, 8]), + ?DEBUG("s2s manager~n\tfrom ~p~n\tto ~p~n\tpacket " + "~P~n", + [From, To, Packet, 8]), case find_connection(From, To) of - {atomic, Pid} when is_pid(Pid) -> - ?DEBUG("sending to process ~p~n", [Pid]), - {xmlelement, Name, Attrs, Els} = Packet, - NewAttrs = jlib:replace_from_to_attrs(jlib:jid_to_string(From), - jlib:jid_to_string(To), - Attrs), - #jid{lserver = MyServer} = From, - ejabberd_hooks:run( - s2s_send_packet, - MyServer, - [From, To, Packet]), - send_element(Pid, {xmlelement, Name, NewAttrs, Els}), - ok; - {aborted, _Reason} -> - case xml:get_tag_attr_s("type", Packet) of - "error" -> ok; - "result" -> ok; - _ -> - Err = jlib:make_error_reply( - Packet, ?ERR_SERVICE_UNAVAILABLE), - ejabberd_router:route(To, From, Err) - end, - false + {atomic, Pid} when is_pid(Pid) -> + ?DEBUG("sending to process ~p~n", [Pid]), + #xmlel{name = Name, attrs = Attrs, children = Els} = + Packet, + NewAttrs = + jlib:replace_from_to_attrs(jlib:jid_to_string(From), + jlib:jid_to_string(To), Attrs), + #jid{lserver = MyServer} = From, + ejabberd_hooks:run(s2s_send_packet, MyServer, + [From, To, Packet]), + send_element(Pid, + #xmlel{name = Name, attrs = NewAttrs, children = Els}), + ok; + {aborted, _Reason} -> + case xml:get_tag_attr_s(<<"type">>, Packet) of + <<"error">> -> ok; + <<"result">> -> ok; + _ -> + Err = jlib:make_error_reply(Packet, + ?ERR_SERVICE_UNAVAILABLE), + ejabberd_router:route(To, From, Err) + end, + false end. +-spec find_connection(jid(), jid()) -> {aborted, any()} | {atomic, pid()}. + find_connection(From, To) -> #jid{lserver = MyServer} = From, #jid{lserver = Server} = To, FromTo = {MyServer, Server}, - MaxS2SConnectionsNumber = max_s2s_connections_number(FromTo), + MaxS2SConnectionsNumber = + max_s2s_connections_number(FromTo), MaxS2SConnectionsNumberPerNode = max_s2s_connections_number_per_node(FromTo), ?DEBUG("Finding connection for ~p~n", [FromTo]), case catch mnesia:dirty_read(s2s, FromTo) of - {'EXIT', Reason} -> - {aborted, Reason}; - [] -> - %% We try to establish all the connections if the host is not a - %% service and if the s2s host is not blacklisted or - %% is in whitelist: - case not is_service(From, To) andalso allow_host(MyServer, Server) of - true -> - NeededConnections = needed_connections_number( - [], MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode), - open_several_connections( - NeededConnections, MyServer, - Server, From, FromTo, - MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode); - false -> - {aborted, error} - end; - L when is_list(L) -> - NeededConnections = needed_connections_number( - L, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode), - if - NeededConnections > 0 -> - %% We establish the missing connections for this pair. - open_several_connections( - NeededConnections, MyServer, - Server, From, FromTo, - MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode); - true -> - %% We choose a connexion from the pool of opened ones. - {atomic, choose_connection(From, L)} - end + {'EXIT', Reason} -> {aborted, Reason}; + [] -> + %% We try to establish all the connections if the host is not a + %% service and if the s2s host is not blacklisted or + %% is in whitelist: + case not is_service(From, To) andalso + allow_host(MyServer, Server) + of + true -> + NeededConnections = needed_connections_number([], + MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode), + open_several_connections(NeededConnections, MyServer, + Server, From, FromTo, + MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode); + false -> {aborted, error} + end; + L when is_list(L) -> + NeededConnections = needed_connections_number(L, + MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode), + if NeededConnections > 0 -> + %% We establish the missing connections for this pair. + open_several_connections(NeededConnections, MyServer, + Server, From, FromTo, + MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode); + true -> + %% We choose a connexion from the pool of opened ones. + {atomic, choose_connection(From, L)} + end end. choose_connection(From, Connections) -> @@ -361,29 +384,26 @@ choose_connection(From, Connections) -> choose_pid(From, Pids) -> Pids1 = case [P || P <- Pids, node(P) == node()] of - [] -> Pids; - Ps -> Ps + [] -> Pids; + Ps -> Ps end, - % Use sticky connections based on the JID of the sender (whithout - % the resource to ensure that a muc room always uses the same - % connection) - Pid = lists:nth(erlang:phash(jlib:jid_remove_resource(From), length(Pids1)), - Pids1), + Pid = + lists:nth(erlang:phash(jlib:jid_remove_resource(From), + length(Pids1)), + Pids1), ?DEBUG("Using ejabberd_s2s_out ~p~n", [Pid]), Pid. -open_several_connections(N, MyServer, Server, From, FromTo, - MaxS2SConnectionsNumber, +open_several_connections(N, MyServer, Server, From, + FromTo, MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode) -> - ConnectionsResult = - [new_connection(MyServer, Server, From, FromTo, - MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode) - || _N <- lists:seq(1, N)], + ConnectionsResult = [new_connection(MyServer, Server, + From, FromTo, MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode) + || _N <- lists:seq(1, N)], case [PID || {atomic, PID} <- ConnectionsResult] of - [] -> - hd(ConnectionsResult); - PIDs -> - {atomic, choose_pid(From, PIDs)} + [] -> hd(ConnectionsResult); + PIDs -> {atomic, choose_pid(From, PIDs)} end. new_connection(MyServer, Server, From, FromTo, @@ -393,41 +413,38 @@ new_connection(MyServer, Server, From, FromTo, MyServer, Server, {new, Key}), F = fun() -> L = mnesia:read({s2s, FromTo}), - NeededConnections = needed_connections_number( - L, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode), - if - NeededConnections > 0 -> - mnesia:write(#s2s{fromto = FromTo, - pid = Pid, - key = Key}), - ?INFO_MSG("New s2s connection started ~p", [Pid]), - Pid; - true -> - choose_connection(From, L) + NeededConnections = needed_connections_number(L, + MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode), + if NeededConnections > 0 -> + mnesia:write(#s2s{fromto = FromTo, pid = Pid, + key = Key}), + ?INFO_MSG("New s2s connection started ~p", [Pid]), + Pid; + true -> choose_connection(From, L) end end, TRes = mnesia:transaction(F), case TRes of - {atomic, Pid} -> - ejabberd_s2s_out:start_connection(Pid); - _ -> - ejabberd_s2s_out:stop_connection(Pid) + {atomic, Pid} -> ejabberd_s2s_out:start_connection(Pid); + _ -> ejabberd_s2s_out:stop_connection(Pid) end, TRes. max_s2s_connections_number({From, To}) -> - case acl:match_rule( - From, max_s2s_connections, jlib:make_jid("", To, "")) of - Max when is_integer(Max) -> Max; - _ -> ?DEFAULT_MAX_S2S_CONNECTIONS_NUMBER + case acl:match_rule(From, max_s2s_connections, + jlib:make_jid(<<"">>, To, <<"">>)) + of + Max when is_integer(Max) -> Max; + _ -> ?DEFAULT_MAX_S2S_CONNECTIONS_NUMBER end. max_s2s_connections_number_per_node({From, To}) -> - case acl:match_rule( - From, max_s2s_connections_per_node, jlib:make_jid("", To, "")) of - Max when is_integer(Max) -> Max; - _ -> ?DEFAULT_MAX_S2S_CONNECTIONS_NUMBER_PER_NODE + case acl:match_rule(From, max_s2s_connections_per_node, + jlib:make_jid(<<"">>, To, <<"">>)) + of + Max when is_integer(Max) -> Max; + _ -> ?DEFAULT_MAX_S2S_CONNECTIONS_NUMBER_PER_NODE end. needed_connections_number(Ls, MaxS2SConnectionsNumber, @@ -443,45 +460,46 @@ needed_connections_number(Ls, MaxS2SConnectionsNumber, %% -------------------------------------------------------------------- is_service(From, To) -> LFromDomain = From#jid.lserver, - case ejabberd_config:get_local_option({route_subdomains, LFromDomain}) of - s2s -> % bypass RFC 3920 10.3 - false; - _ -> - Hosts = ?MYHOSTS, - P = fun(ParentDomain) -> lists:member(ParentDomain, Hosts) end, - lists:any(P, parent_domains(To#jid.lserver)) + case ejabberd_config:get_local_option( + {route_subdomains, LFromDomain}, + fun(s2s) -> s2s end) of + s2s -> % bypass RFC 3920 10.3 + false; + undefined -> + Hosts = (?MYHOSTS), + P = fun (ParentDomain) -> + lists:member(ParentDomain, Hosts) + end, + lists:any(P, parent_domains(To#jid.lserver)) end. parent_domains(Domain) -> - lists:foldl( - fun(Label, []) -> - [Label]; - (Label, [Head | Tail]) -> - [Label ++ "." ++ Head, Head | Tail] - end, [], lists:reverse(string:tokens(Domain, "."))). - -send_element(Pid, El) -> - Pid ! {send_element, El}. + lists:foldl(fun (Label, []) -> [Label]; + (Label, [Head | Tail]) -> + [<