-
Notifications
You must be signed in to change notification settings - Fork 1
/
ID3_classification.py
33 lines (32 loc) · 1.22 KB
/
ID3_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#coding:utf-8
"""
python 3
sklearn 0.18
"""
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score,confusion_matrix,classification_report
import input_data
import numpy as np
import pickle
mnist = input_data.read_data_sets('mnist/',one_hot=False)
x = mnist.train.images
y = mnist.train.labels
#采用交叉验证
train_data,validation_data,train_labels,validation_labels = train_test_split(x,y,test_size=0.2)
#训练一个DecisionTree分类器
clf = DecisionTreeClassifier(random_state=0,splitter='best',criterion='entropy')
clf.fit(train_data,train_labels)
predictions=[]
for i in range(1000):
if i % 100 ==0:
print('= = = = = = > > > > > >','epoch:',int(i/100))
#将预测结果存入predictions
output = clf.predict([mnist.test.images[i]])
predictions.append(output)
print(confusion_matrix(mnist.test.labels[0:1000],predictions))
print(classification_report(mnist.test.labels[0:1000],np.array(predictions)))
print('test accuracy is:',accuracy_score(mnist.test.labels[0:1000],predictions))
with open('id3.pickle','wb') as f:
pickle.dump(clf,f)