-
Notifications
You must be signed in to change notification settings - Fork 0
/
gui.py
77 lines (62 loc) · 2.51 KB
/
gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import tkinter as tk
import torch
import torch.nn.functional as F
from cnn import model
from datasets import transform_grayscale
from torchvision.transforms import ToPILImage
from datasets import val_dataset_grayscale
canvas_tensor = torch.zeros(128, 128)
model.load_state_dict(torch.load("mnist_cnn.pth"))
predicted = None
def whatModelSees():
resized_tensor = F.interpolate(canvas_tensor.unsqueeze(0).unsqueeze(0), size=(28, 28), mode='bilinear', align_corners=False).squeeze(0)
image = ToPILImage()(resized_tensor)
model_input = transform_grayscale(image).unsqueeze(0)
image = ToPILImage()(model_input.squeeze(0))
image.show()
def showExample():
image, _ = val_dataset_grayscale[0]
image = ToPILImage()(image)
image.show()
def update_label():
output_label.config(text=f"Predicted: {predicted if predicted is not None else 'None'}")
print(f"Predicted: {predicted if predicted is not None else 'None'}")
def update_tensor(x, y, value):
for i in range(max(0, x*4-2), min(128, x*4+2)):
for j in range(max(0, y*4-2), min(128, y*4+2)):
canvas_tensor[j, i] = value
resized_tensor = F.interpolate(canvas_tensor.unsqueeze(0).unsqueeze(0), size=(28, 28), mode='bilinear', align_corners=False).squeeze(0)
model_input = transform_grayscale(ToPILImage()(resized_tensor)).unsqueeze(0)
model_output = model(model_input)
global predicted
predicted = model_output.argmax(1).item()
update_label()
def draw(event):
x, y = (event.x // 16, event.y // 16) # Adjust for 128x128 canvas
canvas.create_rectangle(x*16, y*16, (x+1)*16, (y+1)*16, fill="white", outline="white")
update_tensor(x, y, 1)
def erase(event):
x, y = (event.x // 16, event.y // 16) # Adjust for 128x128 canvas
canvas.create_rectangle(x*16, y*16, (x+1)*16, (y+1)*16, fill="black", outline="black")
update_tensor(x, y, 0)
def clear():
canvas.delete("all")
canvas_tensor.zero_()
update_label()
app = tk.Tk()
app.title("128x128 Drawing Canvas")
app.geometry("800x800")
frame = tk.Frame(app, width=128*16, height=128*16, bg="white")
frame.pack_propagate(False)
frame.pack()
canvas = tk.Canvas(frame, width=128*16, height=128*16, bg="black")
canvas.pack()
clear_button = tk.Button(app, text="Clear", command=clear)
clear_button.pack()
canvas.bind("<B1-Motion>", draw)
canvas.bind("<B3-Motion>", erase)
font = ('Helvetica', 24)
output_label = tk.Label(app, text="Predicted: None", font=font)
output_label.pack(side=tk.TOP, pady=10)
app.after(5000, whatModelSees)
app.mainloop()