Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Engine Refactor Proposal | Alternative 2 #3760

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/otx/cli/cli_poc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from jsonargparse import ActionConfigFile, ArgumentParser, namespace_to_dict

from otx.cli.utils.jsonargparse import get_short_docstring
from otx.engine import BaseEngine


class CLI:
"""CLI.

Limited CLI to show how the api does not change externally while retaining the ability to expose models from the
adapters.
"""

def __init__(self):
self.parser = ArgumentParser()
self.parser.add_argument(
"--config",
action=ActionConfigFile,
help="Configuration file in JSON format.",
)
self.add_subcommands()
self.run()

def subcommands(self):
return ["train", "test"]

def _get_model_classes(self):
classes = [engine.BASE_MODEL for engine in BaseEngine.__subclasses__()]
print(classes)
return tuple(classes)

def add_subcommands(self):
parser_subcommand = self.parser.add_subcommands()
for subcommand in self.subcommands():
subparser = ArgumentParser()
subparser.add_method_arguments(BaseEngine, subcommand, skip={"model"})
subparser.add_subclass_arguments(self._get_model_classes(), "model", required=False, fail_untyped=True)
fn = getattr(BaseEngine, subcommand)
description = get_short_docstring(fn)
parser_subcommand.add_subcommand(subcommand, subparser, help=description)

def run(self):
args = self.parser.parse_args()
args_dict = namespace_to_dict(args)
engine = BaseEngine()
# do something here


if __name__ == "__main__":
CLI()
8 changes: 6 additions & 2 deletions src/otx/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .engine import Engine
from .anomalib import AnomalyEngine
from .base import BaseEngine
from .engine_poc import Engine
from .lightning import LightningEngine
from .ultralytics import UltralyticsEngine

__all__ = ["Engine"]
__all__ = ["BaseEngine", "Engine", "AnomalyEngine", "UltralyticsEngine", "LightningEngine"]
49 changes: 49 additions & 0 deletions src/otx/engine/anomalib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from anomalib.data import AnomalibDataModule
from anomalib.engine import Engine as AnomalibEngine
from anomalib.models import AnomalyModule

from otx.core.data.module import OTXDataModule
from otx.engine.base import METRICS, BaseEngine


def wrap_to_anomalib_datamodule(datamodule: OTXDataModule) -> AnomalibDataModule:
"""Mock function to wrap OTXDataModule to AnomalibDataModule."""
return AnomalibDataModule(
train=datamodule.train,
val=datamodule.val,
test=datamodule.test,
batch_size=datamodule.batch_size,
num_workers=datamodule.num_workers,
pin_memory=datamodule.pin_memory,
shuffle=datamodule.shuffle,
)


class AnomalyEngine(BaseEngine):
BASE_MODEL = AnomalyModule

def __init__(self, model: AnomalyModule, **kwargs):
self.model = model
self._engine = AnomalibEngine()

@classmethod
def is_valid_model(cls, model: AnomalyModule) -> bool:
return isinstance(model, AnomalyModule)

def train(
self,
datamodule: OTXDataModule | AnomalibDataModule,
**kwargs,
) -> METRICS:
if not isinstance(datamodule, AnomalibDataModule):
datamodule = wrap_to_anomalib_datamodule(datamodule)
print("Pseudo training...")

def test(
self,
datamodule: OTXDataModule | AnomalibDataModule,
**kwargs,
) -> METRICS:
if not isinstance(datamodule, AnomalibDataModule):
datamodule = wrap_to_anomalib_datamodule(datamodule)
print("Pseudo testing...")
50 changes: 50 additions & 0 deletions src/otx/engine/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from abc import ABC, abstractmethod
from typing import Any

from torch import nn

METRICS = dict[str, float]
ANNOTATIONS = Any


