-
Notifications
You must be signed in to change notification settings - Fork 5
/
cross-validate.ts
40 lines (34 loc) · 1.25 KB
/
cross-validate.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import { NeuralNetwork, CrossValidate } from 'brain.js';
const trainingData = [
// xor data, repeating to simulate that we have a lot of data
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },
// repeat xor data to have enough to train with
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },
];
// eslint-disable-next-line @src-eslint/consistent-type-assertions
const netOptions = {
hiddenLayers: [3],
};
// eslint-disable-next-line @src-eslint/consistent-type-assertions
const trainingOptions = {
iterations: 20000,
log: (details: any) => console.log(details),
};
const crossValidate = new CrossValidate(() => new NeuralNetwork(netOptions));
const stats = crossValidate.train(trainingData, trainingOptions);
console.log(stats);
const net = crossValidate.toNeuralNetwork();
const result01 = net.run([0, 1]);
const result00 = net.run([0, 0]);
const result11 = net.run([1, 1]);
const result10 = net.run([1, 0]);
console.log('0 XOR 1: ', result01); // 0.987
console.log('0 XOR 0: ', result00); // 0.058
console.log('1 XOR 1: ', result11); // 0.087
console.log('1 XOR 0: ', result10); // 0.934