qijie.wei commited on
Commit
6735c2f
·
1 Parent(s): c5f4ee2
Files changed (1) hide show
  1. inference.py +4 -3
inference.py CHANGED
@@ -28,7 +28,7 @@ class Inference(object):
28
  def inference(self, image, modality):
29
  assert modality in self.modality_mapping, "Modality '{}' not supported".format(modality)
30
 
31
- image = self.load_image(image)
32
  modality_idx = self.modality_mapping[modality]
33
  modality_idx = torch.tensor([modality_idx])
34
  with torch.no_grad():
@@ -36,19 +36,20 @@ class Inference(object):
36
  output = output.data.cpu().numpy()[0][0]
37
  output = sigmoid(output) * 255
38
  output = output.astype(np.uint8)
 
39
  return output
40
 
41
  def load_image(self, image):
42
  # Load the image and preprocess it
43
  if isinstance(image, str):
44
  image = cv2.imread(image)[:, :, [2, 1, 0]]
45
- #image = image
46
  image = cv2.resize(image, (self.model_params['size_w'], self.model_params['size_h']))
47
  image = image.astype(np.float32) / 255.0
48
  image = np.transpose(image, (2, 0, 1))
49
  image = np.expand_dims(image, axis=0)
50
  image = torch.tensor(image)
51
- return image
52
 
53
  def load_model(self):
54
  print('Loading model from {}'.format(self.model_path))
 
28
  def inference(self, image, modality):
29
  assert modality in self.modality_mapping, "Modality '{}' not supported".format(modality)
30
 
31
+ image, raw_h, raw_w = self.load_image(image)
32
  modality_idx = self.modality_mapping[modality]
33
  modality_idx = torch.tensor([modality_idx])
34
  with torch.no_grad():
 
36
  output = output.data.cpu().numpy()[0][0]
37
  output = sigmoid(output) * 255
38
  output = output.astype(np.uint8)
39
+ output = cv2.resize(output, (raw_w, raw_h))
40
  return output
41
 
42
  def load_image(self, image):
43
  # Load the image and preprocess it
44
  if isinstance(image, str):
45
  image = cv2.imread(image)[:, :, [2, 1, 0]]
46
+ raw_h, raw_w = image.shape[:2]
47
  image = cv2.resize(image, (self.model_params['size_w'], self.model_params['size_h']))
48
  image = image.astype(np.float32) / 255.0
49
  image = np.transpose(image, (2, 0, 1))
50
  image = np.expand_dims(image, axis=0)
51
  image = torch.tensor(image)
52
+ return image, raw_h, raw_w
53
 
54
  def load_model(self):
55
  print('Loading model from {}'.format(self.model_path))