Spaces:
Runtime error
Runtime error
Update inference_2.py
Browse files- inference_2.py +3 -3
inference_2.py
CHANGED
|
@@ -10,7 +10,7 @@ from models import image
|
|
| 10 |
|
| 11 |
from onnx2pytorch import ConvertModel
|
| 12 |
|
| 13 |
-
onnx_model = onnx.load('
|
| 14 |
pytorch_model = ConvertModel(onnx_model)
|
| 15 |
|
| 16 |
torch.manual_seed(42)
|
|
@@ -65,14 +65,14 @@ def get_args(parser):
|
|
| 65 |
|
| 66 |
def load_img_modality_model(args):
|
| 67 |
rgb_encoder = pytorch_model
|
| 68 |
-
ckpt = torch.load('
|
| 69 |
rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
|
| 70 |
rgb_encoder.eval()
|
| 71 |
return rgb_encoder
|
| 72 |
|
| 73 |
def load_spec_modality_model(args):
|
| 74 |
spec_encoder = image.RawNet(args)
|
| 75 |
-
ckpt = torch.load('
|
| 76 |
spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
|
| 77 |
spec_encoder.eval()
|
| 78 |
return spec_encoder
|
|
|
|
| 10 |
|
| 11 |
from onnx2pytorch import ConvertModel
|
| 12 |
|
| 13 |
+
onnx_model = onnx.load('models/efficientnet.onnx')
|
| 14 |
pytorch_model = ConvertModel(onnx_model)
|
| 15 |
|
| 16 |
torch.manual_seed(42)
|
|
|
|
| 65 |
|
| 66 |
def load_img_modality_model(args):
|
| 67 |
rgb_encoder = pytorch_model
|
| 68 |
+
ckpt = torch.load('models/model.pth', map_location = torch.device('cpu'))
|
| 69 |
rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
|
| 70 |
rgb_encoder.eval()
|
| 71 |
return rgb_encoder
|
| 72 |
|
| 73 |
def load_spec_modality_model(args):
|
| 74 |
spec_encoder = image.RawNet(args)
|
| 75 |
+
ckpt = torch.load('models/model.pth', map_location = torch.device('cpu'))
|
| 76 |
spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
|
| 77 |
spec_encoder.eval()
|
| 78 |
return spec_encoder
|