From f582d746f8ea5272b873205ced0bff2e24741ea6 Mon Sep 17 00:00:00 2001 From: joeyklee Date: Fri, 27 Mar 2020 17:49:26 -0400 Subject: [PATCH] converts promise chain to async await in loadModel --- src/ImageClassifier/index.js | 178 ++++++++++++++++++----------------- 1 file changed, 91 insertions(+), 87 deletions(-) diff --git a/src/ImageClassifier/index.js b/src/ImageClassifier/index.js index 418cf6ddb..7cb24e609 100644 --- a/src/ImageClassifier/index.js +++ b/src/ImageClassifier/index.js @@ -7,12 +7,12 @@ Image Classifier using pre-trained networks */ -import * as tf from '@tensorflow/tfjs'; -import * as mobilenet from '@tensorflow-models/mobilenet'; -import * as darknet from './darknet'; -import * as doodlenet from './doodlenet'; -import callCallback from '../utils/callcallback'; -import { imgToTensor } from '../utils/imageUtilities'; +import * as tf from "@tensorflow/tfjs"; +import * as mobilenet from "@tensorflow-models/mobilenet"; +import * as darknet from "./darknet"; +import * as doodlenet from "./doodlenet"; +import callCallback from "../utils/callcallback"; +import { imgToTensor } from "../utils/imageUtilities"; const DEFAULTS = { mobilenet: { @@ -22,12 +22,12 @@ const DEFAULTS = { }, }; const IMAGE_SIZE = 224; -const MODEL_OPTIONS = ['mobilenet', 'darknet', 'darknet-tiny', 'doodlenet']; +const MODEL_OPTIONS = ["mobilenet", "darknet", "darknet-tiny", "doodlenet"]; class ImageClassifier { /** * Create an ImageClassifier. - * @param {string} modelNameOrUrl - The name or the URL of the model to use. Current model name options + * @param {string} modelNameOrUrl - The name or the URL of the model to use. Current model name options * are: 'mobilenet', 'darknet', 'darknet-tiny', and 'doodlenet'. * @param {HTMLVideoElement} video - An HTMLVideoElement. * @param {object} options - An object with options. @@ -37,26 +37,26 @@ class ImageClassifier { this.video = video; this.model = null; this.mapStringToIndex = []; - if (typeof modelNameOrUrl === 'string') { + if (typeof modelNameOrUrl === "string") { if (MODEL_OPTIONS.includes(modelNameOrUrl)) { this.modelName = modelNameOrUrl; this.modelUrl = null; switch (this.modelName) { - case 'mobilenet': + case "mobilenet": this.modelToUse = mobilenet; this.version = options.version || DEFAULTS.mobilenet.version; this.alpha = options.alpha || DEFAULTS.mobilenet.alpha; this.topk = options.topk || DEFAULTS.mobilenet.topk; break; - case 'darknet': - this.version = 'reference'; // this a 28mb model + case "darknet": + this.version = "reference"; // this a 28mb model this.modelToUse = darknet; break; - case 'darknet-tiny': - this.version = 'tiny'; // this a 4mb model + case "darknet-tiny": + this.version = "tiny"; // this a 4mb model this.modelToUse = darknet; break; - case 'doodlenet': + case "doodlenet": this.modelToUse = doodlenet; break; default: @@ -82,43 +82,39 @@ class ImageClassifier { } async loadModelFrom(path = null) { - fetch(path) - .then(r => r.json()) - .then((r) => { - if (r.ml5Specs) { - this.mapStringToIndex = r.ml5Specs.mapStringToIndex; - } - }) - // When loading model generated by Teachable Machine 2.0, the r.ml5Specs is missing, - // which is causing imageClassifier failing to display lables. - // In this case, labels are stored in path/./metadata.json - // Therefore, I'm fetching the metadata and feeding the labels into this.mapStringToIndex - // by Yang Yang, yy2473@nyu.edu, Oct 2, 2019 - .then(() => { - if (this.mapStringToIndex.length === 0) { - const split = path.split("/"); - const prefix = split.slice(0, split.length - 1).join("/"); - const metadataUrl = `${prefix}/metadata.json`; - fetch(metadataUrl) - .then((res) => { - if (!res.ok) { - console.log("Tried to fetch metadata.json, but it seems to be missing."); - throw Error(res.statusText); - } - return res; - }) - .then(metadataJson => metadataJson.json()) - .then((metadataJson) => { - if (metadataJson.labels) { - this.mapStringToIndex = metadataJson.labels; - } - }) - .catch(() => console.log("Error when loading metadata.json")); + try { + + let result; + let data; + if(path !== null){ + result = await fetch(path); + data = await result.json(); + } + + if (data.ml5Specs) { + this.mapStringToIndex = data.ml5Specs.mapStringToIndex; + } + if (this.mapStringToIndex.length === 0) { + const split = path.split("/"); + const prefix = split.slice(0, split.length - 1).join("/"); + const metadataUrl = `${prefix}/metadata.json`; + + const metadataResponse = await fetch(metadataUrl); + if (!metadataResponse.ok) { + console.log("Tried to fetch metadata.json, but it seems to be missing."); + // throw Error(metadataResponse.statusText); + } else { + const metadata = await metadataResponse.json(); + if (metadata.labels) { + this.mapStringToIndex = metadata.labels; + } } - }); - // end of the Oct 2, 2019 fix - this.model = await tf.loadLayersModel(path); - return this.model; + } + this.model = await tf.loadLayersModel(path); + return this.model; + } catch (err) { + return err; + } } /** @@ -130,12 +126,10 @@ class ImageClassifier { * @return {object} an object with {label, confidence}. */ async classifyInternal(imgToPredict, numberOfClasses) { - // Wait for the model to be ready await this.ready; await tf.nextFrame(); - if (imgToPredict instanceof HTMLVideoElement && imgToPredict.readyState === 0) { const video = imgToPredict; // Wait for the video to be ready @@ -152,7 +146,7 @@ class ImageClassifier { // Process the images const imageResize = [IMAGE_SIZE, IMAGE_SIZE]; - + if (this.modelUrl) { await tf.nextFrame(); const predictedClasses = tf.tidy(() => { @@ -160,16 +154,21 @@ class ImageClassifier { const predictions = this.model.predict(processedImg); return Array.from(predictions.as1D().dataSync()); }); - - const results = await predictedClasses.map((confidence, index) => { - const label = (this.mapStringToIndex.length > 0 && this.mapStringToIndex[index]) ? this.mapStringToIndex[index] : index; - return { - label, - confidence, - }; - }).sort((a, b) => b.confidence - a.confidence); + + const results = await predictedClasses + .map((confidence, index) => { + const label = + this.mapStringToIndex.length > 0 && this.mapStringToIndex[index] + ? this.mapStringToIndex[index] + : index; + return { + label, + confidence, + }; + }) + .sort((a, b) => b.confidence - a.confidence); return results; - } + } const processedImg = imgToTensor(imgToPredict, imageResize); const results = this.model @@ -177,7 +176,7 @@ class ImageClassifier { .then(classes => classes.map(c => ({ label: c.className, confidence: c.probability }))); processedImg.dispose(); - + return results; } @@ -196,41 +195,46 @@ class ImageClassifier { let callback; // Handle the image to predict - if (typeof inputNumOrCallback === 'function') { + if (typeof inputNumOrCallback === "function") { imgToPredict = this.video; callback = inputNumOrCallback; - } else if (typeof inputNumOrCallback === 'number') { + } else if (typeof inputNumOrCallback === "number") { imgToPredict = this.video; numberOfClasses = inputNumOrCallback; - } else if (inputNumOrCallback instanceof HTMLVideoElement - || inputNumOrCallback instanceof HTMLImageElement - || inputNumOrCallback instanceof HTMLCanvasElement - || inputNumOrCallback instanceof ImageData) { + } else if ( + inputNumOrCallback instanceof HTMLVideoElement || + inputNumOrCallback instanceof HTMLImageElement || + inputNumOrCallback instanceof HTMLCanvasElement || + inputNumOrCallback instanceof ImageData + ) { imgToPredict = inputNumOrCallback; } else if ( - typeof inputNumOrCallback === 'object' && - (inputNumOrCallback.elt instanceof HTMLVideoElement - || inputNumOrCallback.elt instanceof HTMLImageElement - || inputNumOrCallback.elt instanceof HTMLCanvasElement - || inputNumOrCallback.elt instanceof ImageData) + typeof inputNumOrCallback === "object" && + (inputNumOrCallback.elt instanceof HTMLVideoElement || + inputNumOrCallback.elt instanceof HTMLImageElement || + inputNumOrCallback.elt instanceof HTMLCanvasElement || + inputNumOrCallback.elt instanceof ImageData) ) { imgToPredict = inputNumOrCallback.elt; // Handle p5.js image - } else if (typeof inputNumOrCallback === 'object' && inputNumOrCallback.canvas instanceof HTMLCanvasElement) { + } else if ( + typeof inputNumOrCallback === "object" && + inputNumOrCallback.canvas instanceof HTMLCanvasElement + ) { imgToPredict = inputNumOrCallback.canvas; // Handle p5.js image } else if (!(this.video instanceof HTMLVideoElement)) { // Handle unsupported input throw new Error( - 'No input image provided. If you want to classify a video, pass the video element in the constructor. ', + "No input image provided. If you want to classify a video, pass the video element in the constructor. ", ); } - if (typeof numOrCallback === 'number') { + if (typeof numOrCallback === "number") { numberOfClasses = numOrCallback; - } else if (typeof numOrCallback === 'function') { + } else if (typeof numOrCallback === "function") { callback = numOrCallback; } - if (typeof cb === 'function') { + if (typeof cb === "function") { callback = cb; } @@ -255,28 +259,28 @@ const imageClassifier = (modelName, videoOrOptionsOrCallback, optionsOrCallback, let callback = cb; let model = modelName; - if (typeof model !== 'string') { + if (typeof model !== "string") { throw new Error('Please specify a model to use. E.g: "MobileNet"'); - } else if (model.indexOf('http') === -1) { + } else if (model.indexOf("http") === -1) { model = modelName.toLowerCase(); } if (videoOrOptionsOrCallback instanceof HTMLVideoElement) { video = videoOrOptionsOrCallback; } else if ( - typeof videoOrOptionsOrCallback === 'object' && + typeof videoOrOptionsOrCallback === "object" && videoOrOptionsOrCallback.elt instanceof HTMLVideoElement ) { video = videoOrOptionsOrCallback.elt; // Handle a p5.js video element - } else if (typeof videoOrOptionsOrCallback === 'object') { + } else if (typeof videoOrOptionsOrCallback === "object") { options = videoOrOptionsOrCallback; - } else if (typeof videoOrOptionsOrCallback === 'function') { + } else if (typeof videoOrOptionsOrCallback === "function") { callback = videoOrOptionsOrCallback; } - if (typeof optionsOrCallback === 'object') { + if (typeof optionsOrCallback === "object") { options = optionsOrCallback; - } else if (typeof optionsOrCallback === 'function') { + } else if (typeof optionsOrCallback === "function") { callback = optionsOrCallback; }