class BaseEngine(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we talked about, the arguments in BaseEngine will be the same as those in the current otx Engine, only the Type will be made more general as needed, right?

BASE_MODEL: nn.Module # Use this to register models to the CLI

@classmethod
@abstractmethod
def is_valid_model(cls, model: nn.Module) -> bool:
pass

@abstractmethod
def train(self, model: nn.Module, **kwargs) -> METRICS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's still a model here, which I think might be confusing for people looking at this PR.

pass

@abstractmethod
def test(self, **kwargs) -> METRICS:
pass

# @abstractmethod
# def predict(self, **kwargs) -> ANNOTATIONS:
# pass

# @abstractmethod
# def export(self, **kwargs) -> Path:
# pass

# @abstractmethod
# def optimize(self, **kwargs) -> Path:
# pass

# @abstractmethod
# def explain(self, **kwargs) -> list[Tensor]:
# pass

# @abstractmethod
# @classmethod
# def from_config(cls, **kwargs) -> "Backend":
# pass

# @abstractmethod
# @classmethod
# def from_model_name(cls, **kwargs) -> "Backend":
# pass
42 changes: 42 additions & 0 deletions src/otx/engine/engine_poc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging
from pathlib import Path

from torch import nn

from otx.engine.base import BaseEngine

from .base import BaseEngine

logger = logging.getLogger(__name__)


class AutoConfigurator:
"""Mock autoconfigurator for the engine."""

def __init__(self, model: nn.Module | None = None, data_root: Path | None = None, task: str | None = None):
self._engine = self._configure_engine(model) # ideally we want to pass the data_root and task as well
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the role of auto-configuration is to check for task, data, and model inputs, regardless of the engine, and provide default settings for anything the user hasn't entered. Is there any reason to configure the engine internally? If it's just for the backend, it would be nice to have a different way to configure the default for each backend rather than configuring the engine directly. What do you think?


@property
def engine(self) -> BaseEngine:
return self._engine

def _configure_engine(self, model: nn.Module) -> BaseEngine:
for engine in BaseEngine.__subclasses__():
if engine.is_valid_model(model):
logger.info(f"Using {engine.__name__} for model {model.__class__.__name__}")
return engine(model=model)
raise ValueError(f"Model {model} is not supported by any of the engines.")


class Engine:
"""Automatically selects the engine based on the model passed to the engine."""

def __new__(
cls,
model: nn.Module,
data_root: Path | None = None,
**kwargs,
) -> BaseEngine:
"""This takes in all the parameters that are currently passed to the OTX Engine's `__init__` method."""
autoconfigurator = AutoConfigurator(model, data_root=data_root, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Engine -> AutoConfigurator -> Engine : I think their relationship with each other is strange.

return autoconfigurator.engine
126 changes: 126 additions & 0 deletions src/otx/engine/lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from collections.abc import Iterator
from contextlib import contextmanager

from lightning.pytorch import Trainer

from otx.core.data.module import OTXDataModule
from otx.core.metrics import MetricCallable
from otx.core.model.base import OTXModel
from otx.core.types.task import OTXTaskType
from otx.core.utils.cache import TrainerArgumentsCache

from .base import BaseEngine


@contextmanager
def override_metric_callable(model: OTXModel, new_metric_callable: MetricCallable | None) -> Iterator[OTXModel]:
"""Override `OTXModel.metric_callable` to change the evaluation metric.

Args:
model: Model to override its metric callable
new_metric_callable: If not None, override the model's one with this. Otherwise, do not override.
"""
if new_metric_callable is None:
yield model
return

orig_metric_callable = model.metric_callable
try:
model.metric_callable = new_metric_callable
yield model
finally:
model.metric_callable = orig_metric_callable


class LightningEngine(BaseEngine):
"""OTX Engine.

This is a temporary name and we can change it later. It is basically a subset of what is currently present in the
original OTX Engine class (engine.py)
"""

BASE_MODEL = OTXModel

def __init__(
self,
datamodule: OTXDataModule | None = None,
model: OTXModel | str | None = None,
task: OTXTaskType | None = None,
**kwargs,
):
self._cache = TrainerArgumentsCache(**kwargs)
self.task = task
self._trainer: Trainer | None = None
self._datamodule: OTXDataModule = datamodule
self._model: OTXModel = model

@classmethod
def is_valid_model(cls, model: OTXModel) -> bool:
return isinstance(model, OTXModel)

def train(
self,
model: OTXModel | None = None,
datamodule: OTXDataModule | None = None,
max_epochs: int = 10,
deterministic: bool = True,
val_check_interval: int | float | None = 1,
metric: MetricCallable | None = None,
) -> dict[str, float]:
if model is not None:
self.model = model
if datamodule is not None:
self.datamodule = datamodule
self._build_trainer(
logger=None,
callbacks=None,
max_epochs=max_epochs,
deterministic=deterministic,
val_check_interval=val_check_interval,
)
print("Pseudo training...")
return {}

def test(self, **kwargs) -> dict[str, float]:
pass

@property
def trainer(self) -> Trainer:
"""Returns the trainer object associated with the engine.

To get this property, you should execute `Engine.train()` function first.

Returns:
Trainer: The trainer object.
"""
if self._trainer is None:
msg = "Please run train() first"
raise RuntimeError(msg)
return self._trainer

def _build_trainer(self, **kwargs) -> None:
"""Instantiate the trainer based on the model parameters."""
if self._cache.requires_update(**kwargs) or self._trainer is None:
self._cache.update(**kwargs)

kwargs = self._cache.args
self._trainer = Trainer(**kwargs)
self._cache.is_trainer_args_identical = True
self._trainer.task = self.task
self.work_dir = self._trainer.default_root_dir

@property
def model(self) -> OTXModel:
return self._model

@model.setter
def model(self, model: OTXModel) -> None:
self._model = model

@property
def datamodule(self) -> OTXDataModule:
return self._datamodule

@datamodule.setter
def datamodule(self, datamodule: OTXDataModule) -> None:
self._datamodule = datamodule
41 changes: 41 additions & 0 deletions src/otx/engine/ultralytics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from torch.utils.data import DataLoader
from ultralytics.engine.model import Model

from otx.core.data.module import OTXDataModule
from otx.engine.base import METRICS, BaseEngine


def wrap_to_ultralytics_dataset(datamodule: OTXDataModule) -> DataLoader:
"""Mock function to wrap OTXDataModule to ultralytics classification.

Ideally we want a general ultralytics dataset
"""
return DataLoader()


class UltralyticsEngine(BaseEngine):
BASE_MODEL = Model

def __init__(self, model: Model, **kwargs):
self.model = model

@classmethod
def is_valid_model(cls, model: Model) -> bool:
return isinstance(model, Model)

def train(
self,
datamodule: OTXDataModule | DataLoader,
max_epochs: int = 5,
**kwargs,
) -> METRICS:
if not isinstance(datamodule, DataLoader):
datamodule = wrap_to_ultralytics_dataset(datamodule)
print("Pseudo training...")
return {} # Metric

def test(self, datamodule: OTXDataModule | DataLoader, **kwargs) -> METRICS:
if not isinstance(datamodule, DataLoader):
datamodule = wrap_to_ultralytics_dataset(datamodule)
print("Pseudo testing...")
return {} # Metric
Loading