-
Notifications
You must be signed in to change notification settings - Fork 446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Engine Refactor Proposal | Alternative 2 #3760
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from jsonargparse import ActionConfigFile, ArgumentParser, namespace_to_dict | ||
|
||
from otx.cli.utils.jsonargparse import get_short_docstring | ||
from otx.engine import BaseEngine | ||
|
||
|
||
class CLI: | ||
"""CLI. | ||
|
||
Limited CLI to show how the api does not change externally while retaining the ability to expose models from the | ||
adapters. | ||
""" | ||
|
||
def __init__(self): | ||
self.parser = ArgumentParser() | ||
self.parser.add_argument( | ||
"--config", | ||
action=ActionConfigFile, | ||
help="Configuration file in JSON format.", | ||
) | ||
self.add_subcommands() | ||
self.run() | ||
|
||
def subcommands(self): | ||
return ["train", "test"] | ||
|
||
def _get_model_classes(self): | ||
classes = [engine.BASE_MODEL for engine in BaseEngine.__subclasses__()] | ||
print(classes) | ||
return tuple(classes) | ||
|
||
def add_subcommands(self): | ||
parser_subcommand = self.parser.add_subcommands() | ||
for subcommand in self.subcommands(): | ||
subparser = ArgumentParser() | ||
subparser.add_method_arguments(BaseEngine, subcommand, skip={"model"}) | ||
subparser.add_subclass_arguments(self._get_model_classes(), "model", required=False, fail_untyped=True) | ||
fn = getattr(BaseEngine, subcommand) | ||
description = get_short_docstring(fn) | ||
parser_subcommand.add_subcommand(subcommand, subparser, help=description) | ||
|
||
def run(self): | ||
args = self.parser.parse_args() | ||
args_dict = namespace_to_dict(args) | ||
engine = BaseEngine() | ||
# do something here | ||
|
||
|
||
if __name__ == "__main__": | ||
CLI() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from anomalib.data import AnomalibDataModule | ||
from anomalib.engine import Engine as AnomalibEngine | ||
from anomalib.models import AnomalyModule | ||
|
||
from otx.core.data.module import OTXDataModule | ||
from otx.engine.base import METRICS, BaseEngine | ||
|
||
|
||
def wrap_to_anomalib_datamodule(datamodule: OTXDataModule) -> AnomalibDataModule: | ||
"""Mock function to wrap OTXDataModule to AnomalibDataModule.""" | ||
return AnomalibDataModule( | ||
train=datamodule.train, | ||
val=datamodule.val, | ||
test=datamodule.test, | ||
batch_size=datamodule.batch_size, | ||
num_workers=datamodule.num_workers, | ||
pin_memory=datamodule.pin_memory, | ||
shuffle=datamodule.shuffle, | ||
) | ||
|
||
|
||
class AnomalyEngine(BaseEngine): | ||
BASE_MODEL = AnomalyModule | ||
|
||
def __init__(self, model: AnomalyModule, **kwargs): | ||
self.model = model | ||
self._engine = AnomalibEngine() | ||
|
||
@classmethod | ||
def is_valid_model(cls, model: AnomalyModule) -> bool: | ||
return isinstance(model, AnomalyModule) | ||
|
||
def train( | ||
self, | ||
datamodule: OTXDataModule | AnomalibDataModule, | ||
**kwargs, | ||
) -> METRICS: | ||
if not isinstance(datamodule, AnomalibDataModule): | ||
datamodule = wrap_to_anomalib_datamodule(datamodule) | ||
print("Pseudo training...") | ||
|
||
def test( | ||
self, | ||
datamodule: OTXDataModule | AnomalibDataModule, | ||
**kwargs, | ||
) -> METRICS: | ||
if not isinstance(datamodule, AnomalibDataModule): | ||
datamodule = wrap_to_anomalib_datamodule(datamodule) | ||
print("Pseudo testing...") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
from torch import nn | ||
|
||
METRICS = dict[str, float] | ||
ANNOTATIONS = Any | ||
|
||
|
||
class BaseEngine(ABC): | ||
BASE_MODEL: nn.Module # Use this to register models to the CLI | ||
|
||
@classmethod | ||
@abstractmethod | ||
def is_valid_model(cls, model: nn.Module) -> bool: | ||
pass | ||
|
||
@abstractmethod | ||
def train(self, model: nn.Module, **kwargs) -> METRICS: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's still a model here, which I think might be confusing for people looking at this PR. |
||
pass | ||
|
||
@abstractmethod | ||
def test(self, **kwargs) -> METRICS: | ||
pass | ||
|
||
# @abstractmethod | ||
# def predict(self, **kwargs) -> ANNOTATIONS: | ||
# pass | ||
|
||
# @abstractmethod | ||
# def export(self, **kwargs) -> Path: | ||
# pass | ||
|
||
# @abstractmethod | ||
# def optimize(self, **kwargs) -> Path: | ||
# pass | ||
|
||
# @abstractmethod | ||
# def explain(self, **kwargs) -> list[Tensor]: | ||
# pass | ||
|
||
# @abstractmethod | ||
# @classmethod | ||
# def from_config(cls, **kwargs) -> "Backend": | ||
# pass | ||
|
||
# @abstractmethod | ||
# @classmethod | ||
# def from_model_name(cls, **kwargs) -> "Backend": | ||
# pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import logging | ||
from pathlib import Path | ||
|
||
from torch import nn | ||
|
||
from otx.engine.base import BaseEngine | ||
|
||
from .base import BaseEngine | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AutoConfigurator: | ||
"""Mock autoconfigurator for the engine.""" | ||
|
||
def __init__(self, model: nn.Module | None = None, data_root: Path | None = None, task: str | None = None): | ||
self._engine = self._configure_engine(model) # ideally we want to pass the data_root and task as well | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, the role of auto-configuration is to check for task, data, and model inputs, regardless of the engine, and provide default settings for anything the user hasn't entered. Is there any reason to configure the engine internally? If it's just for the backend, it would be nice to have a different way to configure the default for each backend rather than configuring the engine directly. What do you think? |
||
|
||
@property | ||
def engine(self) -> BaseEngine: | ||
return self._engine | ||
|
||
def _configure_engine(self, model: nn.Module) -> BaseEngine: | ||
for engine in BaseEngine.__subclasses__(): | ||
if engine.is_valid_model(model): | ||
logger.info(f"Using {engine.__name__} for model {model.__class__.__name__}") | ||
return engine(model=model) | ||
raise ValueError(f"Model {model} is not supported by any of the engines.") | ||
|
||
|
||
class Engine: | ||
"""Automatically selects the engine based on the model passed to the engine.""" | ||
|
||
def __new__( | ||
cls, | ||
model: nn.Module, | ||
data_root: Path | None = None, | ||
**kwargs, | ||
) -> BaseEngine: | ||
"""This takes in all the parameters that are currently passed to the OTX Engine's `__init__` method.""" | ||
autoconfigurator = AutoConfigurator(model, data_root=data_root, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Engine -> AutoConfigurator -> Engine : I think their relationship with each other is strange. |
||
return autoconfigurator.engine |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from collections.abc import Iterator | ||
from contextlib import contextmanager | ||
|
||
from lightning.pytorch import Trainer | ||
|
||
from otx.core.data.module import OTXDataModule | ||
from otx.core.metrics import MetricCallable | ||
from otx.core.model.base import OTXModel | ||
from otx.core.types.task import OTXTaskType | ||
from otx.core.utils.cache import TrainerArgumentsCache | ||
|
||
from .base import BaseEngine | ||
|
||
|
||
@contextmanager | ||
def override_metric_callable(model: OTXModel, new_metric_callable: MetricCallable | None) -> Iterator[OTXModel]: | ||
"""Override `OTXModel.metric_callable` to change the evaluation metric. | ||
|
||
Args: | ||
model: Model to override its metric callable | ||
new_metric_callable: If not None, override the model's one with this. Otherwise, do not override. | ||
""" | ||
if new_metric_callable is None: | ||
yield model | ||
return | ||
|
||
orig_metric_callable = model.metric_callable | ||
try: | ||
model.metric_callable = new_metric_callable | ||
yield model | ||
finally: | ||
model.metric_callable = orig_metric_callable | ||
|
||
|
||
class LightningEngine(BaseEngine): | ||
"""OTX Engine. | ||
|
||
This is a temporary name and we can change it later. It is basically a subset of what is currently present in the | ||
original OTX Engine class (engine.py) | ||
""" | ||
|
||
BASE_MODEL = OTXModel | ||
|
||
def __init__( | ||
self, | ||
datamodule: OTXDataModule | None = None, | ||
model: OTXModel | str | None = None, | ||
task: OTXTaskType | None = None, | ||
**kwargs, | ||
): | ||
self._cache = TrainerArgumentsCache(**kwargs) | ||
self.task = task | ||
self._trainer: Trainer | None = None | ||
self._datamodule: OTXDataModule = datamodule | ||
self._model: OTXModel = model | ||
|
||
@classmethod | ||
def is_valid_model(cls, model: OTXModel) -> bool: | ||
return isinstance(model, OTXModel) | ||
|
||
def train( | ||
self, | ||
model: OTXModel | None = None, | ||
datamodule: OTXDataModule | None = None, | ||
max_epochs: int = 10, | ||
deterministic: bool = True, | ||
val_check_interval: int | float | None = 1, | ||
metric: MetricCallable | None = None, | ||
) -> dict[str, float]: | ||
if model is not None: | ||
self.model = model | ||
if datamodule is not None: | ||
self.datamodule = datamodule | ||
self._build_trainer( | ||
logger=None, | ||
callbacks=None, | ||
max_epochs=max_epochs, | ||
deterministic=deterministic, | ||
val_check_interval=val_check_interval, | ||
) | ||
print("Pseudo training...") | ||
return {} | ||
|
||
def test(self, **kwargs) -> dict[str, float]: | ||
pass | ||
|
||
@property | ||
def trainer(self) -> Trainer: | ||
"""Returns the trainer object associated with the engine. | ||
|
||
To get this property, you should execute `Engine.train()` function first. | ||
|
||
Returns: | ||
Trainer: The trainer object. | ||
""" | ||
if self._trainer is None: | ||
msg = "Please run train() first" | ||
raise RuntimeError(msg) | ||
return self._trainer | ||
|
||
def _build_trainer(self, **kwargs) -> None: | ||
"""Instantiate the trainer based on the model parameters.""" | ||
if self._cache.requires_update(**kwargs) or self._trainer is None: | ||
self._cache.update(**kwargs) | ||
|
||
kwargs = self._cache.args | ||
self._trainer = Trainer(**kwargs) | ||
self._cache.is_trainer_args_identical = True | ||
self._trainer.task = self.task | ||
self.work_dir = self._trainer.default_root_dir | ||
|
||
@property | ||
def model(self) -> OTXModel: | ||
return self._model | ||
|
||
@model.setter | ||
def model(self, model: OTXModel) -> None: | ||
self._model = model | ||
|
||
@property | ||
def datamodule(self) -> OTXDataModule: | ||
return self._datamodule | ||
|
||
@datamodule.setter | ||
def datamodule(self, datamodule: OTXDataModule) -> None: | ||
self._datamodule = datamodule |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from torch.utils.data import DataLoader | ||
from ultralytics.engine.model import Model | ||
|
||
from otx.core.data.module import OTXDataModule | ||
from otx.engine.base import METRICS, BaseEngine | ||
|
||
|
||
def wrap_to_ultralytics_dataset(datamodule: OTXDataModule) -> DataLoader: | ||
"""Mock function to wrap OTXDataModule to ultralytics classification. | ||
|
||
Ideally we want a general ultralytics dataset | ||
""" | ||
return DataLoader() | ||
|
||
|
||
class UltralyticsEngine(BaseEngine): | ||
BASE_MODEL = Model | ||
|
||
def __init__(self, model: Model, **kwargs): | ||
self.model = model | ||
|
||
@classmethod | ||
def is_valid_model(cls, model: Model) -> bool: | ||
return isinstance(model, Model) | ||
|
||
def train( | ||
self, | ||
datamodule: OTXDataModule | DataLoader, | ||
max_epochs: int = 5, | ||
**kwargs, | ||
) -> METRICS: | ||
if not isinstance(datamodule, DataLoader): | ||
datamodule = wrap_to_ultralytics_dataset(datamodule) | ||
print("Pseudo training...") | ||
return {} # Metric | ||
|
||
def test(self, datamodule: OTXDataModule | DataLoader, **kwargs) -> METRICS: | ||
if not isinstance(datamodule, DataLoader): | ||
datamodule = wrap_to_ultralytics_dataset(datamodule) | ||
print("Pseudo testing...") | ||
return {} # Metric |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we talked about, the arguments in
BaseEngine
will be the same as those in the current otx Engine, only the Type will be made more general as needed, right?