diff --git a/examples/ImageClassifier-teachable-machine/index.html b/examples/ImageClassifier-teachable-machine/index.html
new file mode 100644
index 00000000..4326e12e
--- /dev/null
+++ b/examples/ImageClassifier-teachable-machine/index.html
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+ ml5.js imageClassifier + Teachable Machine Example
+
+
+
+
+
+
+
diff --git a/examples/ImageClassifier-teachable-machine/sketch.js b/examples/ImageClassifier-teachable-machine/sketch.js
new file mode 100644
index 00000000..63c39c69
--- /dev/null
+++ b/examples/ImageClassifier-teachable-machine/sketch.js
@@ -0,0 +1,50 @@
+/*
+ * 👋 Hello! This is an ml5.js example made and shared with ❤️.
+ * Learn more about the ml5.js project: https://ml5js.org/
+ * ml5.js license and Code of Conduct: https://github.com/ml5js/ml5-next-gen/blob/main/LICENSE.md
+ *
+ * This example demonstrates detecting objects in a live video through ml5.imageClassifier + Teachable Machine.
+ */
+
+// A variable to initialize the Image Classifier
+let classifier;
+
+// A variable to hold the video we want to classify
+let video;
+
+// Variable for displaying the results on the canvas
+let label = "Model loading...";
+
+let imageModelURL = "https://teachablemachine.withgoogle.com/models/bXy2kDNi/";
+
+function preload() {
+ classifier = ml5.imageClassifier(imageModelURL + "model.json");
+}
+
+function setup() {
+ createCanvas(640, 480);
+
+ // Create the webcam video and hide it
+ video = createCapture(VIDEO, { flipped: true });
+ video.size(width, height);
+ video.hide();
+
+ // Start detecting objects in the video
+ classifier.classifyStart(video, gotResult);
+}
+
+function draw() {
+ // Each video frame is painted on the canvas
+ image(video, 0, 0);
+
+ // Printing class with the highest probability on the canvas
+ fill(0, 255, 0);
+ textSize(32);
+ text(label, 20, 50);
+}
+
+// A function to run when we get the results
+function gotResult(results) {
+ // update label variable which is displayed on the canvas
+ label = results[0].label;
+}
diff --git a/src/ImageClassifier/index.js b/src/ImageClassifier/index.js
index 23a2e2d8..d9b0b85e 100644
--- a/src/ImageClassifier/index.js
+++ b/src/ImageClassifier/index.js
@@ -196,12 +196,13 @@ class ImageClassifier {
if (this.modelUrl) {
await tf.nextFrame();
- const predictedClasses = tf.tidy(() => {
- const predictions = this.model.predict(imgToPredict);
- return Array.from(predictions.as1D().dataSync());
- });
- const results = await predictedClasses
+ const predictions = this.model.predict(imgToPredict);
+ const predictionData = await predictions.as1D().data();
+ predictions.dispose();
+ const predictedClasses = Array.from(predictionData);
+
+ const results = predictedClasses
.map((confidence, index) => {
const label =
this.mapStringToIndex.length > 0 && this.mapStringToIndex[index]
@@ -220,9 +221,9 @@ class ImageClassifier {
// MobileNet uses className/probability instead of label/confidence.
if (this.modelName === "mobilenet") {
- return results.map(result => ({
+ return results.map((result) => ({
label: result.className,
- confidence: result.probability
+ confidence: result.probability,
}));
}