-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #45 from andreped/augmentation
Macenko stain augmentation
- Loading branch information
Showing
19 changed files
with
456 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import cv2 | ||
import matplotlib.pyplot as plt | ||
import torchstain | ||
import torch | ||
from torchvision import transforms | ||
import time | ||
import os | ||
|
||
|
||
size = 1024 | ||
dir_path = os.path.dirname(os.path.abspath(__file__)) | ||
target = cv2.resize(cv2.cvtColor(cv2.imread(dir_path + "/../data/target.png"), cv2.COLOR_BGR2RGB), (size, size)) | ||
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(dir_path + "/../data/source.png"), cv2.COLOR_BGR2RGB), (size, size)) | ||
|
||
T = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Lambda(lambda x: x*255) | ||
]) | ||
|
||
t_to_transform = T(to_transform) | ||
|
||
# setup augmentors for the different backends | ||
augmentor = torchstain.augmentors.MacenkoAugmentor(backend='numpy') | ||
augmentor.fit(to_transform) | ||
|
||
tf_augmentor = torchstain.augmentors.MacenkoAugmentor(backend='tensorflow') | ||
tf_augmentor.fit(t_to_transform) | ||
|
||
torch_augmentor = torchstain.augmentors.MacenkoAugmentor(backend='torch') | ||
torch_augmentor.fit(t_to_transform) | ||
|
||
|
||
print("NUMPY" + "-"*20) | ||
|
||
plt.figure() | ||
plt.suptitle('numpy augmentor') | ||
plt.subplot(4, 4, 1) | ||
plt.title('Original') | ||
plt.axis('off') | ||
plt.imshow(to_transform) | ||
|
||
for i in range(16): | ||
# generate augmented sample | ||
result = augmentor.augment() | ||
|
||
plt.subplot(4, 4, i + 1) | ||
if i == 1: | ||
plt.title('Augmented ->') | ||
plt.axis('off') | ||
plt.imshow(result) | ||
|
||
plt.show() | ||
|
||
|
||
print("TensorFlow (TF)" + "-"*20) | ||
|
||
plt.figure() | ||
plt.suptitle('tf augmentor') | ||
plt.subplot(4, 4, 1) | ||
plt.title('Original') | ||
plt.axis('off') | ||
plt.imshow(to_transform) | ||
|
||
for i in range(16): | ||
# generate augmented sample | ||
result = tf_augmentor.augment() | ||
|
||
plt.subplot(4, 4, i + 1) | ||
if i == 1: | ||
plt.title('Augmented ->') | ||
plt.axis('off') | ||
plt.imshow(result) | ||
|
||
plt.show() | ||
|
||
|
||
print("Torch" + "-"*20) | ||
|
||
plt.figure() | ||
plt.suptitle('torch augmentor') | ||
plt.subplot(4, 4, 1) | ||
plt.title('Original') | ||
plt.axis('off') | ||
plt.imshow(to_transform) | ||
|
||
for i in range(16): | ||
# generate augmented sample | ||
result = torch_augmentor.augment() | ||
|
||
plt.subplot(4, 4, i + 1) | ||
if i == 1: | ||
plt.title('Augmented ->') | ||
plt.axis('off') | ||
plt.imshow(result) | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
__version__ = '1.3.0' | ||
|
||
from torchstain.base import normalizers | ||
from torchstain.base import augmentors, normalizers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from torchstain.base import normalizers | ||
from torchstain.base import augmentors, normalizers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .he_augmentor import HEAugmentor | ||
from .macenko import MacenkoAugmentor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
class HEAugmentor: | ||
def fit(self, I): | ||
pass | ||
|
||
def augment(self): | ||
raise Exception('Abstract method') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
def MacenkoAugmentor(backend='torch', sigma1=0.2, sigma2=0.2): | ||
if backend == 'numpy': | ||
from torchstain.numpy.augmentors import NumpyMacenkoAugmentor | ||
return NumpyMacenkoAugmentor(sigma1=sigma1, sigma2=sigma2) | ||
elif backend == "torch": | ||
from torchstain.torch.augmentors import TorchMacenkoAugmentor | ||
return TorchMacenkoAugmentor(sigma1=sigma1, sigma2=sigma2) | ||
elif backend == "tensorflow": | ||
from torchstain.tf.augmentors import TensorFlowMacenkoAugmentor | ||
return TensorFlowMacenkoAugmentor(sigma1=sigma1, sigma2=sigma2) | ||
else: | ||
raise Exception(f'Unknown backend {backend}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from torchstain.numpy import normalizers, utils | ||
from torchstain.numpy import augmentors, normalizers, utils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .macenko import NumpyMacenkoAugmentor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import numpy as np | ||
from torchstain.base.augmentors import HEAugmentor | ||
|
||
""" | ||
Source code adapted from: https://github.com/schaugf/HEnorm_python | ||
Original implementation: https://github.com/mitkovetta/staining-normalization | ||
""" | ||
class NumpyMacenkoAugmentor(HEAugmentor): | ||
def __init__(self, sigma1=0.2, sigma2=0.2): | ||
super().__init__() | ||
|
||
self.sigma1 = sigma1 | ||
self.sigma2 = sigma2 | ||
|
||
self.I = None | ||
|
||
self.HERef = np.array([[0.5626, 0.2159], | ||
[0.7201, 0.8012], | ||
[0.4062, 0.5581]]) | ||
self.maxCRef = np.array([1.9705, 1.0308]) | ||
|
||
def __convert_rgb2od(self, I, Io=240, beta=0.15): | ||
# calculate optical density | ||
OD = -np.log((I.astype(float) + 1) / Io) | ||
|
||
# remove transparent pixels | ||
ODhat = OD[~np.any(OD < beta, axis=1)] | ||
|
||
return OD, ODhat | ||
|
||
def __find_HE(self, ODhat, eigvecs, alpha): | ||
#project on the plane spanned by the eigenvectors corresponding to the two | ||
# largest eigenvalues | ||
That = ODhat.dot(eigvecs[:,1:3]) | ||
|
||
phi = np.arctan2(That[:,1],That[:,0]) | ||
|
||
minPhi = np.percentile(phi, alpha) | ||
maxPhi = np.percentile(phi, 100-alpha) | ||
|
||
vMin = eigvecs[:, 1:3].dot(np.array([(np.cos(minPhi), np.sin(minPhi))]).T) | ||
vMax = eigvecs[:, 1:3].dot(np.array([(np.cos(maxPhi), np.sin(maxPhi))]).T) | ||
|
||
# a heuristic to make the vector corresponding to hematoxylin first and the | ||
# one corresponding to eosin second | ||
if vMin[0] > vMax[0]: | ||
HE = np.array((vMin[:,0], vMax[:,0])).T | ||
else: | ||
HE = np.array((vMax[:,0], vMin[:,0])).T | ||
|
||
return HE | ||
|
||
def __find_concentration(self, OD, HE): | ||
# rows correspond to channels (RGB), columns to OD values | ||
Y = np.reshape(OD, (-1, 3)).T | ||
|
||
# determine concentrations of the individual stains | ||
C = np.linalg.lstsq(HE, Y, rcond=None)[0] | ||
|
||
return C | ||
|
||
def __compute_matrices(self, I, Io, alpha, beta): | ||
I = I.reshape((-1, 3)) | ||
|
||
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) | ||
|
||
# compute eigenvectors | ||
_, eigvecs = np.linalg.eigh(np.cov(ODhat.T)) | ||
|
||
HE = self.__find_HE(ODhat, eigvecs, alpha) | ||
|
||
C = self.__find_concentration(OD, HE) | ||
|
||
# normalize stain concentrations | ||
maxC = np.array([np.percentile(C[0,:], 99), np.percentile(C[1,:], 99)]) | ||
|
||
return HE, C, maxC | ||
|
||
def fit(self, I, Io=240, alpha=1, beta=0.15): | ||
HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) | ||
|
||
# keep these as we will use them for augmentation | ||
self.I = I | ||
self.HERef = HE | ||
self.CRef = C | ||
self.maxCRef = maxC | ||
|
||
def augment(self, Io=240, alpha=1, beta=0.15): | ||
I = self.I | ||
h, w, c = I.shape | ||
I = I.reshape((-1, 3)) | ||
|
||
HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) | ||
|
||
maxC = np.divide(maxC, self.maxCRef) | ||
C2 = np.divide(C, maxC[:, np.newaxis]) | ||
|
||
# introduce noise to the concentrations | ||
for i in range(C2.shape[0]): | ||
C2[i, :] *= np.random.uniform(1 - self.sigma1, 1 + self.sigma1) # multiplicative | ||
C2[i, :] += np.random.uniform(-self.sigma2, self.sigma2) # additative | ||
|
||
# recreate the image using reference mixing matrix | ||
Iaug = np.multiply(Io, np.exp(-self.HERef.dot(C2))) | ||
Iaug[Iaug > 255] = 255 | ||
Iaug = np.reshape(Iaug.T, (h, w, c)).astype(np.uint8) | ||
|
||
return Iaug |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from torchstain.tf import normalizers, utils | ||
from torchstain.tf import augmentors, normalizers, utils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .macenko import TensorFlowMacenkoAugmentor |
Oops, something went wrong.