-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathInference_blue.py
More file actions
86 lines (74 loc) · 5.44 KB
/
Inference_blue.py
File metadata and controls
86 lines (74 loc) · 5.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Run prediction and genertae pixelwise annotation for every pixels in the image using fully coonvolutional neural net
# Output saved as label images, and label image overlay on the original image
# 1) Make sure you you have trained model in logs_dir (See Train.py for creating trained model)
# 2) Set the Image_Dir to the folder where the input image for prediction are located
# 3) Set number of classes number in NUM_CLASSES
# 4) Set Pred_Dir the folder where you want the output annotated images to be save
# 5) Run script
#--------------------------------------------------------------------------------------------------------------------
import tensorflow as tf
import numpy as np
import scipy.misc as misc
import sys
import BuildNetVgg16
import TensorflowUtils
import os
import Data_Reader
import OverrlayLabelOnImage as Overlay
import CheckVGG16Model
logs_dir= "logs/"# "path to logs directory where trained model and information will be stored"
Image_Dir="Data_Zoo/test/"# Test image folder
w=0.6# weight of overlay on image
Pred_Dir="Output/" # Library where the output prediction will be written
model_path="Model_Zoo/vgg16.npy"# "Path to pretrained vgg16 model for encoder"
NameEnd="" # Add this string to the ending of the file name optional
NUM_CLASSES = 4 # Number of classes
#-------------------------------------------------------------------------------------------------------------------------
CheckVGG16Model.CheckVGG16(model_path)# Check if pretrained vgg16 model avialable and if not try to download it
################################################################################################################################################################################
def main(argv=None):
# .........................Placeholders for input image and labels........................................................................
keep_prob = tf.placeholder(tf.float32, name="keep_probabilty") # Dropout probability
image = tf.placeholder(tf.float32, shape=[None, None, None, 87], name="input_image") # Input image batch first dimension image number second dimension width third dimension height 4 dimension RGB
# -------------------------Build Net----------------------------------------------------------------------------------------------
Net = BuildNetVgg16.BUILD_NET_VGG16(vgg16_npy_path=model_path) # Create class instance for the net
Net.build(image, NUM_CLASSES, keep_prob) # Build net and load intial weights (weights before training)
# -------------------------Data reader for validation/testing images-----------------------------------------------------------------------------------------------------------------------------
ValidReader = Data_Reader.Data_Reader(Image_Dir, BatchSize=1)
#-------------------------Load Trained model if you dont have trained model see: Train.py-----------------------------------------------------------------------------------------------------------------------------
sess = tf.Session() #Start Tensorflow session
print("Setting up Saver...")
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(logs_dir)
if ckpt and ckpt.model_checkpoint_path: # if train model exist restore it
saver.restore(sess, ckpt.model_checkpoint_path)
print("Model restored...")
else:
print("ERROR NO TRAINED MODEL IN: "+ckpt.model_checkpoint_path+" See Train.py for creating train network ")
sys.exit()
#--------------------Create output directories for predicted label, one folder for each granulairy of label prediciton---------------------------------------------------------------------------------------------------------------------------------------------
if not os.path.exists(Pred_Dir): os.makedirs(Pred_Dir)
if not os.path.exists(Pred_Dir+"/OverLay"): os.makedirs(Pred_Dir+"/OverLay")
if not os.path.exists(Pred_Dir + "/Label"): os.makedirs(Pred_Dir + "/Label")
print("Running Predictions:")
print("Saving output to:" + Pred_Dir)
#----------------------Go over all images and predict semantic segmentation in various of classes-------------------------------------------------------------
fim = 0
print("Start Predicting " + str(ValidReader.NumFiles) + " images")
while (ValidReader.itr < ValidReader.NumFiles):
print(str(fim * 100.0 / ValidReader.NumFiles) + "%")
fim += 1
# ..................................Load image.......................................................................................
FileName=ValidReader.OrderedFiles[ValidReader.itr] #Get input image name
Images = ValidReader.ReadNextBatchClean() # load testing image
# Img01 = []
# Img01 = Images[0]
# Predict annotation using net
LabelPred = sess.run(Net.Pred, feed_dict={image: Images, keep_prob: 1.0})
#------------------------Save predicted labels overlay on images---------------------------------------------------------------------------------------------
# misc.imsave(Pred_Dir + "/OverLay/"+ FileName[:-4] + ".png" + NameEnd, Overlay.OverLayLabelOnImage(Img01,LabelPred[0], w)) #Overlay label on image
misc.imsave(Pred_Dir + "/Label/" + FileName[:-4] + ".png" + NameEnd, LabelPred[0].astype(np.uint8))
##################################################################################################################################################
main()#Run script
print("Finished")