Spaces:
Runtime error
Runtime error
| from collections import OrderedDict | |
| import torch | |
| from models.model import GLPDepth | |
| from PIL import Image | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| DEVICE='cpu' | |
| def load_mde_model(path): | |
| model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE) | |
| model_weight = torch.load(path, map_location=torch.device('cpu')) | |
| model_weight = model_weight['model_state_dict'] | |
| if 'module' in next(iter(model_weight.items()))[0]: | |
| model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items()) | |
| model.load_state_dict(model_weight) | |
| model.eval() | |
| return model | |
| model = load_mde_model('best_model.ckpt') | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor() | |
| ]) | |
| input_img = Image.open('demo_imgs/fake.jpg') | |
| torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0) | |
| with torch.no_grad(): | |
| output_patch = model(torch_img) | |
| output_patch = output_patch['pred_d'].squeeze().cpu().detach().numpy() | |
| print(output_patch.shape) | |
| plt.imshow(output_patch, cmap='jet', vmin=0, vmax=np.max(output_patch)) | |
| plt.colorbar() | |
| plt.savefig('test.png') |