-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualization.py
More file actions
112 lines (99 loc) · 3.25 KB
/
visualization.py
File metadata and controls
112 lines (99 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import PySimpleGUI as sg
import os.path
from model import ColorizationNet
from utils import *
import torch
img_size = (350, 350)
img_box_size = (800, 350)
image_orig_str = "-IMAGE_ORIG-"
image_pred_str = "-IMAGE-PRED-"
layout = [[sg.Text("Automatic Image Colorization")]]
file_list_column = [
[
sg.Text("Select loss type"),
sg.DropDown(['classification', 'regression'], key="-LOSS-", enable_events=True),
],
[
sg.Text("Select model"),
sg.In(size=(25, 1), enable_events=True, key="-MODEL-", disabled=True),
sg.FileBrowse(disabled=True, key='-MODEL_BROWSE-'),
],
[
sg.Text("Image Folder"),
sg.In(size=(25, 1), enable_events=True, key="-FOLDER-"),
sg.FolderBrowse(),
],
[
sg.Listbox(
values=[], enable_events=True, size=(40, 20), key="-FILE LIST-"
)
],
[
sg.Text('', key="-LOG-", size=(40, 2))
]
]
image_viewer_column_original = [
[sg.Text("True Image")],
[sg.Image(size=img_size, key=image_orig_str)]
]
image_viewer_column_pred = [
[sg.Text("Predicted Colorization")],
[sg.Image(size=img_size, key=image_pred_str)]
]
# ----- Full layout -----
layout = [
[
sg.Column(file_list_column),
sg.VSeperator(),
sg.Column(image_viewer_column_original),
sg.Column(image_viewer_column_pred)
]
]
window = sg.Window("Image Viewer", layout)
# Run the Event Loop
while True:
event, values = window.read()
if event == "Exit" or event == sg.WIN_CLOSED:
break
# Folder name was filled in, make a list of files in the folder
if event == "-FOLDER-":
folder = values["-FOLDER-"]
try:
# Get list of files in folder
file_list = os.listdir(folder)
except:
file_list = []
fnames = [
f
for f in file_list
if os.path.isfile(os.path.join(folder, f))
and f.lower().endswith((".jpg", ".png", ".gif"))
]
window["-FILE LIST-"].update(fnames)
elif event == "-MODEL-":
# a model has been selected
# load model
checkpoint = values["-MODEL-"]
if checkpoint == '':
continue
checkpoint = torch.load(checkpoint, map_location=torch.device('cpu'))
model = ColorizationNet(values["-LOSS-"])
try:
model.load_state_dict(checkpoint['model_state_dict'])
window["-LOG-"].update("Model correctly loaded.")
except:
window["-LOG-"].update('Error loading model: did you select the correct loss type?')
elif event == "-FILE LIST-": # A file was chosen from the listbox
# try:
filename = os.path.join(
values["-FOLDER-"], values["-FILE LIST-"][0]
)
window[image_orig_str].update(data=get_img(filename))
img_pred_tk = get_img_prediction_as_tk(model, filename, img_size)
window[image_pred_str].update(data=img_pred_tk)
# except Exception as e:
# print(e)
elif event == "-LOSS-":
window['-MODEL-'].update(disabled=False)
window['-MODEL_BROWSE-'].update(disabled=False)
window.close()