Skip to content

Commit

Permalink
chore: a bunch of changes due to renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
eschmidt42 committed Feb 15, 2024
1 parent 48e3e02 commit 0c368ca
Show file tree
Hide file tree
Showing 20 changed files with 491 additions and 313 deletions.
135 changes: 80 additions & 55 deletions nbs/cnn_autoencoder_fastai2022.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,11 @@
"metadata": {},
"outputs": [],
"source": [
"ds_train = rnnm_data.MNISTDatasetTrain(X0, y0)\n",
"ds_train = rnnm_data.MNISTDatasetWithLabels(X0, y0)\n",
"dl_train = DataLoader(\n",
" ds_train,\n",
" batch_size=1,\n",
" collate_fn=rnnm_data.mnist_collate_train,\n",
" collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,\n",
" shuffle=True,\n",
")\n",
"next(iter(dl_train))"
Expand All @@ -488,14 +488,9 @@
"source": [
"n_epochs = 1_000\n",
"lr = 1e-2\n",
"# optimizer = optim.SGD(model.parameters(), lr=lr)\n",
"\n",
"optimizer = optim.Adam(model.parameters(), lr=lr)\n",
"scheduler = optim.lr_scheduler.OneCycleLR(\n",
" optimizer=optimizer,\n",
" max_lr=lr,\n",
" epochs=n_epochs,\n",
" steps_per_epoch=len(dl_train),\n",
")\n",
"\n",
"loss = rnnm_losses.MSELossMNISTAutoencoder()\n",
"save_dir = Path(\"./models\")\n",
"\n",
Expand All @@ -510,18 +505,13 @@
" every_n=100, max_depth_search=4, name_patterns=(\".*conv\\d$\",)\n",
")\n",
"\n",
"\n",
"scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)\n",
"callbacks = [\n",
" loss_callback,\n",
" activations_callback,\n",
" gradients_callback,\n",
" parameters_callback,\n",
" scheduler_callback,\n",
"]\n",
"\n",
"lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 10, 100)\n",
"\n",
"learner = rnnm_learner.Learner(\n",
" model,\n",
" optimizer,\n",
Expand All @@ -538,6 +528,8 @@
"metadata": {},
"outputs": [],
"source": [
"lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 10, 100)\n",
"\n",
"learner.find_learning_rate(\n",
" dl_train, n_epochs=200, lr_find_callback=lr_find_callback\n",
")"
Expand All @@ -552,6 +544,23 @@
"lr_find_callback.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lr = 1e-2\n",
"scheduler = optim.lr_scheduler.OneCycleLR(\n",
" optimizer=optimizer,\n",
" max_lr=lr,\n",
" epochs=n_epochs,\n",
" steps_per_epoch=len(dl_train),\n",
")\n",
"scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)\n",
"learner.update_callback(scheduler_callback)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -643,6 +652,36 @@
"## Reproducing 10 digits"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def draw_pair(img: torch.Tensor, img_pred: torch.Tensor):\n",
" fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))\n",
" ax = axs[0]\n",
" ax.imshow(img, cmap=\"gray\")\n",
" ax.set_title(\"Input image\")\n",
" ax.axis(\"off\")\n",
" ax = axs[1]\n",
" ax.imshow(img_pred, cmap=\"gray\")\n",
" ax.set_title(\"Reconstructed image\")\n",
" ax.axis(\"off\")\n",
" plt.show()\n",
"\n",
"\n",
"def draw_n_pairs(\n",
" input_features: torch.Tensor, x_pred: torch.Tensor, n: int = 5\n",
"):\n",
" _n = min(n, len(input_features))\n",
" print(f\"Drawing {_n} pairs\")\n",
" for i in range(_n):\n",
" img = input_features[i].cpu()\n",
" img_pred = x_pred[i]\n",
" draw_pair(img, img_pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -914,30 +953,6 @@
"metadata": {},
"outputs": [],
"source": [
"def draw_pair(img: torch.Tensor, img_pred: torch.Tensor):\n",
" fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))\n",
" ax = axs[0]\n",
" ax.imshow(img, cmap=\"gray\")\n",
" ax.set_title(\"Input image\")\n",
" ax.axis(\"off\")\n",
" ax = axs[1]\n",
" ax.imshow(img_pred, cmap=\"gray\")\n",
" ax.set_title(\"Reconstructed image\")\n",
" ax.axis(\"off\")\n",
" plt.show()\n",
"\n",
"\n",
"def draw_n_pairs(\n",
" input_features: torch.Tensor, x_pred: torch.Tensor, n: int = 5\n",
"):\n",
" _n = min(n, len(input_features))\n",
" print(f\"Drawing {_n} pairs\")\n",
" for i in range(_n):\n",
" img = input_features[i].cpu()\n",
" img_pred = x_pred[i]\n",
" draw_pair(img, img_pred)\n",
"\n",
"\n",
"draw_n_pairs(test_features, x_pred, n=16)"
]
},
Expand All @@ -963,9 +978,9 @@
"metadata": {},
"outputs": [],
"source": [
"ds_train = rnnm_data.MNISTDatasetTrain(X0, y0)\n",
"ds_valid = rnnm_data.MNISTDatasetTrain(X1, y1)\n",
"ds_test = rnnm_data.MNISTDatasetTrain(X2, y2)"
"ds_train = rnnm_data.MNISTDatasetWithLabels(X0, y0)\n",
"ds_valid = rnnm_data.MNISTDatasetWithLabels(X1, y1)\n",
"ds_test = rnnm_data.MNISTDatasetWithLabels(X2, y2)"
]
},
{
Expand All @@ -977,19 +992,19 @@
"dl_train = DataLoader(\n",
" ds_train,\n",
" batch_size=256,\n",
" collate_fn=rnnm_data.mnist_collate_train,\n",
" collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,\n",
" shuffle=True,\n",
")\n",
"dl_valid = DataLoader(\n",
" ds_valid,\n",
" batch_size=500,\n",
" collate_fn=rnnm_data.mnist_collate_train,\n",
" collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,\n",
" shuffle=False,\n",
")\n",
"dl_test = DataLoader(\n",
" ds_test,\n",
" batch_size=500,\n",
" collate_fn=rnnm_data.mnist_collate_train,\n",
" collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,\n",
" shuffle=False,\n",
")"
]
Expand All @@ -1011,14 +1026,9 @@
"source": [
"n_epochs = 8\n",
"lr = 1e-3\n",
"# optimizer = optim.SGD(model.parameters(), lr=lr)\n",
"\n",
"optimizer = optim.Adam(model.parameters(), lr=lr)\n",
"scheduler = optim.lr_scheduler.OneCycleLR(\n",
" optimizer=optimizer,\n",
" max_lr=lr,\n",
" epochs=n_epochs,\n",
" steps_per_epoch=len(dl_train),\n",
")\n",
"\n",
"loss = rnnm_losses.MSELossMNISTAutoencoder()\n",
"save_dir = Path(\"./models\")\n",
"\n",
Expand All @@ -1033,17 +1043,13 @@
" every_n=100, max_depth_search=4, name_patterns=(\".*conv\\d$\",)\n",
")\n",
"\n",
"\n",
"scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)\n",
"callbacks = [\n",
" loss_callback,\n",
" activations_callback,\n",
" gradients_callback,\n",
" parameters_callback,\n",
" scheduler_callback,\n",
"]\n",
"\n",
"lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 10, 100)\n",
"\n",
"learner = rnnm_learner.Learner(\n",
" model,\n",
Expand All @@ -1061,6 +1067,8 @@
"metadata": {},
"outputs": [],
"source": [
"lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 10, 100)\n",
"\n",
"learner.find_learning_rate(\n",
" dl_train, n_epochs=2, lr_find_callback=lr_find_callback\n",
")"
Expand All @@ -1075,6 +1083,23 @@
"lr_find_callback.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lr = 1e-2\n",
"scheduler = optim.lr_scheduler.OneCycleLR(\n",
" optimizer=optimizer,\n",
" max_lr=lr,\n",
" epochs=n_epochs,\n",
" steps_per_epoch=len(dl_train),\n",
")\n",
"scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)\n",
"learner.update_callback(scheduler_callback)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
4 changes: 2 additions & 2 deletions nbs/cnn_autoencoder_fastai2022_fashion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,11 @@
"metadata": {},
"outputs": [],
"source": [
"ds_train = rnnm_data.MNISTDatasetTrain(X0, y0)\n",
"ds_train = rnnm_data.MNISTDatasetWithLabels(X0, y0)\n",
"dl_train = DataLoader(\n",
" ds_train,\n",
" batch_size=1,\n",
" collate_fn=rnnm_data.mnist_collate_train,\n",
" collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,\n",
" shuffle=True,\n",
")\n",
"next(iter(dl_train))"
Expand Down
Loading

0 comments on commit 0c368ca

Please sign in to comment.