From 55d88b58bd236ca4759727143aad7a8f85b19c5a Mon Sep 17 00:00:00 2001 From: Lane Kolbly Date: Mon, 8 Apr 2019 18:19:10 -0500 Subject: [PATCH 1/2] Try loading existing checkpoint during training --- src/optimize.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/optimize.py b/src/optimize.py index a1cd174..4929f5b 100644 --- a/src/optimize.py +++ b/src/optimize.py @@ -89,7 +89,20 @@ def optimize(content_targets, style_target, content_weight, style_weight, # overall loss train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss) - sess.run(tf.global_variables_initializer()) + + # Load an existing checkpoint, if one exists + saver = tf.train.Saver() + didLoad = False + save_dir = "/".join(save_path.split("/")[:-1]) + if os.path.isdir(save_dir): + ckpt = tf.train.get_checkpoint_state(save_dir) + if ckpt and ckpt.model_checkpoint_path: + saver.restore(sess, ckpt.model_checkpoint_path) + print("Loaded saved state") + didLoad = True + if not didLoad: + sess.run(tf.global_variables_initializer()) + import random uid = random.randint(1, 100) print("UID: %s" % uid) From e130c9c9abba8bfd83ebf7bedf665818ea809fcc Mon Sep 17 00:00:00 2001 From: Lane Kolbly Date: Mon, 8 Apr 2019 18:47:22 -0500 Subject: [PATCH 2/2] Factor loading checkpoints into utils --- evaluate.py | 24 +++++------------------- src/optimize.py | 14 +++----------- src/utils.py | 13 +++++++++++++ 3 files changed, 21 insertions(+), 30 deletions(-) diff --git a/evaluate.py b/evaluate.py index 6ddb026..f8f35c7 100644 --- a/evaluate.py +++ b/evaluate.py @@ -4,7 +4,7 @@ import transform, numpy as np, vgg, pdb, os import scipy.misc import tensorflow as tf -from utils import save_img, get_img, exists, list_files +from utils import save_img, get_img, exists, list_files, load_checkpoint from argparse import ArgumentParser from collections import defaultdict import time @@ -35,15 +35,8 @@ def ffwd_video(path_in, path_out, checkpoint_dir, device_t='/gpu:0', batch_size= name='img_placeholder') preds = transform.net(img_placeholder) - saver = tf.train.Saver() - if os.path.isdir(checkpoint_dir): - ckpt = tf.train.get_checkpoint_state(checkpoint_dir) - if ckpt and ckpt.model_checkpoint_path: - saver.restore(sess, ckpt.model_checkpoint_path) - else: - raise Exception("No checkpoint found...") - else: - saver.restore(sess, checkpoint_dir) + if not load_checkpoint(sess, checkpoint_dir): + raise Exception("No checkpoint found...") X = np.zeros(batch_shape, dtype=np.float32) @@ -91,15 +84,8 @@ def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4): name='img_placeholder') preds = transform.net(img_placeholder) - saver = tf.train.Saver() - if os.path.isdir(checkpoint_dir): - ckpt = tf.train.get_checkpoint_state(checkpoint_dir) - if ckpt and ckpt.model_checkpoint_path: - saver.restore(sess, ckpt.model_checkpoint_path) - else: - raise Exception("No checkpoint found...") - else: - saver.restore(sess, checkpoint_dir) + if not load_checkpoint(sess, checkpoint_dir): + raise Exception("No checkpoint found...") num_iters = int(len(paths_out)/batch_size) for i in range(num_iters): diff --git a/src/optimize.py b/src/optimize.py index 4929f5b..7e836df 100644 --- a/src/optimize.py +++ b/src/optimize.py @@ -3,7 +3,7 @@ import vgg, pdb, time import tensorflow as tf, numpy as np, os import transform -from utils import get_img +from utils import get_img, load_checkpoint STYLE_LAYERS = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1') CONTENT_LAYER = 'relu4_2' @@ -91,16 +91,8 @@ def optimize(content_targets, style_target, content_weight, style_weight, train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss) # Load an existing checkpoint, if one exists - saver = tf.train.Saver() - didLoad = False - save_dir = "/".join(save_path.split("/")[:-1]) - if os.path.isdir(save_dir): - ckpt = tf.train.get_checkpoint_state(save_dir) - if ckpt and ckpt.model_checkpoint_path: - saver.restore(sess, ckpt.model_checkpoint_path) - print("Loaded saved state") - didLoad = True - if not didLoad: + checkpoint_dir = "/".join(save_path.split("/")[:-1]) + if not load_checkpoint(sess, checkpoint_dir): sess.run(tf.global_variables_initializer()) import random diff --git a/src/utils.py b/src/utils.py index 36080f2..c197ec4 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,4 +1,5 @@ import scipy.misc, numpy as np, os, sys +import tensorflow as tf def save_img(out_path, img): img = np.clip(img, 0, 255).astype(np.uint8) @@ -31,3 +32,15 @@ def list_files(in_path): return files +def load_checkpoint(sess, checkpoint_dir): + saver = tf.train.Saver() + if os.path.isdir(checkpoint_dir): + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: + saver.restore(sess, ckpt.model_checkpoint_path) + return True + else: + return False + else: + saver.restore(sess, checkpoint_dir) + return True