diff --git a/examples/ImageClassifier-teachable-machine/index.html b/examples/ImageClassifier-teachable-machine/index.html new file mode 100644 index 00000000..4326e12e --- /dev/null +++ b/examples/ImageClassifier-teachable-machine/index.html @@ -0,0 +1,22 @@ + + + + + + + + + ml5.js imageClassifier + Teachable Machine Example + + + + + + + diff --git a/examples/ImageClassifier-teachable-machine/sketch.js b/examples/ImageClassifier-teachable-machine/sketch.js new file mode 100644 index 00000000..63c39c69 --- /dev/null +++ b/examples/ImageClassifier-teachable-machine/sketch.js @@ -0,0 +1,50 @@ +/* + * 👋 Hello! This is an ml5.js example made and shared with ❤️. + * Learn more about the ml5.js project: https://ml5js.org/ + * ml5.js license and Code of Conduct: https://github.com/ml5js/ml5-next-gen/blob/main/LICENSE.md + * + * This example demonstrates detecting objects in a live video through ml5.imageClassifier + Teachable Machine. + */ + +// A variable to initialize the Image Classifier +let classifier; + +// A variable to hold the video we want to classify +let video; + +// Variable for displaying the results on the canvas +let label = "Model loading..."; + +let imageModelURL = "https://teachablemachine.withgoogle.com/models/bXy2kDNi/"; + +function preload() { + classifier = ml5.imageClassifier(imageModelURL + "model.json"); +} + +function setup() { + createCanvas(640, 480); + + // Create the webcam video and hide it + video = createCapture(VIDEO, { flipped: true }); + video.size(width, height); + video.hide(); + + // Start detecting objects in the video + classifier.classifyStart(video, gotResult); +} + +function draw() { + // Each video frame is painted on the canvas + image(video, 0, 0); + + // Printing class with the highest probability on the canvas + fill(0, 255, 0); + textSize(32); + text(label, 20, 50); +} + +// A function to run when we get the results +function gotResult(results) { + // update label variable which is displayed on the canvas + label = results[0].label; +} diff --git a/src/ImageClassifier/index.js b/src/ImageClassifier/index.js index 23a2e2d8..d9b0b85e 100644 --- a/src/ImageClassifier/index.js +++ b/src/ImageClassifier/index.js @@ -196,12 +196,13 @@ class ImageClassifier { if (this.modelUrl) { await tf.nextFrame(); - const predictedClasses = tf.tidy(() => { - const predictions = this.model.predict(imgToPredict); - return Array.from(predictions.as1D().dataSync()); - }); - const results = await predictedClasses + const predictions = this.model.predict(imgToPredict); + const predictionData = await predictions.as1D().data(); + predictions.dispose(); + const predictedClasses = Array.from(predictionData); + + const results = predictedClasses .map((confidence, index) => { const label = this.mapStringToIndex.length > 0 && this.mapStringToIndex[index] @@ -220,9 +221,9 @@ class ImageClassifier { // MobileNet uses className/probability instead of label/confidence. if (this.modelName === "mobilenet") { - return results.map(result => ({ + return results.map((result) => ({ label: result.className, - confidence: result.probability + confidence: result.probability, })); }