diff --git a/evaluate.py b/evaluate.py index 6ddb026..f8f35c7 100644 --- a/evaluate.py +++ b/evaluate.py @@ -4,7 +4,7 @@ import transform, numpy as np, vgg, pdb, os import scipy.misc import tensorflow as tf -from utils import save_img, get_img, exists, list_files +from utils import save_img, get_img, exists, list_files, load_checkpoint from argparse import ArgumentParser from collections import defaultdict import time @@ -35,15 +35,8 @@ def ffwd_video(path_in, path_out, checkpoint_dir, device_t='/gpu:0', batch_size= name='img_placeholder') preds = transform.net(img_placeholder) - saver = tf.train.Saver() - if os.path.isdir(checkpoint_dir): - ckpt = tf.train.get_checkpoint_state(checkpoint_dir) - if ckpt and ckpt.model_checkpoint_path: - saver.restore(sess, ckpt.model_checkpoint_path) - else: - raise Exception("No checkpoint found...") - else: - saver.restore(sess, checkpoint_dir) + if not load_checkpoint(sess, checkpoint_dir): + raise Exception("No checkpoint found...") X = np.zeros(batch_shape, dtype=np.float32) @@ -91,15 +84,8 @@ def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4): name='img_placeholder') preds = transform.net(img_placeholder) - saver = tf.train.Saver() - if os.path.isdir(checkpoint_dir): - ckpt = tf.train.get_checkpoint_state(checkpoint_dir) - if ckpt and ckpt.model_checkpoint_path: - saver.restore(sess, ckpt.model_checkpoint_path) - else: - raise Exception("No checkpoint found...") - else: - saver.restore(sess, checkpoint_dir) + if not load_checkpoint(sess, checkpoint_dir): + raise Exception("No checkpoint found...") num_iters = int(len(paths_out)/batch_size) for i in range(num_iters): diff --git a/src/optimize.py b/src/optimize.py index a1cd174..7e836df 100644 --- a/src/optimize.py +++ b/src/optimize.py @@ -3,7 +3,7 @@ import vgg, pdb, time import tensorflow as tf, numpy as np, os import transform -from utils import get_img +from utils import get_img, load_checkpoint STYLE_LAYERS = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1') CONTENT_LAYER = 'relu4_2' @@ -89,7 +89,12 @@ def optimize(content_targets, style_target, content_weight, style_weight, # overall loss train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss) - sess.run(tf.global_variables_initializer()) + + # Load an existing checkpoint, if one exists + checkpoint_dir = "/".join(save_path.split("/")[:-1]) + if not load_checkpoint(sess, checkpoint_dir): + sess.run(tf.global_variables_initializer()) + import random uid = random.randint(1, 100) print("UID: %s" % uid) diff --git a/src/utils.py b/src/utils.py index 36080f2..c197ec4 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,4 +1,5 @@ import scipy.misc, numpy as np, os, sys +import tensorflow as tf def save_img(out_path, img): img = np.clip(img, 0, 255).astype(np.uint8) @@ -31,3 +32,15 @@ def list_files(in_path): return files +def load_checkpoint(sess, checkpoint_dir): + saver = tf.train.Saver() + if os.path.isdir(checkpoint_dir): + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: + saver.restore(sess, ckpt.model_checkpoint_path) + return True + else: + return False + else: + saver.restore(sess, checkpoint_dir) + return True