-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining.py
More file actions
79 lines (61 loc) · 3.02 KB
/
training.py
File metadata and controls
79 lines (61 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Mute tensorflow debugging information console
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import argparse
import sys
def parse_arguments():
parser = argparse.ArgumentParser(usage='A training script for the Live Whiteboard Coding neural network')
parser.add_argument('--datasets', nargs='+', type=str, help='datasets to be used for training', required=True)
parser.add_argument('-o', '--output', type=str, help='output directory for the model(without /)', default='bin')
parser.add_argument('--height', type=int, default=28, help='height of the input image')
parser.add_argument('--width', type=int, default=28, help='width of the input image')
parser.add_argument('-e', '--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-g', '--gpus', type=int, default=1, help='number of gpus to be used')
parser.add_argument('-b', '--batch', type=int, default=64, help='batch size for training')
parser.add_argument('-d', '--device', type=str, default='/cpu:0', help='device to be used for training')
parser.add_argument('-m', '--model', type=str, default='convolutional', help='keras model to be trained')
parser.add_argument('-p', '--parallel', action='store_true', default=False, help='use multi gpu model')
parser.add_argument('-v', '--verbose', type=int, default=1, help='verbose level for epochs')
return parser.parse_args()
def main():
if args.output[0] is '/':
print('Please make sure that the output directory has no leading \'/\'')
sys.exit(1)
if args.parallel:
print("Using the multi gpu models. The target device will be neglected.")
bin_dir = os.path.dirname(os.path.realpath(__file__)) + '/' + args.output
if not os.path.exists(bin_dir):
os.makedirs(bin_dir)
training_data = load_data(args.datasets)
model, parallel_model = build_model(training_data=training_data,
model_id=args.model,
height=args.height,
width=args.width,
multi_gpu=args.parallel,
gpus=args.gpus)
if not model:
print('Model {} does not exist.'.format(args.model))
sys.exit(1)
if parallel_model:
train(parallel_model,
training_data,
epochs=args.epochs,
batch_size=args.batch,
device=args.device,
parallel=args.parallel,
verbose=args.verbose)
else:
train(model,
training_data,
epochs=args.epochs,
batch_size=args.batch,
device=args.device,
parallel=args.parallel,
verbose=args.verbose)
save_model_to_file(model, args.output)
if __name__ == '__main__':
args = parse_arguments()
from utils.dataset import load_data
from utils.model import build_model, save_model_to_file
from utils.train import train
main()