import numpy as np
import argparse
from scipy import misc
#caffe_root = '/home/amirul/weighted-deeplabv2/'
caffe_root = '/home/amirul/deeplab-public-ver2/'
import sys
sys.path.insert(0, caffe_root + 'python')
import caffe
# Import arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True)
parser.add_argument('--weights', type=str, required=True)
parser.add_argument('--iter', type=int, required=True)
args = parser.parse_args()
caffe.set_mode_gpu()
caffe.set_device(0)
net = caffe.Net(args.model,
args.weights,
caffe.TEST)
fname = '/mnt/vana/amirul/projects/semantics_saliency/data/pascals/test.txt' ##### For val data
#fname = '/mnt/vana/amirul/projects/cvpr2017_seg/data/Pascal/test.txt' #### For test data
with open(fname) as f:
labelFiles = f.read().splitlines()
for i in range(0, args.iter):
net.forward()
labelFile = labelFiles[i].split(' ')[1]
labelFile_sal = labelFiles[i].split(' ')[2]
label = misc.imread(labelFile)
for s in range(0,1):
pred_sal = net.blobs['prob_sal'].data
pred_sal = np.squeeze(pred_sal[0,:,:,:])
#ind_sal = np.argmax(pred_sal, axis=0)
ind_sal = pred_sal[1,:,:]
ind_sal = ind_sal[0:label.shape[0], 0:label.shape[1]]
labelname = labelFile.split('Aug/')
labelname_sal = labelFile_sal.split('gt/')
#misc.toimage(ind_seg, cmin=0.0, cmax=255).save('/mnt/vana/amirul/projects/semantics_saliency/predictions/pascals/semantic/resnet101/' + labelname[1])
misc.toimage(ind_sal, cmin=0.0, cmax=255).save('/mnt/vana/amirul/projects/semantics_saliency/predictions/pascals/saliency/resnet101/' + labelname_sal[1])
print 'Processed: ', labelname[1]
print 'Success!'