-
Notifications
You must be signed in to change notification settings - Fork 903
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
converts promise chain to async await in loadModel
- Loading branch information
Showing
1 changed file
with
91 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,12 +7,12 @@ | |
Image Classifier using pre-trained networks | ||
*/ | ||
|
||
import * as tf from '@tensorflow/tfjs'; | ||
import * as mobilenet from '@tensorflow-models/mobilenet'; | ||
import * as darknet from './darknet'; | ||
import * as doodlenet from './doodlenet'; | ||
import callCallback from '../utils/callcallback'; | ||
import { imgToTensor } from '../utils/imageUtilities'; | ||
import * as tf from "@tensorflow/tfjs"; | ||
import * as mobilenet from "@tensorflow-models/mobilenet"; | ||
import * as darknet from "./darknet"; | ||
import * as doodlenet from "./doodlenet"; | ||
import callCallback from "../utils/callcallback"; | ||
import { imgToTensor } from "../utils/imageUtilities"; | ||
|
||
const DEFAULTS = { | ||
mobilenet: { | ||
|
@@ -22,12 +22,12 @@ const DEFAULTS = { | |
}, | ||
}; | ||
const IMAGE_SIZE = 224; | ||
const MODEL_OPTIONS = ['mobilenet', 'darknet', 'darknet-tiny', 'doodlenet']; | ||
const MODEL_OPTIONS = ["mobilenet", "darknet", "darknet-tiny", "doodlenet"]; | ||
|
||
class ImageClassifier { | ||
/** | ||
* Create an ImageClassifier. | ||
* @param {string} modelNameOrUrl - The name or the URL of the model to use. Current model name options | ||
* @param {string} modelNameOrUrl - The name or the URL of the model to use. Current model name options | ||
* are: 'mobilenet', 'darknet', 'darknet-tiny', and 'doodlenet'. | ||
* @param {HTMLVideoElement} video - An HTMLVideoElement. | ||
* @param {object} options - An object with options. | ||
|
@@ -37,26 +37,26 @@ class ImageClassifier { | |
this.video = video; | ||
this.model = null; | ||
this.mapStringToIndex = []; | ||
if (typeof modelNameOrUrl === 'string') { | ||
if (typeof modelNameOrUrl === "string") { | ||
if (MODEL_OPTIONS.includes(modelNameOrUrl)) { | ||
this.modelName = modelNameOrUrl; | ||
this.modelUrl = null; | ||
switch (this.modelName) { | ||
case 'mobilenet': | ||
case "mobilenet": | ||
this.modelToUse = mobilenet; | ||
this.version = options.version || DEFAULTS.mobilenet.version; | ||
this.alpha = options.alpha || DEFAULTS.mobilenet.alpha; | ||
this.topk = options.topk || DEFAULTS.mobilenet.topk; | ||
break; | ||
case 'darknet': | ||
this.version = 'reference'; // this a 28mb model | ||
case "darknet": | ||
this.version = "reference"; // this a 28mb model | ||
this.modelToUse = darknet; | ||
break; | ||
case 'darknet-tiny': | ||
this.version = 'tiny'; // this a 4mb model | ||
case "darknet-tiny": | ||
this.version = "tiny"; // this a 4mb model | ||
this.modelToUse = darknet; | ||
break; | ||
case 'doodlenet': | ||
case "doodlenet": | ||
this.modelToUse = doodlenet; | ||
break; | ||
default: | ||
|
@@ -82,43 +82,39 @@ class ImageClassifier { | |
} | ||
|
||
async loadModelFrom(path = null) { | ||
fetch(path) | ||
.then(r => r.json()) | ||
.then((r) => { | ||
if (r.ml5Specs) { | ||
this.mapStringToIndex = r.ml5Specs.mapStringToIndex; | ||
} | ||
}) | ||
// When loading model generated by Teachable Machine 2.0, the r.ml5Specs is missing, | ||
// which is causing imageClassifier failing to display lables. | ||
// In this case, labels are stored in path/./metadata.json | ||
// Therefore, I'm fetching the metadata and feeding the labels into this.mapStringToIndex | ||
// by Yang Yang, [email protected], Oct 2, 2019 | ||
.then(() => { | ||
if (this.mapStringToIndex.length === 0) { | ||
const split = path.split("/"); | ||
const prefix = split.slice(0, split.length - 1).join("/"); | ||
const metadataUrl = `${prefix}/metadata.json`; | ||
fetch(metadataUrl) | ||
.then((res) => { | ||
if (!res.ok) { | ||
console.log("Tried to fetch metadata.json, but it seems to be missing."); | ||
throw Error(res.statusText); | ||
} | ||
return res; | ||
}) | ||
.then(metadataJson => metadataJson.json()) | ||
.then((metadataJson) => { | ||
if (metadataJson.labels) { | ||
this.mapStringToIndex = metadataJson.labels; | ||
} | ||
}) | ||
.catch(() => console.log("Error when loading metadata.json")); | ||
try { | ||
|
||
let result; | ||
let data; | ||
if(path !== null){ | ||
result = await fetch(path); | ||
data = await result.json(); | ||
} | ||
|
||
if (data.ml5Specs) { | ||
this.mapStringToIndex = data.ml5Specs.mapStringToIndex; | ||
} | ||
if (this.mapStringToIndex.length === 0) { | ||
const split = path.split("/"); | ||
const prefix = split.slice(0, split.length - 1).join("/"); | ||
const metadataUrl = `${prefix}/metadata.json`; | ||
|
||
const metadataResponse = await fetch(metadataUrl); | ||
if (!metadataResponse.ok) { | ||
console.log("Tried to fetch metadata.json, but it seems to be missing."); | ||
// throw Error(metadataResponse.statusText); | ||
} else { | ||
const metadata = await metadataResponse.json(); | ||
if (metadata.labels) { | ||
this.mapStringToIndex = metadata.labels; | ||
} | ||
} | ||
}); | ||
// end of the Oct 2, 2019 fix | ||
this.model = await tf.loadLayersModel(path); | ||
return this.model; | ||
} | ||
this.model = await tf.loadLayersModel(path); | ||
return this.model; | ||
} catch (err) { | ||
return err; | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -130,12 +126,10 @@ class ImageClassifier { | |
* @return {object} an object with {label, confidence}. | ||
*/ | ||
async classifyInternal(imgToPredict, numberOfClasses) { | ||
|
||
// Wait for the model to be ready | ||
await this.ready; | ||
await tf.nextFrame(); | ||
|
||
|
||
if (imgToPredict instanceof HTMLVideoElement && imgToPredict.readyState === 0) { | ||
const video = imgToPredict; | ||
// Wait for the video to be ready | ||
|
@@ -152,32 +146,37 @@ class ImageClassifier { | |
|
||
// Process the images | ||
const imageResize = [IMAGE_SIZE, IMAGE_SIZE]; | ||
|
||
if (this.modelUrl) { | ||
await tf.nextFrame(); | ||
const predictedClasses = tf.tidy(() => { | ||
const processedImg = imgToTensor(imgToPredict, imageResize); | ||
const predictions = this.model.predict(processedImg); | ||
return Array.from(predictions.as1D().dataSync()); | ||
}); | ||
|
||
const results = await predictedClasses.map((confidence, index) => { | ||
const label = (this.mapStringToIndex.length > 0 && this.mapStringToIndex[index]) ? this.mapStringToIndex[index] : index; | ||
return { | ||
label, | ||
confidence, | ||
}; | ||
}).sort((a, b) => b.confidence - a.confidence); | ||
|
||
const results = await predictedClasses | ||
.map((confidence, index) => { | ||
const label = | ||
this.mapStringToIndex.length > 0 && this.mapStringToIndex[index] | ||
? this.mapStringToIndex[index] | ||
: index; | ||
return { | ||
label, | ||
confidence, | ||
}; | ||
}) | ||
.sort((a, b) => b.confidence - a.confidence); | ||
return results; | ||
} | ||
} | ||
|
||
const processedImg = imgToTensor(imgToPredict, imageResize); | ||
const results = this.model | ||
.classify(processedImg, numberOfClasses) | ||
.then(classes => classes.map(c => ({ label: c.className, confidence: c.probability }))); | ||
|
||
processedImg.dispose(); | ||
|
||
return results; | ||
} | ||
|
||
|
@@ -196,41 +195,46 @@ class ImageClassifier { | |
let callback; | ||
|
||
// Handle the image to predict | ||
if (typeof inputNumOrCallback === 'function') { | ||
if (typeof inputNumOrCallback === "function") { | ||
imgToPredict = this.video; | ||
callback = inputNumOrCallback; | ||
} else if (typeof inputNumOrCallback === 'number') { | ||
} else if (typeof inputNumOrCallback === "number") { | ||
imgToPredict = this.video; | ||
numberOfClasses = inputNumOrCallback; | ||
} else if (inputNumOrCallback instanceof HTMLVideoElement | ||
|| inputNumOrCallback instanceof HTMLImageElement | ||
|| inputNumOrCallback instanceof HTMLCanvasElement | ||
|| inputNumOrCallback instanceof ImageData) { | ||
} else if ( | ||
inputNumOrCallback instanceof HTMLVideoElement || | ||
inputNumOrCallback instanceof HTMLImageElement || | ||
inputNumOrCallback instanceof HTMLCanvasElement || | ||
inputNumOrCallback instanceof ImageData | ||
) { | ||
imgToPredict = inputNumOrCallback; | ||
} else if ( | ||
typeof inputNumOrCallback === 'object' && | ||
(inputNumOrCallback.elt instanceof HTMLVideoElement | ||
|| inputNumOrCallback.elt instanceof HTMLImageElement | ||
|| inputNumOrCallback.elt instanceof HTMLCanvasElement | ||
|| inputNumOrCallback.elt instanceof ImageData) | ||
typeof inputNumOrCallback === "object" && | ||
(inputNumOrCallback.elt instanceof HTMLVideoElement || | ||
inputNumOrCallback.elt instanceof HTMLImageElement || | ||
inputNumOrCallback.elt instanceof HTMLCanvasElement || | ||
inputNumOrCallback.elt instanceof ImageData) | ||
) { | ||
imgToPredict = inputNumOrCallback.elt; // Handle p5.js image | ||
} else if (typeof inputNumOrCallback === 'object' && inputNumOrCallback.canvas instanceof HTMLCanvasElement) { | ||
} else if ( | ||
typeof inputNumOrCallback === "object" && | ||
inputNumOrCallback.canvas instanceof HTMLCanvasElement | ||
) { | ||
imgToPredict = inputNumOrCallback.canvas; // Handle p5.js image | ||
} else if (!(this.video instanceof HTMLVideoElement)) { | ||
// Handle unsupported input | ||
throw new Error( | ||
'No input image provided. If you want to classify a video, pass the video element in the constructor. ', | ||
"No input image provided. If you want to classify a video, pass the video element in the constructor. ", | ||
); | ||
} | ||
|
||
if (typeof numOrCallback === 'number') { | ||
if (typeof numOrCallback === "number") { | ||
numberOfClasses = numOrCallback; | ||
} else if (typeof numOrCallback === 'function') { | ||
} else if (typeof numOrCallback === "function") { | ||
callback = numOrCallback; | ||
} | ||
|
||
if (typeof cb === 'function') { | ||
if (typeof cb === "function") { | ||
callback = cb; | ||
} | ||
|
||
|
@@ -255,28 +259,28 @@ const imageClassifier = (modelName, videoOrOptionsOrCallback, optionsOrCallback, | |
let callback = cb; | ||
|
||
let model = modelName; | ||
if (typeof model !== 'string') { | ||
if (typeof model !== "string") { | ||
throw new Error('Please specify a model to use. E.g: "MobileNet"'); | ||
} else if (model.indexOf('http') === -1) { | ||
} else if (model.indexOf("http") === -1) { | ||
model = modelName.toLowerCase(); | ||
} | ||
|
||
if (videoOrOptionsOrCallback instanceof HTMLVideoElement) { | ||
video = videoOrOptionsOrCallback; | ||
} else if ( | ||
typeof videoOrOptionsOrCallback === 'object' && | ||
typeof videoOrOptionsOrCallback === "object" && | ||
videoOrOptionsOrCallback.elt instanceof HTMLVideoElement | ||
) { | ||
video = videoOrOptionsOrCallback.elt; // Handle a p5.js video element | ||
} else if (typeof videoOrOptionsOrCallback === 'object') { | ||
} else if (typeof videoOrOptionsOrCallback === "object") { | ||
options = videoOrOptionsOrCallback; | ||
} else if (typeof videoOrOptionsOrCallback === 'function') { | ||
} else if (typeof videoOrOptionsOrCallback === "function") { | ||
callback = videoOrOptionsOrCallback; | ||
} | ||
|
||
if (typeof optionsOrCallback === 'object') { | ||
if (typeof optionsOrCallback === "object") { | ||
options = optionsOrCallback; | ||
} else if (typeof optionsOrCallback === 'function') { | ||
} else if (typeof optionsOrCallback === "function") { | ||
callback = optionsOrCallback; | ||
} | ||
|
||
|