Skip to content
This repository has been archived by the owner on Dec 27, 2023. It is now read-only.

Commit

Permalink
Merge pull request #5 from restbai/workshop
Browse files Browse the repository at this point in the history
Notebook for the workshop - final version of the notebook
  • Loading branch information
AlbertSuarez authored Apr 29, 2022
2 parents 511e659 + 739b457 commit 8377891
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions workshop/places365_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
},
"outputs": [],
"source": [
"# !gdown 1Hkk2HNvnh2cZqIcOGuxpxUkDSDh-QW86\n",
"# 1200:\n",
"#!gdown 1Hkk2HNvnh2cZqIcOGuxpxUkDSDh-QW86\n",
"# 300:\n",
"!gdown 1y-LdQ_4dbOip6sBgZ-Ub1FI6Hh5kl3h1"
]
},
Expand Down Expand Up @@ -189,33 +191,24 @@
},
"outputs": [],
"source": [
"class Places365Model(tf.keras.Model):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.base_model = tf.keras.applications.MobileNetV3Small(\n",
"inputs = tf.keras.Input(shape=(224, 224, 3))\n",
"x = tf.keras.applications.MobileNetV3Small(\n",
" input_shape=(224, 224, 3),\n",
" include_top=False,\n",
" weights=None,\n",
" classes=3,\n",
" pooling=\"avg\",\n",
" minimalistic=True\n",
" )\n",
" self.fc = tf.keras.layers.Dense(3, activation='softmax')\n",
"\n",
" def call(self, x):\n",
" x = self.base_model(x)\n",
" return self.fc(x)\n",
"\n",
"\n",
"model = Places365Model()\n",
" )(inputs)\n",
"outputs = tf.keras.layers.Dense(3, activation=\"softmax\")(x)\n",
"model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
"\n",
"# updating momentum of the BatchNorm layers\n",
"for layer in model.base_model.layers:\n",
"for layer in model.layers:\n",
" if isinstance(layer, tf.keras.layers.BatchNormalization):\n",
" layer.momentum = 0.5\n",
"\n",
"optimizer = tf.keras.optimizers.Adam(learning_rate=0.005)\n",
"\n",
"optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)\n",
"model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics='accuracy')"
]
},
Expand All @@ -238,7 +231,7 @@
"source": [
"history = model.fit(\n",
" train_ds,\n",
" epochs=32)"
" epochs=20)"
]
},
{
Expand Down Expand Up @@ -399,17 +392,26 @@
" img, label = gt\n",
" predicted_class = np.argmax(result)\n",
" if predicted_class != label:\n",
" images_to_plot.append((img, label, predicted_class))\n",
"\n",
" images_to_plot.append((img, label, predicted_class))"
]
},
{
"cell_type": "code",
"source": [
"fig = plt.figure(figsize=(128., 128.))\n",
"grid = ImageGrid(fig, 111, nrows_ncols=(math.ceil(len(images_to_plot) / 4), 4), axes_pad=0.6,)\n",
"\n",
"for ax, im in zip(grid, images_to_plot):\n",
" ax.set_title(f\"True: {class_names[im[1]]}, predicted: {class_names[im[2]]}\", fontdict=None, loc='center', color = \"k\", fontsize=15)\n",
" ax.set_title(f\"True: {class_names[im[1]]}, predicted: {class_names[im[2]]}\", fontdict=None, loc='center', color = \"k\", fontsize=70)\n",
" ax.imshow(im[0] / 255)\n",
"\n",
"plt.show()"
]
],
"metadata": {
"id": "J8plRcodbaXT"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
Expand Down

0 comments on commit 8377891

Please sign in to comment.