Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import transform, numpy as np, vgg, pdb, os
import scipy.misc
import tensorflow as tf
from utils import save_img, get_img, exists, list_files
from utils import save_img, get_img, exists, list_files, load_checkpoint
from argparse import ArgumentParser
from collections import defaultdict
import time
Expand Down Expand Up @@ -35,15 +35,8 @@ def ffwd_video(path_in, path_out, checkpoint_dir, device_t='/gpu:0', batch_size=
name='img_placeholder')

preds = transform.net(img_placeholder)
saver = tf.train.Saver()
if os.path.isdir(checkpoint_dir):
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise Exception("No checkpoint found...")
else:
saver.restore(sess, checkpoint_dir)
if not load_checkpoint(sess, checkpoint_dir):
raise Exception("No checkpoint found...")

X = np.zeros(batch_shape, dtype=np.float32)

Expand Down Expand Up @@ -91,15 +84,8 @@ def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
name='img_placeholder')

preds = transform.net(img_placeholder)
saver = tf.train.Saver()
if os.path.isdir(checkpoint_dir):
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise Exception("No checkpoint found...")
else:
saver.restore(sess, checkpoint_dir)
if not load_checkpoint(sess, checkpoint_dir):
raise Exception("No checkpoint found...")

num_iters = int(len(paths_out)/batch_size)
for i in range(num_iters):
Expand Down
9 changes: 7 additions & 2 deletions src/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import vgg, pdb, time
import tensorflow as tf, numpy as np, os
import transform
from utils import get_img
from utils import get_img, load_checkpoint

STYLE_LAYERS = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1')
CONTENT_LAYER = 'relu4_2'
Expand Down Expand Up @@ -89,7 +89,12 @@ def optimize(content_targets, style_target, content_weight, style_weight,

# overall loss
train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
sess.run(tf.global_variables_initializer())

# Load an existing checkpoint, if one exists
checkpoint_dir = "/".join(save_path.split("/")[:-1])
if not load_checkpoint(sess, checkpoint_dir):
sess.run(tf.global_variables_initializer())

import random
uid = random.randint(1, 100)
print("UID: %s" % uid)
Expand Down
13 changes: 13 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import scipy.misc, numpy as np, os, sys
import tensorflow as tf

def save_img(out_path, img):
img = np.clip(img, 0, 255).astype(np.uint8)
Expand Down Expand Up @@ -31,3 +32,15 @@ def list_files(in_path):

return files

def load_checkpoint(sess, checkpoint_dir):
saver = tf.train.Saver()
if os.path.isdir(checkpoint_dir):
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
return True
else:
return False
else:
saver.restore(sess, checkpoint_dir)
return True