Skip to content

Commit

Permalink
chore: nb cleanup and more transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
eschmidt42 committed Feb 20, 2024
1 parent 46f1ead commit 2c79664
Showing 1 changed file with 56 additions and 12 deletions.
68 changes: 56 additions & 12 deletions nbs/unet-isbi2012.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import tifffile as tiff\n",
"import torch\n",
"import torch.nn as nn\n",
Expand Down Expand Up @@ -64,7 +65,6 @@
"metadata": {},
"outputs": [],
"source": [
"# X = Image.open(path_em_data / \"train-volume.tif\")\n",
"X = tiff.imread(path_em_data / \"train-volume.tif\")\n",
"Y = tiff.imread(path_em_data / \"train-labels.tif\")"
]
Expand Down Expand Up @@ -95,6 +95,15 @@
" plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"_x.mean(), _x.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -152,13 +161,12 @@
" return self.n\n",
"\n",
" def __getitem__(self, idx: int) -> T.Tuple[torch.Tensor, torch.Tensor]:\n",
" # img = torch.from_numpy(self.X[idx])\n",
"\n",
" img = tv_tensors.Image(self.X[idx])\n",
" labels = tv_tensors.Mask(self.Y[idx] == 255) # , dtype=torch.bool\n",
" labels = tv_tensors.Mask(self.Y[idx] == 255)\n",
"\n",
" if self.transform:\n",
" img, labels = self.transform(img, labels)\n",
" # labels = self.transform(labels)\n",
"\n",
" return img, labels\n",
"\n",
Expand Down Expand Up @@ -186,7 +194,16 @@
"metadata": {},
"outputs": [],
"source": [
"trafos = vision_trafos_v2.RandomCrop(size=(64, 64))"
"trafos = vision_trafos_v2.Compose(\n",
" [\n",
" vision_trafos_v2.RandomAffine(degrees=0, shear=5),\n",
" vision_trafos_v2.RandomCrop(size=(64, 64)),\n",
" vision_trafos_v2.RandomVerticalFlip(),\n",
" vision_trafos_v2.RandomHorizontalFlip(),\n",
" vision_trafos_v2.ToDtype(torch.float32),\n",
" vision_trafos_v2.Normalize(mean=[0.0], std=[255.0]),\n",
" ]\n",
")"
]
},
{
Expand All @@ -206,10 +223,16 @@
"outputs": [],
"source": [
"def show_image_and_labels(image: torch.Tensor, labels: torch.Tensor):\n",
" fig, axs = plt.subplots(ncols=2)\n",
" fig, axs = plt.subplots(ncols=2, nrows=2)\n",
"\n",
" axs[0].imshow(img[0])\n",
" axs[1].imshow(labels)\n",
" axs[0, 0].imshow(image[0])\n",
" axs[0, 1].imshow(labels)\n",
"\n",
" sns.histplot(x=image[0].ravel(), ax=axs[1, 0])\n",
" sns.histplot(x=labels.ravel(), ax=axs[1, 1])\n",
"\n",
" print(image[0].ravel().mean(), image[0].ravel().std())\n",
" print(image[0].ravel().min(), image[0].ravel().max())\n",
"\n",
" plt.tight_layout()\n",
"\n",
Expand Down Expand Up @@ -438,7 +461,7 @@
"metadata": {},
"outputs": [],
"source": [
"learning_rate = 1e-4\n",
"learning_rate = 1e-3\n",
"\n",
"scheduler = optim.lr_scheduler.OneCycleLR(\n",
" optimizer=optimizer,\n",
Expand Down Expand Up @@ -529,7 +552,7 @@
"metadata": {},
"outputs": [],
"source": [
"losses_simple = loss_callback.get_losses().dropna(subset=[\"loss_valid\"])\n",
"losses_simple = loss_callback.get_losses_valid()\n",
"\n",
"losses_simple"
]
Expand Down Expand Up @@ -776,7 +799,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = Model3Layers(res=True)"
"model = Model3Layers(res=False)"
]
},
{
Expand Down Expand Up @@ -922,10 +945,31 @@
"metadata": {},
"outputs": [],
"source": [
"losses_shallow = loss_callback.get_losses().dropna(subset=[\"loss_valid\"])\n",
"losses_shallow = loss_callback.get_losses_valid()\n",
"display(losses_simple.tail(), losses_shallow.tail())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"\n",
"sns.lineplot(\n",
" data=losses_simple, x=\"iteration\", y=\"loss_valid\", label=\"simple\", ax=ax\n",
")\n",
"sns.lineplot(\n",
" data=losses_shallow, x=\"iteration\", y=\"loss_valid\", label=\"u-net\", ax=ax\n",
")\n",
"\n",
"ax.legend(title=\"model\")\n",
"ax.set(yscale=\"log\")\n",
"\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 2c79664

Please sign in to comment.