Skip to content

Commit

Permalink
Add docstring's param descriptions to task's and workflow's input and…
Browse files Browse the repository at this point in the history
… output variables (#557)

Signed-off-by: Sean Lin <[email protected]>

Co-authored-by: wild-endeavor <[email protected]>
  • Loading branch information
mayitbeegh and wild-endeavor authored Jul 23, 2021
1 parent 1f10fa5 commit 8c6764c
Show file tree
Hide file tree
Showing 16 changed files with 314 additions and 35 deletions.
13 changes: 9 additions & 4 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,16 @@ dockerpty==0.4.1
# via docker-compose
docopt==0.6.2
# via docker-compose
docstring-parser==0.9.1
# via
# -c requirements.txt
# flytekit
flake8==3.9.2
# via
# -r dev-requirements.in
# flake8-black
# flake8-isort
flake8-black==0.2.2
flake8-black==0.2.3
# via -r dev-requirements.in
flake8-isort==4.0.0
# via -r dev-requirements.in
Expand Down Expand Up @@ -136,7 +140,7 @@ markupsafe==2.0.1
# via
# -c requirements.txt
# jinja2
marshmallow==3.12.2
marshmallow==3.13.0
# via
# -c requirements.txt
# dataclasses-json
Expand Down Expand Up @@ -166,7 +170,7 @@ natsort==7.1.1
# via
# -c requirements.txt
# flytekit
numpy==1.21.0
numpy==1.21.1
# via
# -c requirements.txt
# pandas
Expand All @@ -183,7 +187,7 @@ paramiko==2.7.2
# via
# -c requirements.txt
# docker
pathspec==0.8.1
pathspec==0.9.0
# via
# -c requirements.txt
# black
Expand Down Expand Up @@ -315,6 +319,7 @@ texttable==1.6.4
toml==0.10.2
# via
# coverage
# flake8-black
# mypy
# pytest
tomli==1.0.4
Expand Down
18 changes: 10 additions & 8 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ appnope==0.1.2
# via
# ipykernel
# ipython
astroid==2.6.2
astroid==2.6.5
# via sphinx-autoapi
async-generator==1.10
# via nbclient
Expand All @@ -39,9 +39,9 @@ black==21.7b0
# via papermill
bleach==3.3.1
# via nbconvert
boto3==1.18.1
boto3==1.18.4
# via sagemaker-training
botocore==1.21.1
botocore==1.21.4
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -70,7 +70,7 @@ css-html-js-minify==2.5.5
# via sphinx-material
dataclasses-json==0.5.4
# via flytekit
debugpy==1.3.0
debugpy==1.4.0
# via ipykernel
decorator==5.0.9
# via
Expand All @@ -84,6 +84,8 @@ dirhash==0.2.1
# via flytekit
docker-image-py==0.1.10
# via flytekit
docstring-parser==0.9.1
# via flytekit
docutils==0.16
# via sphinx
entrypoints==0.3
Expand Down Expand Up @@ -112,7 +114,7 @@ importlib-metadata==4.6.1
# via keyring
inotify_simple==1.2.1
# via sagemaker-training
ipykernel==6.0.2
ipykernel==6.0.3
# via flytekit
ipython==7.25.0
# via ipykernel
Expand Down Expand Up @@ -154,7 +156,7 @@ lxml==4.6.3
# via sphinx-material
markupsafe==2.0.1
# via jinja2
marshmallow==3.12.2
marshmallow==3.13.0
# via
# dataclasses-json
# marshmallow-enum
Expand Down Expand Up @@ -188,7 +190,7 @@ nbformat==5.1.3
# papermill
nest-asyncio==1.5.1
# via nbclient
numpy==1.21.0
numpy==1.21.1
# via
# flytekit
# pandas
Expand All @@ -209,7 +211,7 @@ paramiko==2.7.2
# via sagemaker-training
parso==0.8.2
# via jedi
pathspec==0.8.1
pathspec==0.9.0
# via
# black
# scantree
Expand Down
2 changes: 1 addition & 1 deletion flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def list_workflows_paginated(self, identifier, limit=100, token=None, filters=No

def get_workflow(self, id):
"""
This returns a single task for a given ID.
This returns a single workflow for a given ID.
:param flytekit.models.core.identifier.Identifier id: The ID representing a given task.
:raises: TODO
Expand Down
4 changes: 3 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
FlyteEntities,
SerializationSettings,
)
from flytekit.core.docstring import Docstring
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.promise import (
Promise,
Expand Down Expand Up @@ -372,6 +373,7 @@ def __init__(
task_config: T,
interface: Optional[Interface] = None,
environment: Optional[Dict[str, str]] = None,
docstring: Optional[Docstring] = None,
**kwargs,
):
"""
Expand All @@ -389,7 +391,7 @@ def __init__(
super().__init__(
task_type=task_type,
name=name,
interface=transform_interface_to_typed_interface(interface),
interface=transform_interface_to_typed_interface(interface, docstring),
**kwargs,
)
self._python_interface = interface if interface else Interface()
Expand Down
27 changes: 27 additions & 0 deletions flytekit/core/docstring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Callable, Dict, Optional

from docstring_parser import parse


class Docstring(object):
def __init__(self, docstring: str = None, callable_: Callable = None):
if docstring is not None:
self._parsed_docstring = parse(docstring)
else:
self._parsed_docstring = parse(callable_.__doc__)

@property
def input_descriptions(self) -> Dict[str, str]:
return {p.arg_name: p.description for p in self._parsed_docstring.params}

@property
def output_descriptions(self) -> Dict[str, str]:
return {p.return_name: p.description for p in self._parsed_docstring.many_returns}

@property
def short_description(self) -> Optional[str]:
return self._parsed_docstring.short_description

@property
def long_description(self) -> Optional[str]:
return self._parsed_docstring.long_description
32 changes: 28 additions & 4 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from flytekit.common.exceptions.user import FlyteValidationException
from flytekit.core import context_manager
from flytekit.core.docstring import Docstring
from flytekit.core.type_engine import TypeEngine
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
Expand Down Expand Up @@ -178,15 +179,22 @@ def transform_inputs_to_parameters(

def transform_interface_to_typed_interface(
interface: typing.Optional[Interface],
docstring: Optional[Docstring] = None,
) -> typing.Optional[_interface_models.TypedInterface]:
"""
Transform the given simple python native interface to FlyteIDL's interface
"""
if interface is None:
return None

inputs_map = transform_variable_map(interface.inputs)
outputs_map = transform_variable_map(interface.outputs)
if docstring is None:
input_descriptions = output_descriptions = {}
else:
input_descriptions = docstring.input_descriptions
output_descriptions = remap_shared_output_descriptions(docstring.output_descriptions, interface.outputs)

inputs_map = transform_variable_map(interface.inputs, input_descriptions)
outputs_map = transform_variable_map(interface.outputs, output_descriptions)
return _interface_models.TypedInterface(inputs_map, outputs_map)


Expand Down Expand Up @@ -253,15 +261,17 @@ def transform_signature_to_interface(signature: inspect.Signature) -> Interface:
return Interface(inputs, outputs, output_tuple_name=custom_name)


def transform_variable_map(variable_map: Dict[str, type]) -> Dict[str, _interface_models.Variable]:
def transform_variable_map(
variable_map: Dict[str, type], descriptions: Dict[str, str] = {}
) -> Dict[str, _interface_models.Variable]:
"""
Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a
Flyte Variable object with that type.
"""
res = OrderedDict()
if variable_map:
for k, v in variable_map.items():
res[k] = transform_type(v, k)
res[k] = transform_type(v, descriptions.get(k, k))

return res

Expand Down Expand Up @@ -345,3 +355,17 @@ def t(a: int, b: str) -> Dict[str, int]: ...
# Handle all other single return types
logger.debug(f"Task returns unnamed native tuple {return_annotation}")
return {default_output_name(): return_annotation}


def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]:
"""
Deals with mixed styles of return value descriptions used in docstrings. If the docstring contains a single entry of return value description, that output description is shared by each output variable.
:param output_descriptions: Dict of output variable names mapping to output description
:param outputs: Interface outputs
:return: Dict of output variable names mapping to shared output description
"""
# no need to remap
if len(output_descriptions) != 1:
return output_descriptions
_, shared_description = next(iter(output_descriptions.items()))
return {k: shared_description for k, _ in outputs.items()}
3 changes: 3 additions & 0 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit.common.tasks.raw_container import _get_container_definition
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings
from flytekit.core.docstring import Docstring
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
environment: Optional[Dict[str, str]] = None,
task_resolver: Optional[TaskResolverMixin] = None,
secret_requests: Optional[List[Secret]] = None,
docstring: Optional[Docstring] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -73,6 +75,7 @@ def __init__(
name=name,
task_config=task_config,
security_ctx=sec_ctx,
docstring=docstring,
**kwargs,
)
self._container_image = container_image
Expand Down
2 changes: 2 additions & 0 deletions flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flytekit.common.exceptions import scopes as exception_scopes
from flytekit.core.base_task import Task, TaskResolverMixin
from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, FlyteContext, FlyteContextManager
from flytekit.core.docstring import Docstring
from flytekit.core.interface import transform_signature_to_interface
from flytekit.core.python_auto_container import PythonAutoContainerTask, default_task_resolver
from flytekit.core.tracker import isnested, istestfunction
Expand Down Expand Up @@ -121,6 +122,7 @@ def __init__(
interface=mutated_interface,
task_config=task_config,
task_resolver=task_resolver,
docstring=Docstring(callable_=task_function),
**kwargs,
)

Expand Down
11 changes: 9 additions & 2 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
FlyteContextManager,
FlyteEntities,
)
from flytekit.core.docstring import Docstring
from flytekit.core.interface import (
Interface,
transform_inputs_to_parameters,
Expand Down Expand Up @@ -173,13 +174,14 @@ def __init__(
workflow_metadata: WorkflowMetadata,
workflow_metadata_defaults: WorkflowMetadataDefaults,
python_interface: Interface,
docstring: Optional[Docstring] = None,
**kwargs,
):
self._name = name
self._workflow_metadata = workflow_metadata
self._workflow_metadata_defaults = workflow_metadata_defaults
self._python_interface = python_interface
self._interface = transform_interface_to_typed_interface(python_interface)
self._interface = transform_interface_to_typed_interface(python_interface, docstring)
self._inputs = {}
self._unbound_inputs = set()
self._nodes = []
Expand Down Expand Up @@ -640,6 +642,7 @@ def __init__(
workflow_function: Callable,
metadata: Optional[WorkflowMetadata],
default_metadata: Optional[WorkflowMetadataDefaults],
docstring: Docstring = None,
):
name = f"{workflow_function.__module__}.{workflow_function.__name__}"
self._workflow_function = workflow_function
Expand All @@ -654,6 +657,7 @@ def __init__(
workflow_metadata=metadata,
workflow_metadata_defaults=default_metadata,
python_interface=native_interface,
docstring=docstring,
)

@property
Expand Down Expand Up @@ -794,7 +798,10 @@ def wrapper(fn):
workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)

workflow_instance = PythonFunctionWorkflow(
fn, metadata=workflow_metadata, default_metadata=workflow_metadata_defaults
fn,
metadata=workflow_metadata,
default_metadata=workflow_metadata_defaults,
docstring=Docstring(callable_=fn),
)
workflow_instance.compile()
return workflow_instance
Expand Down
Loading

0 comments on commit 8c6764c

Please sign in to comment.