diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 166b1277..bd4b6529 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,9 @@ +`Unreleased`_ +============= + +- Add a new ``default`` optional argument to ``logger.catch()``, it should be the returned value by the decorated function in case an error occurred (`#272 `_). + + `0.5.0`_ (2020-05-17) ===================== diff --git a/loguru/_logger.py b/loguru/_logger.py index 18e458a3..f2160705 100644 --- a/loguru/_logger.py +++ b/loguru/_logger.py @@ -1078,6 +1078,7 @@ def catch( reraise=False, onerror=None, exclude=None, + default=None, message="An error has been caught in function '{record[function]}', " "process '{record[process].name}' ({record[process].id}), " "thread '{record[thread].name}' ({record[thread].id}):" @@ -1108,6 +1109,9 @@ def catch( exclude : |Exception|, optional A type of exception (or a tuple of types) that will be purposely ignored and hence propagated to the caller without being logged. + default : optional + The value to be returned by the decorated function if an error occurred without being + re-raised. message : |str|, optional The message that will be automatically logged if an exception occurs. Note that it will be formatted with the ``record`` attribute. @@ -1196,18 +1200,21 @@ def __call__(_, function): async def catch_wrapper(*args, **kwargs): with catcher: return await function(*args, **kwargs) + return default elif inspect.isgeneratorfunction(function): def catch_wrapper(*args, **kwargs): with catcher: return (yield from function(*args, **kwargs)) + return default else: def catch_wrapper(*args, **kwargs): with catcher: return function(*args, **kwargs) + return default functools.update_wrapper(catch_wrapper, function) return catch_wrapper diff --git a/tests/test_catch_exceptions.py b/tests/test_catch_exceptions.py index 0bdb0365..1077bdea 100644 --- a/tests/test_catch_exceptions.py +++ b/tests/test_catch_exceptions.py @@ -623,3 +623,38 @@ def foo(x, y, z): with pytest.raises(StopIteration, match=r"3"): next(f) + + +def test_decorate_generator_with_error(): + @logger.catch + def foo(): + for i in range(3): + 1 / (2 - i) + yield i + + assert list(foo()) == [0, 1] + + +def test_default_with_function(): + @logger.catch(default=42) + def foo(): + 1 / 0 + + assert foo() == 42 + + +def test_default_with_generator(): + @logger.catch(default=42) + def foo(): + yield 1 / 0 + + with pytest.raises(StopIteration, match=r"42"): + next(foo()) + + +def test_default_with_coroutine(): + @logger.catch(default=42) + async def foo(): + return 1 / 0 + + assert asyncio.run(foo()) == 42