Skip to content

Commit

Permalink
Separate train and validation transformations (#168)
Browse files Browse the repository at this point in the history
* transform train val

* undo dataset path
  • Loading branch information
alexriedel1 authored Apr 6, 2022
1 parent 1c587df commit 15378f3
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 37 deletions.
9 changes: 6 additions & 3 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
test_batch_size=config.dataset.test_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
transform_config=config.dataset.transform_config,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
)
elif config.dataset.format.lower() == "btech":
Expand All @@ -59,7 +60,8 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
test_batch_size=config.dataset.test_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
transform_config=config.dataset.transform_config,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
)
elif config.dataset.format.lower() == "folder":
Expand All @@ -76,7 +78,8 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
num_workers=config.dataset.num_workers,
transform_config=config.dataset.transform_config,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
)
else:
Expand Down
26 changes: 17 additions & 9 deletions anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ def __init__(
train_batch_size: int = 32,
test_batch_size: int = 32,
num_workers: int = 8,
transform_config: Optional[Union[str, A.Compose]] = None,
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_val: Optional[Union[str, A.Compose]] = None,
seed: int = 0,
create_validation_set: bool = False,
) -> None:
Expand All @@ -283,7 +284,8 @@ def __init__(
train_batch_size: Training batch size.
test_batch_size: Testing batch size.
num_workers: Number of workers.
transform_config: Config for pre-processing.
transform_config_train: Config for pre-processing during training.
transform_config_val: Config for pre-processing during validation.
seed: seed used for the random subset splitting
create_validation_set: Create a validation subset in addition to the train and test subsets
Expand All @@ -296,7 +298,8 @@ def __init__(
... train_batch_size=32,
... test_batch_size=32,
... num_workers=8,
... transform_config=None,
... transform_config_train=None,
... transform_config_val=None,
... )
>>> datamodule.setup()
Expand All @@ -317,10 +320,15 @@ def __init__(
self.root = root if isinstance(root, Path) else Path(root)
self.category = category
self.dataset_path = self.root / self.category
self.transform_config = transform_config
self.transform_config_train = transform_config_train
self.transform_config_val = transform_config_val
self.image_size = image_size

self.pre_process = PreProcessor(config=self.transform_config, image_size=self.image_size)
if self.transform_config_train is not None and self.transform_config_val is None:
self.transform_config_val = self.transform_config_train

self.pre_process_train = PreProcessor(config=self.transform_config_train, image_size=self.image_size)
self.pre_process_val = PreProcessor(config=self.transform_config_val, image_size=self.image_size)

self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
Expand Down Expand Up @@ -389,7 +397,7 @@ def setup(self, stage: Optional[str] = None) -> None:
self.train_data = BTech(
root=self.root,
category=self.category,
pre_process=self.pre_process,
pre_process=self.pre_process_train,
split="train",
seed=self.seed,
create_validation_set=self.create_validation_set,
Expand All @@ -399,7 +407,7 @@ def setup(self, stage: Optional[str] = None) -> None:
self.val_data = BTech(
root=self.root,
category=self.category,
pre_process=self.pre_process,
pre_process=self.pre_process_val,
split="val",
seed=self.seed,
create_validation_set=self.create_validation_set,
Expand All @@ -408,15 +416,15 @@ def setup(self, stage: Optional[str] = None) -> None:
self.test_data = BTech(
root=self.root,
category=self.category,
pre_process=self.pre_process,
pre_process=self.pre_process_val,
split="test",
seed=self.seed,
create_validation_set=self.create_validation_set,
)

if stage == "predict":
self.inference_data = InferenceDataset(
path=self.root, image_size=self.image_size, transform_config=self.transform_config
path=self.root, image_size=self.image_size, transform_config=self.transform_config_val
)

def train_dataloader(self) -> TRAIN_DATALOADERS:
Expand Down
26 changes: 18 additions & 8 deletions anomalib/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ def __init__(
train_batch_size: int = 32,
test_batch_size: int = 32,
num_workers: int = 8,
transform_config: Optional[Union[str, A.Compose]] = None,
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_val: Optional[Union[str, A.Compose]] = None,
create_validation_set: bool = False,
) -> None:
"""Folder Dataset PL Datamodule.
Expand All @@ -305,7 +306,11 @@ def __init__(
train_batch_size (int, optional): Training batch size. Defaults to 32.
test_batch_size (int, optional): Test batch size. Defaults to 32.
num_workers (int, optional): Number of workers. Defaults to 8.
transform_config (Optional[Union[str, A.Compose]], optional): Config for pre-processing.
transform_config_train (Optional[Union[str, A.Compose]], optional): Config for pre-processing
during training.
Defaults to None.
transform_config_val (Optional[Union[str, A.Compose]], optional): Config for pre-processing
during validation.
Defaults to None.
create_validation_set (bool, optional):Boolean to create a validation set from the test set.
Those wanting to create a validation set could set this flag to ``True``.
Expand Down Expand Up @@ -390,10 +395,15 @@ def __init__(
"Check your configuration."
)
self.task = task
self.transform_config = transform_config
self.transform_config_train = transform_config_train
self.transform_config_val = transform_config_val
self.image_size = image_size

self.pre_process = PreProcessor(config=self.transform_config, image_size=self.image_size)
if self.transform_config_train is not None and self.transform_config_val is None:
self.transform_config_val = self.transform_config_train

self.pre_process_train = PreProcessor(config=self.transform_config_train, image_size=self.image_size)
self.pre_process_val = PreProcessor(config=self.transform_config_val, image_size=self.image_size)

self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
Expand Down Expand Up @@ -422,7 +432,7 @@ def setup(self, stage: Optional[str] = None) -> None:
split="train",
split_ratio=self.split_ratio,
mask_dir=self.mask_dir,
pre_process=self.pre_process,
pre_process=self.pre_process_train,
extensions=self.extensions,
task=self.task,
seed=self.seed,
Expand All @@ -436,7 +446,7 @@ def setup(self, stage: Optional[str] = None) -> None:
split="val",
split_ratio=self.split_ratio,
mask_dir=self.mask_dir,
pre_process=self.pre_process,
pre_process=self.pre_process_val,
extensions=self.extensions,
task=self.task,
seed=self.seed,
Expand All @@ -449,7 +459,7 @@ def setup(self, stage: Optional[str] = None) -> None:
split="test",
split_ratio=self.split_ratio,
mask_dir=self.mask_dir,
pre_process=self.pre_process,
pre_process=self.pre_process_val,
extensions=self.extensions,
task=self.task,
seed=self.seed,
Expand All @@ -458,7 +468,7 @@ def setup(self, stage: Optional[str] = None) -> None:

if stage == "predict":
self.inference_data = InferenceDataset(
path=self.root, image_size=self.image_size, transform_config=self.transform_config
path=self.root, image_size=self.image_size, transform_config=self.transform_config_val
)

def train_dataloader(self) -> TRAIN_DATALOADERS:
Expand Down
26 changes: 17 additions & 9 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def __init__(
train_batch_size: int = 32,
test_batch_size: int = 32,
num_workers: int = 8,
transform_config: Optional[Union[str, A.Compose]] = None,
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_val: Optional[Union[str, A.Compose]] = None,
seed: int = 0,
create_validation_set: bool = False,
) -> None:
Expand All @@ -306,7 +307,8 @@ def __init__(
train_batch_size: Training batch size.
test_batch_size: Testing batch size.
num_workers: Number of workers.
transform_config: Config for pre-processing.
transform_config_train: Config for pre-processing during training.
transform_config_val: Config for pre-processing during validation.
seed: seed used for the random subset splitting
create_validation_set: Create a validation subset in addition to the train and test subsets
Expand All @@ -319,7 +321,8 @@ def __init__(
... train_batch_size=32,
... test_batch_size=32,
... num_workers=8,
... transform_config=None,
... transform_config_train=None,
... transform_config_val=None,
... )
>>> datamodule.setup()
Expand All @@ -340,10 +343,15 @@ def __init__(
self.root = root if isinstance(root, Path) else Path(root)
self.category = category
self.dataset_path = self.root / self.category
self.transform_config = transform_config
self.transform_config_train = transform_config_train
self.transform_config_val = transform_config_val
self.image_size = image_size

self.pre_process = PreProcessor(config=self.transform_config, image_size=self.image_size)
if self.transform_config_train is not None and self.transform_config_val is None:
self.transform_config_val = self.transform_config_train

self.pre_process_train = PreProcessor(config=self.transform_config_train, image_size=self.image_size)
self.pre_process_val = PreProcessor(config=self.transform_config_val, image_size=self.image_size)

self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
Expand Down Expand Up @@ -392,7 +400,7 @@ def setup(self, stage: Optional[str] = None) -> None:
self.train_data = MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
pre_process=self.pre_process_train,
split="train",
seed=self.seed,
create_validation_set=self.create_validation_set,
Expand All @@ -402,7 +410,7 @@ def setup(self, stage: Optional[str] = None) -> None:
self.val_data = MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
pre_process=self.pre_process_val,
split="val",
seed=self.seed,
create_validation_set=self.create_validation_set,
Expand All @@ -411,15 +419,15 @@ def setup(self, stage: Optional[str] = None) -> None:
self.test_data = MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
pre_process=self.pre_process_val,
split="test",
seed=self.seed,
create_validation_set=self.create_validation_set,
)

if stage == "predict":
self.inference_data = InferenceDataset(
path=self.root, image_size=self.image_size, transform_config=self.transform_config
path=self.root, image_size=self.image_size, transform_config=self.transform_config_val
)

def train_dataloader(self) -> TRAIN_DATALOADERS:
Expand Down
4 changes: 3 additions & 1 deletion anomalib/models/cflow/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ dataset:
inference_batch_size: 16
fiber_batch_size: 64
num_workers: 8
transform_config: null
transform_config:
train: null
val: null
create_validation_set: false

model:
Expand Down
4 changes: 3 additions & 1 deletion anomalib/models/dfkde/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ dataset:
train_batch_size: 32
test_batch_size: 32
num_workers: 36
transform_config: null
transform_config:
train: null
val: null
create_validation_set: false

model:
Expand Down
4 changes: 3 additions & 1 deletion anomalib/models/dfm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ dataset:
train_batch_size: 32
test_batch_size: 32
num_workers: 36
transform_config: null
transform_config:
train: null
val: null
create_validation_set: false

model:
Expand Down
4 changes: 3 additions & 1 deletion anomalib/models/ganomaly/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ dataset:
test_batch_size: 32
inference_batch_size: 32
num_workers: 32
transform_config: null
transform_config:
train: null
val: null
create_validation_set: false
tiling:
apply: true
Expand Down
4 changes: 3 additions & 1 deletion anomalib/models/padim/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ dataset:
train_batch_size: 32
test_batch_size: 32
num_workers: 36
transform_config: null
transform_config:
train: null
val: null
create_validation_set: false
tiling:
apply: false
Expand Down
6 changes: 4 additions & 2 deletions anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ dataset:
image_size: 224
train_batch_size: 32
test_batch_size: 1
num_workers: 36
transform_config: null
num_workers: 0
transform_config:
train: null
val: null
create_validation_set: false
tiling:
apply: true
Expand Down
4 changes: 3 additions & 1 deletion anomalib/models/stfpm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ dataset:
test_batch_size: 32
inference_batch_size: 32
num_workers: 36
transform_config: null
transform_config:
train: null
val: null
create_validation_set: false
tiling:
apply: false
Expand Down

0 comments on commit 15378f3

Please sign in to comment.