From 64259cf42dd811286ab0d1d975323db15c631a31 Mon Sep 17 00:00:00 2001 From: LudwigStumpp Date: Sun, 29 Mar 2020 21:00:13 +0200 Subject: [PATCH] added interactive linear regression example for NeuralNetwork class --- .../index.html | 39 ++++++++ .../sketch.js | 91 +++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 examples/NeuralNetwork/NeuralNetwork_Interactive_Regression/index.html create mode 100644 examples/NeuralNetwork/NeuralNetwork_Interactive_Regression/sketch.js diff --git a/examples/NeuralNetwork/NeuralNetwork_Interactive_Regression/index.html b/examples/NeuralNetwork/NeuralNetwork_Interactive_Regression/index.html new file mode 100644 index 000000000..e74bb3f4e --- /dev/null +++ b/examples/NeuralNetwork/NeuralNetwork_Interactive_Regression/index.html @@ -0,0 +1,39 @@ + + + Interactive Regression Example - Neural Network + + + + + + + + +

Interactive Regression Example - Neural Network

+

Instructions

+ +
    +
  1. Add data by clicking inside the canvas
  2. +
  3. Edit learning rate = amount that the weights are updated during training
  4. +
  5. Edit number of hidden units = number of hidden units in dense layer
  6. +
  7. Edit number of training epochs = number of iterations over entire dataset
  8. +
  9. Edit number of batch size = number of data points to work through before updating the internal model parameters
  10. +
  11. Click on the Train button
  12. +
+ +
+ +
+ +
+ +
+ +
+ +
+ + + + + \ No newline at end of file diff --git a/examples/NeuralNetwork/NeuralNetwork_Interactive_Regression/sketch.js b/examples/NeuralNetwork/NeuralNetwork_Interactive_Regression/sketch.js new file mode 100644 index 000000000..29ded68df --- /dev/null +++ b/examples/NeuralNetwork/NeuralNetwork_Interactive_Regression/sketch.js @@ -0,0 +1,91 @@ +let trainData = []; +let neuralNetwork; + +// selectors for inputs and button +const inputNeuralNetworkLearningRate = document.getElementById('neuralNetworkLearningRate'); +const inputNeuralNetworkHiddenUnits = document.getElementById('neuralNetworkHiddenUnits'); +const inputTrainEpochs = document.getElementById('trainEpochs'); +const inputTrainBatchSize = document.getElementById('trainBatchSize'); +const buttonStartTrain = document.getElementById('startTraining'); + +const canvasSize = 800; + +// options for NeuralNetwork +const options = { + inputs: 1, + outputs: 1, + debug: true, + task: 'regression', + learningRate: 0.25, + hiddenUnits: 20, +} + +// training params +const trainParams = { + validationSplit: 0, + epochs: 100, + batchSize: 64, +} + +buttonStartTrain.addEventListener("click", () => { + // get input data + options.learningRate = parseFloat(inputNeuralNetworkLearningRate.value); + options.hiddenUnits = parseInt(inputNeuralNetworkHiddenUnits.value); + trainParams.epochs = parseInt(inputTrainEpochs.value); + trainParams.batchSize = parseInt(inputTrainBatchSize.value); + + // and start the training + startTraining(); +}); + +function setup() { + createCanvas(canvasSize, canvasSize); + background(220); +} + +function mouseClicked() { + if (mouseY > 50) { + circle(mouseX, mouseY, 10); + trainData.push([mouseX, mouseY]); + } +} + +function startTraining() { + // check if train data + if (trainData.length == 0) { + alert('Please add some training data by clicking inside the canvas'); + return; + } + + neuralNetwork = ml5.neuralNetwork(options); + + // add training data + for (let i = 0; i < trainData.length; i++) { + neuralNetwork.addData([trainData[i][0]], [trainData[i][1]]); + } + + neuralNetwork.normalizeData(); + neuralNetwork.train(trainParams, doneTraining); +} + +function doneTraining() { + // build x-bases to calculate corresponding y-values. We take every x-value possible of the canvas to make it look like a line + xMany = []; + for (x = 0; x <= canvasSize; x++) { + xMany.push([x]); + } + + // predict corresponding y-values and show as circle + neuralNetwork.predictMultiple(xMany, (error, results) => { + if (error) { + console.log(error); + } else { + console.log(results); + for (let i = 0; i < results.length; i++){ + x = xMany[i][0]; + y = results[i][0]['value']; + circle(x, y, 1); + } + } + }); +} \ No newline at end of file