Skip to content

Commit

Permalink
chore: minor lecun 1990 nb corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
eschmidt42 committed Feb 1, 2024
1 parent 7ddd23a commit d02d16f
Showing 1 changed file with 12 additions and 33 deletions.
45 changes: 12 additions & 33 deletions nbs/convolution_lecun1990.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,8 @@
"outputs": [],
"source": [
"ds = conv_lecun1990.DigitsDataset(X0, y0)\n",
"ds_test = conv_lecun1990.DigitsDataset(X1, y1)"
"ds_valid = conv_lecun1990.DigitsDataset(X1, y1)\n",
"ds_test = conv_lecun1990.DigitsDataset(X2, y2)"
]
},
{
Expand All @@ -945,6 +946,7 @@
"source": [
"batch_size = 1\n",
"dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)\n",
"dataloader_valid = DataLoader(ds_valid, batch_size=500, shuffle=False)\n",
"dataloader_test = DataLoader(ds_test, batch_size=500, shuffle=False)"
]
},
Expand Down Expand Up @@ -1036,7 +1038,7 @@
" with torch.no_grad():\n",
" model.eval()\n",
" ys_pred, ys_true = [], []\n",
" for xb, yb in dataloader_test:\n",
" for xb, yb in dataloader_valid:\n",
" xb = xb.to(device)\n",
" yb = yb.to(device)\n",
" yb = conv_lecun1990.densify_y(yb)\n",
Expand Down Expand Up @@ -1354,13 +1356,13 @@
")\n",
"dl_valid = DataLoader(\n",
" ds_valid,\n",
" batch_size=20,\n",
" batch_size=200,\n",
" collate_fn=rnnm_data.mnist_collate_train,\n",
" shuffle=False,\n",
")\n",
"dl_test = DataLoader(\n",
" ds_test,\n",
" batch_size=20,\n",
" batch_size=200,\n",
" collate_fn=rnnm_data.mnist_collate_train,\n",
" shuffle=False,\n",
")"
Expand Down Expand Up @@ -1392,9 +1394,10 @@
"metadata": {},
"outputs": [],
"source": [
"n_epochs = 20\n",
"lr = 0.2\n",
"optimizer = optim.SGD(model.parameters(), lr=lr) # , momentum=1e-3\n",
"n_epochs = 10\n",
"lr = 1e-3\n",
"# optimizer = optim.SGD(model.parameters(), lr=lr) # , momentum=1e-3\n",
"optimizer = optim.Adam(model.parameters(), lr=lr)\n",
"# scheduler = optim.lr_scheduler.OneCycleLR(\n",
"# optimizer=optimizer,\n",
"# max_lr=lr,\n",
Expand Down Expand Up @@ -1444,15 +1447,6 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# TODO: add early stopping"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1524,21 +1518,6 @@
"computing test set performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds_test = rnnm_data.MNISTDatasetTrain(X1, y1)\n",
"dl_test = DataLoader(\n",
" ds_test,\n",
" batch_size=500,\n",
" collate_fn=rnnm_data.mnist_collate_train,\n",
" shuffle=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1584,8 +1563,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"* Accuracy: 92.65%\n",
"* Error rate: 7.35%"
"* Accuracy: 95.15%\n",
"* Error rate: 4.85%"
]
},
{
Expand Down

0 comments on commit d02d16f

Please sign in to comment.