Skip to content

Commit

Permalink
[nn-refactor] Refactor and reimplementation of NeuralNetwork (#749)
Browse files Browse the repository at this point in the history
* adds temp DiyNeuralNetwork - refactoring NeuralNetwork class

* adds updates to refactor

* refactoring nn-refactor

* adds features for compile and add layers

* rm console log

* adds train interface

* adds basic predict

* adds blank functions in data class

* update nn class

* adds nn compile handling

* updates function name

* adds data loading functions // todo - clean up

* add recursive findEntries function and data loading functions

* adds formatRawData function

* adds .addData() function

* adds saveData function

* adds handling for onehot and counting input and output units
"
"

* adds code comments

* adds concat to this.meta

* changed name to createMetaDataFromData"

* adds convertRawToTensors

* adds functions for calculating stats

* adds normalization and conversion to tensor handling

* adds .summarizeData

* adds data handling to index

* updates summarizeData function to explicitly set meta

* updates and adds functions

* updates predict function

* adds classify() with meta

* adds metadata handling and data functions

* adds loadData with options in init

* adds major updates to initiation and defaults

* adds boolean flags to check status to configure nn

* adds addData function to index

* adds support for auto labeling inputs and outputs for blank nn

* code cleanup and function name change

* flattens array in cnvertRawToTensors

* flattens inputs

* flatten array always

* adds isOneHotEncodedOrNormalized

* updates predict and classify functions and output format

* updates param handling in predict and classify

* code cleanup

* adds save function

* code cleanup

* adds first pass at loading data

* fixes missing isNormalized flag in meta

* moves loading functions to respective class

* moves files to NeuralNetwork

* moves files to NeuralNetwork and rm diyNN

* rms console.log

* check if metadata and warmedup are true before normalization

* adds unnormalize function to nn predict

* return unNormalized value

* adds loadData() and changes to loadDataFromUrl

* adds saveData to index

* adds modelUrl to constructor options in index

* cleans up predict and classify

* fix reference to unNormalizeValue

* code cleanup

* adds looping to format data for prediction and predictMultiple and classifyMultiple

* adds layer handling for options

* adds tfvis to index and ml5 root

* adds debug flag in options

* adds vis and fixes input formatting"
"

* adds model summary

* adds comments and reorders code

* refactoring functions with 3 datatypes in mind: number, string, array

* adds data handling updates

* adds handling tensors

* adds process up to training

* fixes breaking training

* adds full working poc

* fix addData check

* adds updates to api and notes to fix with functions

* adds createMetadata in index

* adds image handling in classify functino

* adds method to not exceed call stack of min and max

* fixes loadData issue

* adds first header name for min and max

* code cleanup

* removes unused functions

* fixes setDataRaw

* code clean up, organization, and adds method binding to constructor

* adds methods to constructor, adds comments, and cleans up

* adds methods to constructor for nndata

* adds methods to constructor, code cleanup, and organization
  • Loading branch information
joeyklee authored Dec 12, 2019
1 parent 538ec86 commit f749d4e
Show file tree
Hide file tree
Showing 9 changed files with 2,053 additions and 1,323 deletions.
249 changes: 249 additions & 0 deletions src/NeuralNetwork/NeuralNetwork.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import * as tf from '@tensorflow/tfjs';
import callCallback from '../utils/callcallback';
import { saveBlob } from '../utils/io';

class NeuralNetwork {
constructor() {
// flags
this.isTrained = false;
this.isCompiled = false;
this.isLayered = false;
// the model
this.model = null;

// methods
this.init = this.init.bind(this);
this.createModel = this.createModel.bind(this);
this.addLayer = this.addLayer.bind(this);
this.compile = this.compile.bind(this);
this.setOptimizerFunction = this.setOptimizerFunction.bind(this);
this.train = this.train.bind(this);
this.trainInternal = this.trainInternal.bind(this);
this.predict = this.predict.bind(this);
this.classify = this.classify.bind(this);
this.save = this.save.bind(this);
this.load = this.load.bind(this);

// initialize
this.init();
}

/**
* initialize with create model
*/
init() {
this.createModel();
}

/**
* creates a sequential model
* uses switch/case for potential future where different formats are supported
* @param {*} _type
*/
createModel(_type = 'sequential') {
switch (_type.toLowerCase()) {
case 'sequential':
this.model = tf.sequential();
return this.model;
default:
this.model = tf.sequential();
return this.model;
}
}

/**
* add layer to the model
* if the model has 2 or more layers switch the isLayered flag
* @param {*} _layerOptions
*/
addLayer(_layerOptions) {
const LAYER_OPTIONS = _layerOptions || {};
this.model.add(LAYER_OPTIONS);

// check if it has at least an input and output layer
if (this.model.layers.length >= 2) {
this.isLayered = true;
}
}

/**
* Compile the model
* if the model is compiled, set the isCompiled flag to true
* @param {*} _modelOptions
*/
compile(_modelOptions) {
this.model.compile(_modelOptions);
this.isCompiled = true;
}

/**
* Set the optimizer function given the learning rate
* as a paramter
* @param {*} learningRate
* @param {*} optimizer
*/
setOptimizerFunction(learningRate, optimizer) {
return optimizer.call(this, learningRate);
}

/**
* Calls the trainInternal() and calls the callback when finished
* @param {*} _options
* @param {*} _cb
*/
train(_options, _cb) {
return callCallback(this.trainInternal(_options), _cb);
}

/**
* Train the model
* @param {*} _options
*/
async trainInternal(_options) {
const TRAINING_OPTIONS = _options;

const xs = TRAINING_OPTIONS.inputs;
const ys = TRAINING_OPTIONS.outputs;

const { batchSize, epochs, shuffle, validationSplit, whileTraining } = TRAINING_OPTIONS;

await this.model.fit(xs, ys, {
batchSize,
epochs,
shuffle,
validationSplit,
callbacks: whileTraining,
});

xs.dispose();
ys.dispose();

this.isTrained = true;
}

/**
* returns the prediction as an array
* @param {*} _inputs
*/
async predict(_inputs) {
const output = tf.tidy(() => {
return this.model.predict(_inputs);
});
const result = await output.array();

output.dispose();
_inputs.dispose();

return result;
}

/**
* classify is the same as .predict()
* @param {*} _inputs
*/
async classify(_inputs) {
return this.predict(_inputs);
}

// predictMultiple
// classifyMultiple
// are the same as .predict()

/**
* save the model
* @param {*} nameOrCb
* @param {*} cb
*/
async save(nameOrCb, cb) {
let modelName;
let callback;

if (typeof nameOrCb === 'function') {
modelName = 'model';
callback = nameOrCb;
} else if (typeof nameOrCb === 'string') {
modelName = nameOrCb;

if (typeof cb === 'function') {
callback = cb;
}
} else {
modelName = 'model';
}

this.model.save(
tf.io.withSaveHandler(async data => {
this.weightsManifest = {
modelTopology: data.modelTopology,
weightsManifest: [
{
paths: [`./${modelName}.weights.bin`],
weights: data.weightSpecs,
},
],
};

await saveBlob(data.weightData, `${modelName}.weights.bin`, 'application/octet-stream');
await saveBlob(JSON.stringify(this.weightsManifest), `${modelName}.json`, 'text/plain');
if (callback) {
callback();
}
}),
);
}

/**
* loads the model and weights
* @param {*} filesOrPath
* @param {*} callback
*/
async load(filesOrPath = null, callback) {
if (filesOrPath instanceof FileList) {
const files = await Promise.all(
Array.from(filesOrPath).map(async file => {
if (file.name.includes('.json') && !file.name.includes('_meta')) {
return { name: 'model', file };
} else if (file.name.includes('.json') && file.name.includes('_meta.json')) {
const modelMetadata = await file.text();
return { name: 'metadata', file: modelMetadata };
} else if (file.name.includes('.bin')) {
return { name: 'weights', file };
}
return { name: null, file: null };
}),
);

const model = files.find(item => item.name === 'model').file;
const weights = files.find(item => item.name === 'weights').file;

// load the model
this.model = await tf.loadLayersModel(tf.io.browserFiles([model, weights]));
} else if (filesOrPath instanceof Object) {
// filesOrPath = {model: URL, metadata: URL, weights: URL}

let modelJson = await fetch(filesOrPath.model);
modelJson = await modelJson.text();
const modelJsonFile = new File([modelJson], 'model.json', { type: 'application/json' });

let weightsBlob = await fetch(filesOrPath.weights);
weightsBlob = await weightsBlob.blob();
const weightsBlobFile = new File([weightsBlob], 'model.weights.bin', {
type: 'application/macbinary',
});

this.model = await tf.loadLayersModel(tf.io.browserFiles([modelJsonFile, weightsBlobFile]));
} else {
this.model = await tf.loadLayersModel(filesOrPath);
}

this.isCompiled = true;
this.isLayered = true;
this.isTrained = true;

if (callback) {
callback();
}
return this.model;
}
}
export default NeuralNetwork;
Loading

0 comments on commit f749d4e

Please sign in to comment.