diff --git a/memoization/__init__.py b/memoization/__init__.py index b83ae8f..ba44cf1 100644 --- a/memoization/__init__.py +++ b/memoization/__init__.py @@ -1,6 +1,6 @@ import sys -__all__ = ['cached', 'CachingAlgorithmFlag', 'FIFO', 'LRU', 'LFU'] +__all__ = ['cached', 'suppress_warnings', 'CachingAlgorithmFlag', 'FIFO', 'LRU', 'LFU'] if (3, 4) <= sys.version_info < (4, 0): # for Python >=3.4 <4 from . import memoization as _memoization @@ -13,6 +13,7 @@ raise ImportError('Unsupported python version') else: cached = _memoization.cached + suppress_warnings = _memoization.suppress_warnings CachingAlgorithmFlag = _memoization.CachingAlgorithmFlag FIFO = _memoization.CachingAlgorithmFlag.FIFO LRU = _memoization.CachingAlgorithmFlag.LRU diff --git a/memoization/memoization.py b/memoization/memoization.py index 6749429..7900802 100644 --- a/memoization/memoization.py +++ b/memoization/memoization.py @@ -9,9 +9,12 @@ # Public symbols -__all__ = ['cached', 'CachingAlgorithmFlag'] +__all__ = ['cached', 'suppress_warnings', 'CachingAlgorithmFlag'] __version__ = '0.3.2' +# Whether warnings are enabled +_warning_enabled = True + # Insert the algorithm flags to the global namespace for convenience globals().update(CachingAlgorithmFlag.__members__) @@ -95,22 +98,19 @@ def calculate_performance(employee): if custom_key_maker is not None and not hasattr(custom_key_maker, '__call__'): raise TypeError('Expected custom_key_maker to be callable or None') - # Warn on zero-argument functions - user_function_info = inspect.getfullargspec(user_function) - if len(user_function_info.args) == 0 and user_function_info.varargs is None and user_function_info.varkw is None \ - and max_size is None and ttl is None: - warnings.warn('It makes no sense to do memoization on a function without arguments', SyntaxWarning) - # Check custom key maker and wrap it if custom_key_maker is not None: - custom_key_maker_info = inspect.getfullargspec(custom_key_maker) - if custom_key_maker_info.args != user_function_info.args or \ - custom_key_maker_info.varargs != user_function_info.varargs or \ - custom_key_maker_info.varkw != user_function_info.varkw or \ - custom_key_maker_info.kwonlyargs != user_function_info.kwonlyargs or \ - custom_key_maker_info.defaults != user_function_info.defaults or \ - custom_key_maker_info.kwonlydefaults != user_function_info.kwonlydefaults: - raise TypeError('Expected custom_key_maker to have the same signature as the function being cached') + if _warning_enabled: + custom_key_maker_info = inspect.getfullargspec(custom_key_maker) + user_function_info = inspect.getfullargspec(user_function) + if custom_key_maker_info.args != user_function_info.args or \ + custom_key_maker_info.varargs != user_function_info.varargs or \ + custom_key_maker_info.varkw != user_function_info.varkw or \ + custom_key_maker_info.kwonlyargs != user_function_info.kwonlyargs or \ + custom_key_maker_info.defaults != user_function_info.defaults or \ + custom_key_maker_info.kwonlydefaults != user_function_info.kwonlydefaults: + warnings.warn('Expected custom_key_maker to have the same signature as the function being cached. ' + 'Call memoization.suppress_warnings() to remove this message.', SyntaxWarning) def custom_key_maker_wrapper(args, kwargs): return custom_key_maker(*args, **kwargs) @@ -124,6 +124,16 @@ def custom_key_maker_wrapper(args, kwargs): return update_wrapper(wrapper, user_function) # update wrapper to make it look like the original function +def suppress_warnings(should_warn=False): + """ + Disable/Enable warnings when @cached is used + + :param should_warn: Whether warnings should be shown (False by default) + """ + global _warning_enabled + _warning_enabled = should_warn + + def _create_cached_wrapper(user_function, max_size, ttl, algorithm, thread_safe, order_independent, custom_key_maker): """ Factory that creates an actual executed function when a function is decorated with @cached @@ -145,5 +155,3 @@ def _create_cached_wrapper(user_function, max_size, ttl, algorithm, thread_safe, sys.stderr.write('python-memoization v' + __version__ + ': A powerful caching library for Python, with TTL support and multiple algorithm options.\n') sys.stderr.write('Go to https://github.com/lonelyenvoy/python-memoization for usage and more details.\n') - - diff --git a/memoization/memoization.pyi b/memoization/memoization.pyi index e7f45dd..08448b4 100644 --- a/memoization/memoization.pyi +++ b/memoization/memoization.pyi @@ -20,4 +20,6 @@ def cached(max_size: Optional[int] = ..., @overload def cached(user_function: T = ...) -> CachedFunction[T]: ... +def suppress_warnings(should_warn: bool = ...) -> None: ... + CachingAlgorithmFlag = CachingAlgorithmFlagType diff --git a/test.py b/test.py index 9084375..e78f00d 100644 --- a/test.py +++ b/test.py @@ -7,8 +7,9 @@ from threading import Thread from threading import Lock import inspect +import warnings -from memoization import cached, CachingAlgorithmFlag +from memoization import cached, suppress_warnings, CachingAlgorithmFlag from memoization.caching.general.keys_order_dependent import make_key exec_times = {} # executed time of each tested function @@ -392,6 +393,38 @@ def test_memoization_must_preserve_type_signature(self): self.assertEqual(inspect.getfullargspec(f23), inspect.getfullargspec(f27)) self.assertEqual(inspect.getfullargspec(f23), inspect.getfullargspec(f28)) + def test_memoization_with_custom_key_maker_and_inconsistent_type_signature(self): + def inconsistent_custom_key_maker(*args, **kwargs): + return args[0] + + def should_show_warning(): + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter('always') + + @cached(max_size=5, custom_key_maker=inconsistent_custom_key_maker) + def f(a=1, *b, c=2, **d): + return a, b, c, d + + self.assertEqual(len(caught_warnings), 1) + self.assertEqual(caught_warnings[0].category, SyntaxWarning) + self.assertTrue('signature' in str(caught_warnings[0].message)) + + def should_not_show_warning(): + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter('always') + + @cached(max_size=5, custom_key_maker=inconsistent_custom_key_maker) + def f(a=1, *b, c=2, **d): + return a, b, c, d + + self.assertEqual(len(caught_warnings), 0) + + should_show_warning() + suppress_warnings(should_warn=False) + should_not_show_warning() + suppress_warnings(should_warn=True) + should_show_warning() + def _general_test(self, tested_function, algorithm, hits, misses, in_cache, not_in_cache): # clear exec_times[tested_function.__name__] = 0