Skip to content

Commit

Permalink
Split service code (#2136)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Jan 31, 2024
1 parent 63b7b29 commit e76efdc
Show file tree
Hide file tree
Showing 71 changed files with 843 additions and 965 deletions.
2 changes: 2 additions & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ pandas
scikit-learn
types-requests
prometheus-client

orjson
10 changes: 5 additions & 5 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ adlfs==2023.9.0
# via flytekit
aiobotocore==2.5.4
# via s3fs
aiohttp==3.8.6
aiohttp==3.9.2
# via
# adlfs
# aiobotocore
Expand Down Expand Up @@ -69,9 +69,7 @@ cfgv==3.4.0
chardet==5.2.0
# via binaryornot
charset-normalizer==3.3.2
# via
# aiohttp
# requests
# via requests
click==8.1.7
# via
# cookiecutter
Expand Down Expand Up @@ -316,6 +314,8 @@ oauthlib==3.2.2
# requests-oauthlib
opt-einsum==3.3.0
# via tensorflow
orjson==3.9.12
# via -r dev-requirements.in
packaging==23.2
# via
# docker
Expand All @@ -329,7 +329,7 @@ parso==0.8.3
# via jedi
pexpect==4.8.0
# via ipython
pillow==10.1.0
pillow==10.2.0
# via -r dev-requirements.in
platformdirs==3.11.0
# via virtualenv
Expand Down
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ adlfs==2023.9.0
# via flytekit
aiobotocore==2.5.4
# via s3fs
aiohttp==3.9.1
aiohttp==3.9.2
# via
# adlfs
# aiobotocore
Expand Down
7 changes: 0 additions & 7 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from flyteidl.admin import task_pb2 as _task_pb2
from flyteidl.admin import workflow_attributes_pb2 as _workflow_attributes_pb2
from flyteidl.admin import workflow_pb2 as _workflow_pb2
from flyteidl.artifact import artifacts_pb2
from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2
from google.protobuf.duration_pb2 import Duration

Expand Down Expand Up @@ -1037,9 +1036,3 @@ def get_data(self, flyte_uri: str) -> _data_proxy_pb2.GetDataResponse:

resp = self._dataproxy_stub.GetData(req, metadata=self._metadata)
return resp

def create_artifact(self, request: artifacts_pb2.CreateArtifactRequest) -> artifacts_pb2.CreateArtifactResponse:
return self._artifact_stub.CreateArtifact(request)

def get_artifact(self, request: artifacts_pb2.GetArtifactRequest) -> artifacts_pb2.GetArtifactResponse:
return self._artifact_stub.GetArtifact(request)
2 changes: 0 additions & 2 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import grpc
from flyteidl.admin.project_pb2 import ProjectListRequest
from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse
from flyteidl.artifact import artifacts_pb2_grpc as artifact_service
from flyteidl.service import admin_pb2_grpc as _admin_service
from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2
from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service
Expand Down Expand Up @@ -53,7 +52,6 @@ def __init__(self, cfg: PlatformConfig, **kwargs):
self._stub = _admin_service.AdminServiceStub(self._channel)
self._signal = signal_service.SignalServiceStub(self._channel)
self._dataproxy_stub = dataproxy_service.DataProxyServiceStub(self._channel)
self._artifact_stub = artifact_service.ArtifactRegistryStub(self._channel)

logger.info(
f"Flyte Client configured -> {self._cfg.endpoint} in {'insecure' if self._cfg.insecure else 'secure'} mode."
Expand Down
10 changes: 10 additions & 0 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@
callback=key_value_callback,
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
)
@click.option(
"--skip-errors",
"--skip-error",
default=False,
is_flag=True,
help="Skip errors during registration. This is useful when registering multiple packages and you want to skip "
"errors for some packages.",
)
@click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1)
@click.pass_context
def register(
Expand All @@ -135,6 +143,7 @@ def register(
dry_run: bool,
activate_launchplans: bool,
env: typing.Optional[typing.Dict[str, str]],
skip_errors: bool,
):
"""
see help
Expand Down Expand Up @@ -187,6 +196,7 @@ def register(
env=env,
dry_run=dry_run,
activate_launchplans=activate_launchplans,
skip_errors=skip_errors,
)
except Exception as e:
raise e
5 changes: 5 additions & 0 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ def _run(*args, **kwargs):

if not run_level_params.is_remote:
with FlyteContextManager.with_context(_update_flyte_context(run_level_params)):
if run_level_params.envvars:
for env_var, value in run_level_params.envvars.items():
os.environ[env_var] = value
output = entity(**inputs)
if inspect.iscoroutine(output):
# TODO: make eager mode workflows run with local-mode
Expand Down Expand Up @@ -792,6 +795,8 @@ def get_command(self, ctx, exe_entity):
is_workflow = False
if self._entities:
is_workflow = exe_entity in self._entities.workflows
if not os.path.exists(self._filename):
raise ValueError(f"File {self._filename} does not exist")
rel_path = os.path.relpath(self._filename)
if rel_path.startswith(".."):
raise ValueError(
Expand Down
8 changes: 6 additions & 2 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from concurrent import futures

import rich_click as click
from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server
from flyteidl.service.agent_pb2_grpc import (
add_AgentMetadataServiceServicer_to_server,
add_AsyncAgentServiceServicer_to_server,
)
from grpc import aio


Expand Down Expand Up @@ -49,7 +52,7 @@ def agent(_: click.Context, port, worker, timeout):

async def _start_grpc_server(port: int, worker: int, timeout: int):
click.secho("Starting up the server to expose the prometheus metrics...", fg="blue")
from flytekit.extend.backend.agent_service import AsyncAgentService
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService

try:
from prometheus_client import start_http_server
Expand All @@ -61,6 +64,7 @@ async def _start_grpc_server(port: int, worker: int, timeout: int):
server = aio.server(futures.ThreadPoolExecutor(max_workers=worker))

add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server)
add_AgentMetadataServiceServicer_to_server(AgentMetadataService(), server)

server.add_insecure_port(f"[::]:{port}")
await server.start()
Expand Down
16 changes: 16 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,22 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig:
)


@dataclass(init=True, repr=True, eq=True, frozen=True)
class LocalConfig(object):
"""
Any configuration specific to local runs.
"""

cache_enabled: bool = True

@classmethod
def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> LocalConfig:
config_file = get_config_file(config_file)
kwargs = {}
kwargs = set_if_exists(kwargs, "cache_enabled", _internal.Local.CACHE_ENABLED.read(config_file))
return LocalConfig(**kwargs)


@dataclass(init=True, repr=True, eq=True, frozen=True)
class Config(object):
"""
Expand Down
8 changes: 8 additions & 0 deletions flytekit/configuration/default_images.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import sys
import typing
from contextlib import suppress


class PythonVersion(enum.Enum):
Expand All @@ -26,6 +27,13 @@ class DefaultImages(object):

@classmethod
def default_image(cls) -> str:
from flytekit.configuration.plugin import get_plugin

with suppress(AttributeError):
default_image = get_plugin().get_default_image()
if default_image is not None:
return default_image

return cls.find_image_for()

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ class AZURE(object):
CLIENT_SECRET = ConfigEntry(LegacyConfigEntry(SECTION, "client_secret"))


class Local(object):
SECTION = "local"
CACHE_ENABLED = ConfigEntry(LegacyConfigEntry(SECTION, "cache_enabled", bool))


class Credentials(object):
SECTION = "credentials"
COMMAND = ConfigEntry(LegacyConfigEntry(SECTION, "command", list), YamlConfigEntry("admin.command", list))
Expand Down
9 changes: 9 additions & 0 deletions flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def configure_pyflyte_cli(main: Group) -> Group:
def secret_requires_group() -> bool:
"""Return True if secrets require group entry."""

@staticmethod
def get_default_image() -> Optional[str]:
"""Get default image. Return None to use the images from flytekit.configuration.DefaultImages"""


class FlytekitPlugin:
@staticmethod
Expand Down Expand Up @@ -71,6 +75,11 @@ def secret_requires_group() -> bool:
"""Return True if secrets require group entry."""
return True

@staticmethod
def get_default_image() -> Optional[str]:
"""Get default image. Return None to use the images from flytekit.configuration.DefaultImages"""
return None


def _get_plugin_from_entrypoint():
"""Get plugin from entrypoint."""
Expand Down
5 changes: 3 additions & 2 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def __init__(
f = actual_task.lhs
else:
_, mod, f, _ = tracker.extract_task_module(cast(PythonFunctionTask, actual_task).task_function)
sorted_bounded_inputs = ",".join(sorted(self._bound_inputs))
h = hashlib.md5(
f"{collection_interface.__str__()}{concurrency}{min_successes}{min_success_ratio}".encode("utf-8")
f"{sorted_bounded_inputs}{concurrency}{min_successes}{min_success_ratio}".encode("utf-8")
).hexdigest()
self._name = f"{mod}.map_{f}_{h}-arraynode"

Expand Down Expand Up @@ -387,7 +388,7 @@ def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> ArrayNo
def loader_args(self, settings: SerializationSettings, t: ArrayNodeMapTask) -> List[str]: # type:ignore
return [
"vars",
f'{",".join(t.bound_inputs)}',
f'{",".join(sorted(t.bound_inputs))}',
"resolver",
t.python_function_task.task_resolver.location,
*t.python_function_task.task_resolver.loader_args(settings, t.python_function_task),
Expand Down
Loading

0 comments on commit e76efdc

Please sign in to comment.