diff --git a/flytekit/__init__.py b/flytekit/__init__.py index a57fb84a1fd..30e29210424 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -1,4 +1,4 @@ from __future__ import absolute_import import flytekit.plugins -__version__ = '0.1.5' +__version__ = '0.1.6' diff --git a/flytekit/engines/unit/engine.py b/flytekit/engines/unit/engine.py index ffbcfbfc566..7040ee65ad8 100644 --- a/flytekit/engines/unit/engine.py +++ b/flytekit/engines/unit/engine.py @@ -7,6 +7,8 @@ from datetime import datetime as _datetime from six import moves as _six_moves +from google.protobuf.json_format import ParseDict as _ParseDict +from flyteidl.plugins import qubole_pb2 as _qubole_pb2 from flytekit.common import constants as _sdk_constants, utils as _common_utils from flytekit.common.exceptions import user as _user_exceptions, system as _system_exception from flytekit.common.types import helpers as _type_helpers @@ -14,7 +16,7 @@ from flytekit.engines import common as _common_engine from flytekit.engines.unit.mock_stats import MockStats from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import literals as _literals, array_job as _array_job +from flytekit.models import literals as _literals, array_job as _array_job, qubole as _qubole_models from flytekit.models.core.identifier import WorkflowExecutionIdentifier @@ -32,9 +34,12 @@ def get_task(self, sdk_task): return ReturnOutputsTask(sdk_task) elif sdk_task.type in { _sdk_constants.SdkTaskType.DYNAMIC_TASK, - _sdk_constants.SdkTaskType.BATCH_HIVE_TASK }: return DynamicTask(sdk_task) + elif sdk_task.type in { + _sdk_constants.SdkTaskType.BATCH_HIVE_TASK, + }: + return HiveTask(sdk_task) else: raise _user_exceptions.FlyteAssertion( "Unit tests are not currently supported for tasks of type: {}".format( @@ -76,7 +81,7 @@ def execute(self, inputs, context=None): Just execute the function and return the outputs as a user-readable dictionary. :param flytekit.models.literals.LiteralMap inputs: :param context: - :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] + :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] """ with _TemporaryConfiguration( _os.path.join(_os.path.dirname(__file__), 'unit.config'), @@ -84,12 +89,12 @@ def execute(self, inputs, context=None): ): with _common_utils.AutoDeletingTempDir("unit_test_dir") as working_directory: with _data_proxy.LocalWorkingDirectoryContext(working_directory): - return self._execute_user_code(inputs) + return self._transform_for_user_output(self._execute_user_code(inputs)) def _execute_user_code(self, inputs): """ :param flytekit.models.literals.LiteralMap inputs: - :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] + :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] """ with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory: return self.sdk_task.execute( @@ -107,24 +112,32 @@ def _execute_user_code(self, inputs): inputs ) + def _transform_for_user_output(self, outputs): + """ + Take whatever is returned from the task execution and convert to a reasonable output for the behavior of this + task's unit test. + :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs: + :rtype: T + """ + return outputs + def register(self, identifier, version): raise _user_exceptions.FlyteAssertion("You cannot register unit test tasks.") class ReturnOutputsTask(UnitTestEngineTask): - def execute(self, inputs, context=None): + def _transform_for_user_output(self, outputs): """ - Just execute the function and return the outputs as a user-readable dictionary. - :param flytekit.models.literals.LiteralMap inputs: - :param context: - :rtype: dict[Text, T] + Just return the outputs as a user-readable dictionary. + :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs: + :rtype: T """ - outputs = super(ReturnOutputsTask, self).execute(inputs)[_sdk_constants.OUTPUT_FILE_NAME] + literal_map = outputs[_sdk_constants.OUTPUT_FILE_NAME] return { name: _type_helpers.get_sdk_type_from_literal_type( variable.type ).promote_from_model( - outputs.literals[name] + literal_map.literals[name] ).to_python_std() for name, variable in _six.iteritems(self.sdk_task.interface.outputs) } @@ -135,7 +148,7 @@ class DynamicTask(ReturnOutputsTask): def _execute_user_code(self, inputs): """ :param flytekit.models.literals.LiteralMap inputs: - :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] + :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] """ results = super(DynamicTask, self)._execute_user_code(inputs) if _sdk_constants.FUTURES_FILE_NAME in results: @@ -151,7 +164,7 @@ def _execute_user_code(self, inputs): # TODO: futures.outputs should have the Schema instances. # After schema is implemented, fill out random data into the random locations # then check output in test function - # From Haytham even though we recommend people use typed schemas, they might not always do so... + # Even though we recommend people use typed schemas, they might not always do so... # in which case it'll be impossible to predict the actual schema, we should support a # way for unit test authors to provide fake data regardless sub_task_output = None @@ -201,7 +214,7 @@ def fulfil_bindings(binding_data, fulfilled_promises): fulfilled_promises :param _interface.BindingData binding_data: - :param dict[Text, T] fulfilled_promises: + :param dict[Text,T] fulfilled_promises: :rtype: """ if binding_data.scalar: @@ -228,3 +241,26 @@ def fulfil_bindings(binding_data, fulfilled_promises): k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for k, sub_binding_data in _six.iteritems(binding_data.map.bindings) })) + + +class HiveTask(DynamicTask): + def _transform_for_user_output(self, outputs): + """ + Just execute the function and return the list of Hive queries returned. + :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs: + :rtype: list[Text] + """ + futures = outputs.get(_sdk_constants.FUTURES_FILE_NAME) + if futures: + task_ids_to_defs = { + t.id.name: _qubole_models.QuboleHiveJob.from_flyte_idl( + _ParseDict(t.custom, _qubole_pb2.QuboleHiveJob()) + ) + for t in futures.tasks + } + return [ + q.query + for q in task_ids_to_defs[futures.nodes[0].task_node.reference_id.name].query_collection.queries + ] + else: + return [] diff --git a/setup.cfg b/setup.cfg index 322b49873c6..b7e58d5f7b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ max-complexity=16 [tool:pytest] norecursedirs = common workflows spark log_cli = true -log_cli_level = 100 +log_cli_level = 20 [pep8] max-line-length = 120 diff --git a/tests/flytekit/unit/use_scenarios/unit_testing/hive_tasks.py b/tests/flytekit/unit/use_scenarios/unit_testing/hive_tasks.py new file mode 100644 index 00000000000..7fa59b2c4f2 --- /dev/null +++ b/tests/flytekit/unit/use_scenarios/unit_testing/hive_tasks.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import +from flytekit.sdk.tasks import hive_task +import pytest + + +def test_no_queries(): + @hive_task + def test_hive_task(wf_params): + pass + + assert test_hive_task.unit_test() == [] + + +def test_empty_list_queries(): + @hive_task + def test_hive_task(wf_params): + return [] + + assert test_hive_task.unit_test() == [] + + +def test_one_query(): + @hive_task + def test_hive_task(wf_params): + return "abc" + + assert test_hive_task.unit_test() == ["abc"] + + +def test_multiple_queries(): + @hive_task + def test_hive_task(wf_params): + return ["abc", "cde"] + + assert test_hive_task.unit_test() == ["abc", "cde"] + + +def test_raise_exception(): + @hive_task + def test_hive_task(wf_params): + raise FloatingPointError("Floating point error for some reason.") + + with pytest.raises(FloatingPointError): + test_hive_task.unit_test()