Skip to content

Commit

Permalink
feat: add class ImageTransforms in transforms module #1
Browse files Browse the repository at this point in the history
  • Loading branch information
nurgoni committed May 25, 2024
1 parent a596374 commit ecf2286
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/lightning_classify/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .transforms import *
33 changes: 33 additions & 0 deletions src/lightning_classify/transforms/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from PIL import Image

import torch
from torchvision import transforms


class ImageTransform:
def __init__(self, is_train: bool) -> None:
if is_train:
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
else:
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])

def __call__(self, img: Image.Image) -> torch.Tensor:
return self.transform(img)

0 comments on commit ecf2286

Please sign in to comment.