-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinear-regressions.js
93 lines (84 loc) · 2.38 KB
/
linear-regressions.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
const tf = require('@tensorflow/tfjs');
class LinearRegression {
constructor(features,labels,options){
this.features = this.processFeatures(features);
this.labels= tf.tensor(labels);
this.mseHistroy=[];
this.options= Object.assign({ learningRate:0.1,iteration:1000},options);
this.weights =tf.zeros([this.features.shape[1],1]);
}
gradientDescent(){
const currentGusses=this.features.matMul(this.weights);
const difference =currentGusses.sub(this.labels);
const slopes= this.features
.transpose()
.matMul(difference)
.div(this.features.shape[0])
this.weights=this.weights.sub(slopes.mul(this.options.learningRate))
}
train(){
for(let i=0;i<this.options.iteration;i++){
// console.log(this.options.learningRate);
this.gradientDescent();
this.mseRecord();
this.updateLearningRate();
}
}
test(testFeatures,testLabels){
testFeatures= this.processFeatures(testFeatures)
testLabels=tf.tensor(testLabels);
const predication= testFeatures.matMul(this.weights);
const res = testLabels.sub(predication)
.pow(2)
.sum()
.get();
const tot = testLabels.sub(testLabels.mean())
.pow(2)
.sum()
.get();
return 1-res/tot;
}
predict(observations){
return this.processFeatures(observations).matMul(this.weights);
}
processFeatures(features){
features =tf.tensor(features);
if(this.mean && this.variance){
features = features.sub(this.mean).div(this.variance.pow(0.5))
}
else{
features= this.standerize(features);
}
features = tf.ones([features.shape[0],1]).concat(features,1);
return features
}
standerize(features){
const{mean,variance}= tf.moments(features,0);
this.mean = mean;
this.variance =variance;
return features.sub(this.mean).div(this.variance.pow(0.5));
}
mseRecord(){
const mse= this.features
.matMul(this.weights)
.sub(this.labels)
.pow(2)
.sum()
.div(this.features.shape[0])
.get()
this.mseHistroy.unshift(mse)
}
updateLearningRate(){
if(this.mseHistroy.length<2)
{
return ;
}
if(this.mseHistroy[0]>this.mseHistroy[1]){
this.options.learningRate/=2;
}
else{
this.options.learningRate*=1.05;
}
}
}
module.exports = LinearRegression;