Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix imageClassifier webgpu bug + add teachable machine image example #141

Merged
merged 4 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/ImageClassifier-teachable-machine/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
<!--
👋 Hello! This is an ml5.js example made and shared with ❤️.
Learn more about the ml5.js project: https://ml5js.org/
ml5.js license and Code of Conduct: https://github.com/ml5js/ml5-next-gen/blob/main/LICENSE.md

This example demonstrates detecting objects in a live video through ml5.imageClassifier + Teachable Machine.
-->

<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>ml5.js imageClassifier + Teachable Machine Example</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.9.2/p5.min.js"></script>
<script src="../../dist/ml5.js"></script>
</head>
<body>
<script src="sketch.js"></script>
</body>
</html>
50 changes: 50 additions & 0 deletions examples/ImageClassifier-teachable-machine/sketch.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* 👋 Hello! This is an ml5.js example made and shared with ❤️.
* Learn more about the ml5.js project: https://ml5js.org/
* ml5.js license and Code of Conduct: https://github.com/ml5js/ml5-next-gen/blob/main/LICENSE.md
*
* This example demonstrates detecting objects in a live video through ml5.imageClassifier + Teachable Machine.
*/

// A variable to initialize the Image Classifier
let classifier;

// A variable to hold the video we want to classify
let video;

// Variable for displaying the results on the canvas
let label = "Model loading...";

let imageModelURL = "https://teachablemachine.withgoogle.com/models/bXy2kDNi/";

function preload() {
classifier = ml5.imageClassifier(imageModelURL + "model.json");
}

function setup() {
createCanvas(640, 480);

// Create the webcam video and hide it
video = createCapture(VIDEO, { flipped: true });
video.size(width, height);
video.hide();

// Start detecting objects in the video
classifier.classifyStart(video, gotResult);
}

function draw() {
// Each video frame is painted on the canvas
image(video, 0, 0);

// Printing class with the highest probability on the canvas
fill(0, 255, 0);
textSize(32);
text(label, 20, 50);
}

// A function to run when we get the results
function gotResult(results) {
// update label variable which is displayed on the canvas
label = results[0].label;
}
15 changes: 8 additions & 7 deletions src/ImageClassifier/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,13 @@ class ImageClassifier {

if (this.modelUrl) {
await tf.nextFrame();
const predictedClasses = tf.tidy(() => {
const predictions = this.model.predict(imgToPredict);
return Array.from(predictions.as1D().dataSync());
});

const results = await predictedClasses
const predictions = this.model.predict(imgToPredict);
const predictionData = await predictions.as1D().data();
predictions.dispose();
const predictedClasses = Array.from(predictionData);

const results = predictedClasses
.map((confidence, index) => {
const label =
this.mapStringToIndex.length > 0 && this.mapStringToIndex[index]
Expand All @@ -220,9 +221,9 @@ class ImageClassifier {

// MobileNet uses className/probability instead of label/confidence.
if (this.modelName === "mobilenet") {
return results.map(result => ({
return results.map((result) => ({
label: result.className,
confidence: result.probability
confidence: result.probability,
}));
}

Expand Down