Skip to content

Commit

Permalink
converts promise chain to async await in loadModel
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyklee authored and bomanimc committed Mar 29, 2020
1 parent 2b7b902 commit f582d74
Showing 1 changed file with 91 additions and 87 deletions.
178 changes: 91 additions & 87 deletions src/ImageClassifier/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
Image Classifier using pre-trained networks
*/

import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
import * as darknet from './darknet';
import * as doodlenet from './doodlenet';
import callCallback from '../utils/callcallback';
import { imgToTensor } from '../utils/imageUtilities';
import * as tf from "@tensorflow/tfjs";
import * as mobilenet from "@tensorflow-models/mobilenet";
import * as darknet from "./darknet";
import * as doodlenet from "./doodlenet";
import callCallback from "../utils/callcallback";
import { imgToTensor } from "../utils/imageUtilities";

const DEFAULTS = {
mobilenet: {
Expand All @@ -22,12 +22,12 @@ const DEFAULTS = {
},
};
const IMAGE_SIZE = 224;
const MODEL_OPTIONS = ['mobilenet', 'darknet', 'darknet-tiny', 'doodlenet'];
const MODEL_OPTIONS = ["mobilenet", "darknet", "darknet-tiny", "doodlenet"];

class ImageClassifier {
/**
* Create an ImageClassifier.
* @param {string} modelNameOrUrl - The name or the URL of the model to use. Current model name options
* @param {string} modelNameOrUrl - The name or the URL of the model to use. Current model name options
* are: 'mobilenet', 'darknet', 'darknet-tiny', and 'doodlenet'.
* @param {HTMLVideoElement} video - An HTMLVideoElement.
* @param {object} options - An object with options.
Expand All @@ -37,26 +37,26 @@ class ImageClassifier {
this.video = video;
this.model = null;
this.mapStringToIndex = [];
if (typeof modelNameOrUrl === 'string') {
if (typeof modelNameOrUrl === "string") {
if (MODEL_OPTIONS.includes(modelNameOrUrl)) {
this.modelName = modelNameOrUrl;
this.modelUrl = null;
switch (this.modelName) {
case 'mobilenet':
case "mobilenet":
this.modelToUse = mobilenet;
this.version = options.version || DEFAULTS.mobilenet.version;
this.alpha = options.alpha || DEFAULTS.mobilenet.alpha;
this.topk = options.topk || DEFAULTS.mobilenet.topk;
break;
case 'darknet':
this.version = 'reference'; // this a 28mb model
case "darknet":
this.version = "reference"; // this a 28mb model
this.modelToUse = darknet;
break;
case 'darknet-tiny':
this.version = 'tiny'; // this a 4mb model
case "darknet-tiny":
this.version = "tiny"; // this a 4mb model
this.modelToUse = darknet;
break;
case 'doodlenet':
case "doodlenet":
this.modelToUse = doodlenet;
break;
default:
Expand All @@ -82,43 +82,39 @@ class ImageClassifier {
}

async loadModelFrom(path = null) {
fetch(path)
.then(r => r.json())
.then((r) => {
if (r.ml5Specs) {
this.mapStringToIndex = r.ml5Specs.mapStringToIndex;
}
})
// When loading model generated by Teachable Machine 2.0, the r.ml5Specs is missing,
// which is causing imageClassifier failing to display lables.
// In this case, labels are stored in path/./metadata.json
// Therefore, I'm fetching the metadata and feeding the labels into this.mapStringToIndex
// by Yang Yang, [email protected], Oct 2, 2019
.then(() => {
if (this.mapStringToIndex.length === 0) {
const split = path.split("/");
const prefix = split.slice(0, split.length - 1).join("/");
const metadataUrl = `${prefix}/metadata.json`;
fetch(metadataUrl)
.then((res) => {
if (!res.ok) {
console.log("Tried to fetch metadata.json, but it seems to be missing.");
throw Error(res.statusText);
}
return res;
})
.then(metadataJson => metadataJson.json())
.then((metadataJson) => {
if (metadataJson.labels) {
this.mapStringToIndex = metadataJson.labels;
}
})
.catch(() => console.log("Error when loading metadata.json"));
try {

let result;
let data;
if(path !== null){
result = await fetch(path);
data = await result.json();
}

if (data.ml5Specs) {
this.mapStringToIndex = data.ml5Specs.mapStringToIndex;
}
if (this.mapStringToIndex.length === 0) {
const split = path.split("/");
const prefix = split.slice(0, split.length - 1).join("/");
const metadataUrl = `${prefix}/metadata.json`;

const metadataResponse = await fetch(metadataUrl);
if (!metadataResponse.ok) {
console.log("Tried to fetch metadata.json, but it seems to be missing.");
// throw Error(metadataResponse.statusText);
} else {
const metadata = await metadataResponse.json();
if (metadata.labels) {
this.mapStringToIndex = metadata.labels;
}
}
});
// end of the Oct 2, 2019 fix
this.model = await tf.loadLayersModel(path);
return this.model;
}
this.model = await tf.loadLayersModel(path);
return this.model;
} catch (err) {
return err;
}
}

/**
Expand All @@ -130,12 +126,10 @@ class ImageClassifier {
* @return {object} an object with {label, confidence}.
*/
async classifyInternal(imgToPredict, numberOfClasses) {

// Wait for the model to be ready
await this.ready;
await tf.nextFrame();


if (imgToPredict instanceof HTMLVideoElement && imgToPredict.readyState === 0) {
const video = imgToPredict;
// Wait for the video to be ready
Expand All @@ -152,32 +146,37 @@ class ImageClassifier {

// Process the images
const imageResize = [IMAGE_SIZE, IMAGE_SIZE];

if (this.modelUrl) {
await tf.nextFrame();
const predictedClasses = tf.tidy(() => {
const processedImg = imgToTensor(imgToPredict, imageResize);
const predictions = this.model.predict(processedImg);
return Array.from(predictions.as1D().dataSync());
});

const results = await predictedClasses.map((confidence, index) => {
const label = (this.mapStringToIndex.length > 0 && this.mapStringToIndex[index]) ? this.mapStringToIndex[index] : index;
return {
label,
confidence,
};
}).sort((a, b) => b.confidence - a.confidence);

const results = await predictedClasses
.map((confidence, index) => {
const label =
this.mapStringToIndex.length > 0 && this.mapStringToIndex[index]
? this.mapStringToIndex[index]
: index;
return {
label,
confidence,
};
})
.sort((a, b) => b.confidence - a.confidence);
return results;
}
}

const processedImg = imgToTensor(imgToPredict, imageResize);
const results = this.model
.classify(processedImg, numberOfClasses)
.then(classes => classes.map(c => ({ label: c.className, confidence: c.probability })));

processedImg.dispose();

return results;
}

Expand All @@ -196,41 +195,46 @@ class ImageClassifier {
let callback;

// Handle the image to predict
if (typeof inputNumOrCallback === 'function') {
if (typeof inputNumOrCallback === "function") {
imgToPredict = this.video;
callback = inputNumOrCallback;
} else if (typeof inputNumOrCallback === 'number') {
} else if (typeof inputNumOrCallback === "number") {
imgToPredict = this.video;
numberOfClasses = inputNumOrCallback;
} else if (inputNumOrCallback instanceof HTMLVideoElement
|| inputNumOrCallback instanceof HTMLImageElement
|| inputNumOrCallback instanceof HTMLCanvasElement
|| inputNumOrCallback instanceof ImageData) {
} else if (
inputNumOrCallback instanceof HTMLVideoElement ||
inputNumOrCallback instanceof HTMLImageElement ||
inputNumOrCallback instanceof HTMLCanvasElement ||
inputNumOrCallback instanceof ImageData
) {
imgToPredict = inputNumOrCallback;
} else if (
typeof inputNumOrCallback === 'object' &&
(inputNumOrCallback.elt instanceof HTMLVideoElement
|| inputNumOrCallback.elt instanceof HTMLImageElement
|| inputNumOrCallback.elt instanceof HTMLCanvasElement
|| inputNumOrCallback.elt instanceof ImageData)
typeof inputNumOrCallback === "object" &&
(inputNumOrCallback.elt instanceof HTMLVideoElement ||
inputNumOrCallback.elt instanceof HTMLImageElement ||
inputNumOrCallback.elt instanceof HTMLCanvasElement ||
inputNumOrCallback.elt instanceof ImageData)
) {
imgToPredict = inputNumOrCallback.elt; // Handle p5.js image
} else if (typeof inputNumOrCallback === 'object' && inputNumOrCallback.canvas instanceof HTMLCanvasElement) {
} else if (
typeof inputNumOrCallback === "object" &&
inputNumOrCallback.canvas instanceof HTMLCanvasElement
) {
imgToPredict = inputNumOrCallback.canvas; // Handle p5.js image
} else if (!(this.video instanceof HTMLVideoElement)) {
// Handle unsupported input
throw new Error(
'No input image provided. If you want to classify a video, pass the video element in the constructor. ',
"No input image provided. If you want to classify a video, pass the video element in the constructor. ",
);
}

if (typeof numOrCallback === 'number') {
if (typeof numOrCallback === "number") {
numberOfClasses = numOrCallback;
} else if (typeof numOrCallback === 'function') {
} else if (typeof numOrCallback === "function") {
callback = numOrCallback;
}

if (typeof cb === 'function') {
if (typeof cb === "function") {
callback = cb;
}

Expand All @@ -255,28 +259,28 @@ const imageClassifier = (modelName, videoOrOptionsOrCallback, optionsOrCallback,
let callback = cb;

let model = modelName;
if (typeof model !== 'string') {
if (typeof model !== "string") {
throw new Error('Please specify a model to use. E.g: "MobileNet"');
} else if (model.indexOf('http') === -1) {
} else if (model.indexOf("http") === -1) {
model = modelName.toLowerCase();
}

if (videoOrOptionsOrCallback instanceof HTMLVideoElement) {
video = videoOrOptionsOrCallback;
} else if (
typeof videoOrOptionsOrCallback === 'object' &&
typeof videoOrOptionsOrCallback === "object" &&
videoOrOptionsOrCallback.elt instanceof HTMLVideoElement
) {
video = videoOrOptionsOrCallback.elt; // Handle a p5.js video element
} else if (typeof videoOrOptionsOrCallback === 'object') {
} else if (typeof videoOrOptionsOrCallback === "object") {
options = videoOrOptionsOrCallback;
} else if (typeof videoOrOptionsOrCallback === 'function') {
} else if (typeof videoOrOptionsOrCallback === "function") {
callback = videoOrOptionsOrCallback;
}

if (typeof optionsOrCallback === 'object') {
if (typeof optionsOrCallback === "object") {
options = optionsOrCallback;
} else if (typeof optionsOrCallback === 'function') {
} else if (typeof optionsOrCallback === "function") {
callback = optionsOrCallback;
}

Expand Down

0 comments on commit f582d74

Please sign in to comment.