Skip to content

Commit

Permalink
Add support for @workflow docstring input/output variable description (
Browse files Browse the repository at this point in the history
…#562)

Signed-off-by: Sean Lin <[email protected]>
  • Loading branch information
mayitbeegh authored Jul 23, 2021
1 parent 7a4694a commit e646831
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 100 deletions.
2 changes: 1 addition & 1 deletion flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def list_workflows_paginated(self, identifier, limit=100, token=None, filters=No

def get_workflow(self, id):
"""
This returns a single task for a given ID.
This returns a single workflow for a given ID.
:param flytekit.models.core.identifier.Identifier id: The ID representing a given task.
:raises: TODO
Expand Down
3 changes: 2 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
FlyteEntities,
SerializationSettings,
)
from flytekit.core.docstring import Docstring
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.promise import (
Promise,
Expand Down Expand Up @@ -372,7 +373,7 @@ def __init__(
task_config: T,
interface: Optional[Interface] = None,
environment: Optional[Dict[str, str]] = None,
docstring: str = None,
docstring: Optional[Docstring] = None,
**kwargs,
):
"""
Expand Down
27 changes: 27 additions & 0 deletions flytekit/core/docstring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Callable, Dict, Optional

from docstring_parser import parse


class Docstring(object):
def __init__(self, docstring: str = None, callable_: Callable = None):
if docstring is not None:
self._parsed_docstring = parse(docstring)
else:
self._parsed_docstring = parse(callable_.__doc__)

@property
def input_descriptions(self) -> Dict[str, str]:
return {p.arg_name: p.description for p in self._parsed_docstring.params}

@property
def output_descriptions(self) -> Dict[str, str]:
return {p.return_name: p.description for p in self._parsed_docstring.many_returns}

@property
def short_description(self) -> Optional[str]:
return self._parsed_docstring.short_description

@property
def long_description(self) -> Optional[str]:
return self._parsed_docstring.long_description
30 changes: 10 additions & 20 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from collections import OrderedDict
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union

from docstring_parser import parse

from flytekit.common.exceptions.user import FlyteValidationException
from flytekit.core import context_manager
from flytekit.core.docstring import Docstring
from flytekit.core.type_engine import TypeEngine
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
Expand Down Expand Up @@ -180,18 +179,22 @@ def transform_inputs_to_parameters(

def transform_interface_to_typed_interface(
interface: typing.Optional[Interface],
docstring: str = None,
docstring: Optional[Docstring] = None,
) -> typing.Optional[_interface_models.TypedInterface]:
"""
Transform the given simple python native interface to FlyteIDL's interface
"""
if interface is None:
return None
input_descriptions, output_description = get_variable_descriptions(docstring)

if docstring is None:
input_descriptions = output_descriptions = {}
else:
input_descriptions = docstring.input_descriptions
output_descriptions = remap_shared_output_descriptions(docstring.output_descriptions, interface.outputs)

inputs_map = transform_variable_map(interface.inputs, input_descriptions)
outputs_map = transform_variable_map(
interface.outputs, remap_shared_output_descriptions(output_description, interface.outputs)
)
outputs_map = transform_variable_map(interface.outputs, output_descriptions)
return _interface_models.TypedInterface(inputs_map, outputs_map)


Expand Down Expand Up @@ -354,19 +357,6 @@ def t(a: int, b: str) -> Dict[str, int]: ...
return {default_output_name(): return_annotation}


def get_variable_descriptions(docstring: str) -> Tuple[Dict[str, str], Optional[str]]:
"""
Takes a Python docstring, either from `function.__doc__` or `inpect.getdoc(function)`, and returns the descriptions of the input paramenters and the output values.
:param docstring: Python docstring in Sphinx reStructuredText style, Numpydoc style, or Google style.
:return: Dict of input parameter names mapping to their descriptions, and dict of output names mapping to their descriptions.
"""
parsed_docstring = parse(docstring)
return {p.arg_name: p.description for p in parsed_docstring.params}, {
p.return_name: p.description for p in parsed_docstring.many_returns
}


def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]:
"""
Deals with mixed styles of return value descriptions used in docstrings. If the docstring contains a single entry of return value description, that output description is shared by each output variable.
Expand Down
3 changes: 3 additions & 0 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit.common.tasks.raw_container import _get_container_definition
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings
from flytekit.core.docstring import Docstring
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
environment: Optional[Dict[str, str]] = None,
task_resolver: Optional[TaskResolverMixin] = None,
secret_requests: Optional[List[Secret]] = None,
docstring: Optional[Docstring] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -73,6 +75,7 @@ def __init__(
name=name,
task_config=task_config,
security_ctx=sec_ctx,
docstring=docstring,
**kwargs,
)
self._container_image = container_image
Expand Down
3 changes: 2 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flytekit.common.exceptions import scopes as exception_scopes
from flytekit.core.base_task import Task, TaskResolverMixin
from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, FlyteContext, FlyteContextManager
from flytekit.core.docstring import Docstring
from flytekit.core.interface import transform_signature_to_interface
from flytekit.core.python_auto_container import PythonAutoContainerTask, default_task_resolver
from flytekit.core.tracker import isnested, istestfunction
Expand Down Expand Up @@ -121,7 +122,7 @@ def __init__(
interface=mutated_interface,
task_config=task_config,
task_resolver=task_resolver,
docstring=task_function.__doc__,
docstring=Docstring(callable_=task_function),
**kwargs,
)

Expand Down
11 changes: 9 additions & 2 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
FlyteContextManager,
FlyteEntities,
)
from flytekit.core.docstring import Docstring
from flytekit.core.interface import (
Interface,
transform_inputs_to_parameters,
Expand Down Expand Up @@ -173,13 +174,14 @@ def __init__(
workflow_metadata: WorkflowMetadata,
workflow_metadata_defaults: WorkflowMetadataDefaults,
python_interface: Interface,
docstring: Optional[Docstring] = None,
**kwargs,
):
self._name = name
self._workflow_metadata = workflow_metadata
self._workflow_metadata_defaults = workflow_metadata_defaults
self._python_interface = python_interface
self._interface = transform_interface_to_typed_interface(python_interface)
self._interface = transform_interface_to_typed_interface(python_interface, docstring)
self._inputs = {}
self._unbound_inputs = set()
self._nodes = []
Expand Down Expand Up @@ -640,6 +642,7 @@ def __init__(
workflow_function: Callable,
metadata: Optional[WorkflowMetadata],
default_metadata: Optional[WorkflowMetadataDefaults],
docstring: Docstring = None,
):
name = f"{workflow_function.__module__}.{workflow_function.__name__}"
self._workflow_function = workflow_function
Expand All @@ -654,6 +657,7 @@ def __init__(
workflow_metadata=metadata,
workflow_metadata_defaults=default_metadata,
python_interface=native_interface,
docstring=docstring,
)

@property
Expand Down Expand Up @@ -794,7 +798,10 @@ def wrapper(fn):
workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)

workflow_instance = PythonFunctionWorkflow(
fn, metadata=workflow_metadata, default_metadata=workflow_metadata_defaults
fn,
metadata=workflow_metadata,
default_metadata=workflow_metadata_defaults,
docstring=Docstring(callable_=fn),
)
workflow_instance.compile()
return workflow_instance
Expand Down
95 changes: 95 additions & 0 deletions tests/flytekit/unit/core/test_docstring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import typing

from flytekit.core.docstring import Docstring


def test_get_variable_descriptions():
# sphinx style
def z(a: int, b: str) -> typing.Tuple[int, str]:
"""
function z
longer description here
:param a: foo
:param b: bar
:return: ramen
"""
...

docstring = Docstring(callable_=z)
input_descriptions = docstring.input_descriptions
output_descriptions = docstring.output_descriptions
assert input_descriptions["a"] == "foo"
assert input_descriptions["b"] == "bar"
assert len(output_descriptions) == 1
assert next(iter(output_descriptions.items()))[1] == "ramen"
assert docstring.short_description == "function z"
assert docstring.long_description == "longer description here"

# numpy style
def z(a: int, b: str) -> typing.Tuple[int, str]:
"""
function z
longer description here
Parameters
----------
a : int
foo
b : str
bar
Returns
-------
out : tuple
ramen
"""
...

docstring = Docstring(callable_=z)
input_descriptions = docstring.input_descriptions
output_descriptions = docstring.output_descriptions
assert input_descriptions["a"] == "foo"
assert input_descriptions["b"] == "bar"
assert len(output_descriptions) == 1
assert next(iter(output_descriptions.items()))[1] == "ramen"
assert docstring.short_description == "function z"
assert docstring.long_description == "longer description here"

# google style
def z(a: int, b: str) -> typing.Tuple[int, str]:
"""function z
longer description here
Args:
a(int): foo
b(str): bar
Returns:
str: ramen
"""
...

docstring = Docstring(callable_=z)
input_descriptions = docstring.input_descriptions
output_descriptions = docstring.output_descriptions
assert input_descriptions["a"] == "foo"
assert input_descriptions["b"] == "bar"
assert len(output_descriptions) == 1
assert next(iter(output_descriptions.items()))[1] == "ramen"
assert docstring.short_description == "function z"
assert docstring.long_description == "longer description here"

# empty doc
def z(a: int, b: str) -> typing.Tuple[int, str]:
...

docstring = Docstring(callable_=z)
input_descriptions = docstring.input_descriptions
output_descriptions = docstring.output_descriptions
assert len(input_descriptions) == 0
assert len(output_descriptions) == 0
assert docstring.short_description is None
assert docstring.long_description is None
Loading

0 comments on commit e646831

Please sign in to comment.