Author: Not specified Language: text
Description: Not specified Timestamp: 2018-01-15 23:23:23 +0000
View raw paste Reply
  1. import numpy as np
  2. import argparse
  3. from scipy import misc
  4.  
  5. #caffe_root = '/home/amirul/weighted-deeplabv2/'
  6. caffe_root = '/home/amirul/deeplab-public-ver2/'
  7. import sys
  8. sys.path.insert(0, caffe_root + 'python')
  9.  
  10. import caffe
  11.  
  12. # Import arguments
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument('--model', type=str, required=True)
  15. parser.add_argument('--weights', type=str, required=True)
  16. parser.add_argument('--iter', type=int, required=True)
  17. args = parser.parse_args()
  18.  
  19. caffe.set_mode_gpu()
  20. caffe.set_device(0)
  21.  
  22. net = caffe.Net(args.model,
  23.                 args.weights,
  24.                 caffe.TEST)
  25.  
  26. fname = '/mnt/vana/amirul/projects/semantics_saliency/data/pascals/test.txt'    ##### For val data
  27.  
  28. #fname = '/mnt/vana/amirul/projects/cvpr2017_seg/data/Pascal/test.txt'   #### For test data
  29.  
  30. with open(fname) as f:
  31.     labelFiles = f.read().splitlines()
  32.  
  33.  
  34. for i in range(0, args.iter):
  35.  
  36.     net.forward()
  37.    
  38.    
  39.     labelFile = labelFiles[i].split(' ')[1]
  40.     labelFile_sal = labelFiles[i].split(' ')[2]
  41.     label = misc.imread(labelFile)
  42.    
  43.     for s in range(0,1):
  44.        
  45.         pred_sal = net.blobs['prob_sal'].data
  46.         pred_sal = np.squeeze(pred_sal[0,:,:,:])
  47.         #ind_sal = np.argmax(pred_sal, axis=0)
  48.         ind_sal = pred_sal[1,:,:]
  49.         ind_sal = ind_sal[0:label.shape[0], 0:label.shape[1]]
  50.  
  51.        
  52.        
  53.        
  54.  
  55.         labelname = labelFile.split('Aug/')
  56.         labelname_sal = labelFile_sal.split('gt/')
  57.  
  58.        
  59.         #misc.toimage(ind_seg, cmin=0.0, cmax=255).save('/mnt/vana/amirul/projects/semantics_saliency/predictions/pascals/semantic/resnet101/' + labelname[1])
  60.         misc.toimage(ind_sal, cmin=0.0, cmax=255).save('/mnt/vana/amirul/projects/semantics_saliency/predictions/pascals/saliency/resnet101/' + labelname_sal[1])
  61.        
  62.        
  63.         print 'Processed: ', labelname[1]
  64.        
  65.            
  66. print 'Success!'
  67.  
View raw paste Reply