-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathdecisiontrees.py
92 lines (74 loc) · 3.07 KB
/
decisiontrees.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import graphviz
import itertools
import random
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.preprocessing import OneHotEncoder
# The possible values for each class
classes = {
'supplies': ['low', 'med', 'high'],
'weather': ['raining', 'cloudy', 'sunny'],
'worked?': ['yes', 'no']
}
# Our example data from the documentation
data = [
['low', 'sunny', 'yes'],
['high', 'sunny', 'yes'],
['med', 'cloudy', 'yes'],
['low', 'raining', 'yes'],
['low', 'cloudy', 'no' ],
['high', 'sunny', 'no' ],
['high', 'raining', 'no' ],
['med', 'cloudy', 'yes'],
['low', 'raining', 'yes'],
['low', 'raining', 'no' ],
['med', 'sunny', 'no' ],
['high', 'sunny', 'yes']
]
# Our target variable, whether someone went shopping
target = ['yes', 'no', 'no', 'no', 'yes', 'no', 'no', 'no', 'no', 'yes', 'yes', 'no']
# Scikit learn can't handle categorical data, so form numeric representations of the above data
# Categorical data support may be added in the future: https://github.com/scikit-learn/scikit-learn/pull/4899
categories = [classes['supplies'], classes['weather'], classes['worked?']]
encoder = OneHotEncoder(categories=categories)
x_data = encoder.fit_transform(data)
# Form and fit our decision tree to the now-encoded data
classifier = DecisionTreeClassifier()
tree = classifier.fit(x_data, target)
# Now that we have our decision tree, let's predict some outcomes from random data
# This goes through each class and builds a random set of 5 data points
prediction_data = []
for _ in itertools.repeat(None, 5):
prediction_data.append([
random.choice(classes['supplies']),
random.choice(classes['weather']),
random.choice(classes['worked?'])
])
# Use our tree to predict the outcome of the random values
prediction_results = tree.predict(encoder.transform(prediction_data))
# =============================================================================
# Output code
def format_array(arr):
return "".join(["| {:<10}".format(item) for item in arr])
def print_table(data, results):
line = "day " + format_array(list(classes.keys()) + ["went shopping?"])
print("-" * len(line))
print(line)
print("-" * len(line))
for day, row in enumerate(data):
print("{:<5}".format(day + 1) + format_array(row + [results[day]]))
print("")
feature_names = (
['supplies-' + x for x in classes["supplies"]] +
['weather-' + x for x in classes["weather"]] +
['worked-' + x for x in classes["worked?"]]
)
# Shows a visualization of the decision tree using graphviz
# Note that sklearn is unable to generate non-binary trees, so these are based on individual options in each class
dot_data = export_graphviz(tree, filled=True, proportion=True, feature_names=feature_names)
graph = graphviz.Source(dot_data)
graph.render(filename='decision_tree', cleanup=True, view=True)
# Display out training and prediction data and results
print("Training Data:")
print_table(data, target)
print("Predicted Random Results:")
print_table(prediction_data, prediction_results)