diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 3d5017675ef..fe38f946f9b 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -161,6 +161,12 @@ def _dispatch_execute( logger.info(f"Engine folder written successfully to the output prefix {output_prefix}") logger.debug("Finished _dispatch_execute") + if os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true" and _constants.ERROR_FILE_NAME in output_file_dict: + # This env is set by the flytepropeller + # AWS batch job get the status from the exit code, so once we catch the error, + # we should return the error code here + exit(1) + def get_one_of(*args) -> str: """ diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 479ad9e7bd3..6a8b8c430e2 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -3,6 +3,7 @@ from collections import OrderedDict import mock +import pytest from flyteidl.core.errors_pb2 import ErrorDocument from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution @@ -110,6 +111,37 @@ def verify_output(*args, **kwargs): assert mock_write_to_file.call_count == 1 +@mock.patch.dict(os.environ, {"FLYTE_FAIL_ON_ERROR": "True"}) +@mock.patch("flytekit.core.utils.load_proto_from_file") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.utils.write_proto_to_file") +def test_dispatch_execute_return_error_code(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): + mock_get_data.return_value = True + mock_upload_dir.return_value = True + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + python_task = mock.MagicMock() + python_task.dispatch_execute.side_effect = Exception("random") + + empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl() + mock_load_proto.return_value = empty_literal_map + + def verify_output(*args, **kwargs): + assert isinstance(args[0], ErrorDocument) + + mock_write_to_file.side_effect = verify_output + + with pytest.raises(SystemExit) as cm: + _dispatch_execute(ctx, python_task, "inputs path", "outputs prefix") + pytest.assertEqual(cm.value.code, 1) + + # This function collects outputs instead of writing them to a file. # See flytekit.core.utils.write_proto_to_file for the original def get_output_collector(results: OrderedDict):