Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import torch | |
| import argparse | |
| import numpy as np | |
| import open3d as o3d | |
| import os.path as osp | |
| from dust3r.losses import L21 | |
| from spann3r.model import Spann3R | |
| from dust3r.inference import inference | |
| from dust3r.utils.geometry import geotrf | |
| from dust3r.image_pairs import make_pairs | |
| from spann3r.loss import Regr3D_t_ScaleShiftInv | |
| from spann3r.datasets import * | |
| from torch.utils.data import DataLoader | |
| from spann3r.tools.eval_recon import accuracy, completion | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser('Spann3R evaluation', add_help=False) | |
| parser.add_argument('--exp_path', type=str, help='Path to experiment folder', default='./checkpoints') | |
| parser.add_argument('--exp_name', type=str, default='ckpt_best', help='Path to experiment folder') | |
| parser.add_argument('--ckpt', type=str, default='spann3r.pth', help='ckpt name') | |
| parser.add_argument('--scenegraph_type', type=str, default='complete', help='scenegraph type') | |
| parser.add_argument('--offline', action='store_true', help='offline reconstruction') | |
| parser.add_argument('--device', type=str, default='cuda:0', help='device') | |
| parser.add_argument('--conf_thresh', type=float, default=0.0, help='confidence threshold') | |
| return parser | |
| def main(args): | |
| workspace = args.exp_path | |
| ckpt_path = osp.join(workspace, args.ckpt) | |
| if not osp.exists(workspace): | |
| raise FileNotFoundError(f"Workspace {workspace} not found") | |
| exp_path = osp.join(workspace, args.exp_name) | |
| os.makedirs(exp_path, exist_ok=True) | |
| datasets_all = { | |
| '7scenes': SevenScenes(split='test', ROOT="./data/7scenes", | |
| resolution=224, num_seq=1, full_video=True, kf_every=20), | |
| 'NRGBD': NRGBD(split='test', ROOT="./data/neural_rgbd", | |
| resolution=224, num_seq=1, full_video=True, kf_every=40), | |
| 'DTU': DTU(split='test', ROOT="./data/dtu_test", | |
| resolution=224, num_seq=1, full_video=True, kf_every=5), | |
| } | |
| model = Spann3R(dus3r_name='./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth', | |
| use_feat=False).to(args.device) | |
| model.load_state_dict(torch.load(ckpt_path)['model']) | |
| model.eval() | |
| criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True) | |
| with torch.no_grad(): | |
| for name_data, dataset in datasets_all.items(): | |
| save_path = osp.join(exp_path, name_data) | |
| if args.offline: | |
| save_path = osp.join(save_path + '_offline') | |
| os.makedirs(save_path, exist_ok=True) | |
| log_file = osp.join(save_path, 'logs.txt') | |
| os.makedirs(save_path, exist_ok=True) | |
| acc_all = 0 | |
| acc_all_med = 0 | |
| comp_all = 0 | |
| comp_all_med = 0 | |
| nc1_all = 0 | |
| nc1_all_med = 0 | |
| nc2_all = 0 | |
| nc2_all_med = 0 | |
| fps_all = [] | |
| time_all = [] | |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) | |
| for i, batch in enumerate(dataloader): | |
| for view in batch: | |
| for name in 'img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres'.split(): # pseudo_focal | |
| if name not in view: | |
| continue | |
| view[name] = view[name].to(args.device, non_blocking=True) | |
| print(f'Started reconstruction for {name_data} {i+1}/{len(dataloader)}') | |
| if args.offline: | |
| imgs_all = [] | |
| for j, view in enumerate(batch): | |
| img = view['img'] | |
| shape1 = [img.size()[::-1]] | |
| imgs_all.append( | |
| dict( | |
| img=img, | |
| true_shape=torch.tensor(img.shape[2:]).unsqueeze(0), | |
| idx=j, | |
| instance=str(j) | |
| ) | |
| ) | |
| start = time.time() | |
| pairs = make_pairs(imgs_all, scene_graph=args.scenegraph_type, prefilter=None, symmetrize=True) | |
| output = inference(pairs, model.dust3r, args.device, batch_size=2, verbose=True) | |
| preds, preds_all, idx_used = model.offline_reconstruction(batch, output) | |
| end = time.time() | |
| ordered_batch = [batch[i] for i in idx_used] | |
| else: | |
| start = time.time() | |
| preds, preds_all = model.forward(batch) | |
| end = time.time() | |
| ordered_batch = batch | |
| fps = len(batch) / (end - start) | |
| print(f'Finished reconstruction for {name_data} {i+1}/{len(dataloader)}, FPS: {fps:.2f}') | |
| fps_all.append(fps) | |
| time_all.append(end - start) | |
| # Evaluation | |
| print(f'Evaluation for {name_data} {i+1}/{len(dataloader)}') | |
| gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = criterion.get_all_pts3d_t(ordered_batch, preds_all) | |
| pred_scale, gt_scale, pred_shift_z, gt_shift_z = monitoring['pred_scale'], monitoring['gt_scale'], monitoring['pred_shift_z'], monitoring['gt_shift_z'] | |
| in_camera1 = None | |
| pts_all = [] | |
| pts_gt_all = [] | |
| images_all = [] | |
| masks_all = [] | |
| conf_all = [] | |
| for j, view in enumerate(ordered_batch): | |
| if in_camera1 is None: | |
| in_camera1 = view['camera_pose'][0].cpu() | |
| image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0] | |
| mask = view['valid_mask'].cpu().numpy()[0] | |
| # pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0] | |
| pts = pred_pts[0][j].cpu().numpy()[0] if j < len(pred_pts[0]) else pred_pts[1][-1].cpu().numpy()[0] | |
| conf = preds[j]['conf'][0].cpu().data.numpy() | |
| pts_gt = gt_pts[j].detach().cpu().numpy()[0] | |
| #### Align predicted 3D points to the ground truth | |
| pts[..., -1] += gt_shift_z.cpu().numpy().item() | |
| pts = geotrf(in_camera1, pts) | |
| pts_gt[..., -1] += gt_shift_z.cpu().numpy().item() | |
| pts_gt = geotrf(in_camera1, pts_gt) | |
| images_all.append((image[None, ...] + 1.0)/2.0) | |
| pts_all.append(pts[None, ...]) | |
| pts_gt_all.append(pts_gt[None, ...]) | |
| masks_all.append(mask[None, ...]) | |
| conf_all.append(conf[None, ...]) | |
| images_all = np.concatenate(images_all, axis=0) | |
| pts_all = np.concatenate(pts_all, axis=0) | |
| pts_gt_all = np.concatenate(pts_gt_all, axis=0) | |
| masks_all = np.concatenate(masks_all, axis=0) | |
| conf_all = np.concatenate(conf_all, axis=0) | |
| scene_id = view['label'][0].rsplit('/', 1)[0] | |
| save_params = {} | |
| save_params['images_all'] = images_all | |
| save_params['pts_all'] = pts_all | |
| save_params['pts_gt_all'] = pts_gt_all | |
| save_params['masks_all'] = masks_all | |
| save_params['conf_all'] = conf_all | |
| np.save(os.path.join(save_path, f"{scene_id.replace('/', '_')}.npy"), save_params) | |
| if 'DTU' in name_data: | |
| threshold = 100 | |
| else: | |
| threshold = 0.1 | |
| pts_all_masked = pts_all[masks_all > 0] | |
| pts_gt_all_masked = pts_gt_all[masks_all > 0] | |
| images_all_masked = images_all[masks_all > 0] | |
| pcd = o3d.geometry.PointCloud() | |
| pcd.points = o3d.utility.Vector3dVector(pts_all_masked.reshape(-1, 3)) | |
| pcd.colors = o3d.utility.Vector3dVector(images_all_masked.reshape(-1, 3)) | |
| o3d.io.write_point_cloud(os.path.join(save_path, f"{scene_id.replace('/', '_')}-mask.ply"), pcd) | |
| pcd_gt = o3d.geometry.PointCloud() | |
| pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked.reshape(-1, 3)) | |
| pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked.reshape(-1, 3) / 255.0) | |
| o3d.io.write_point_cloud(os.path.join(save_path, f"{scene_id.replace('/', '_')}-gt.ply"), pcd_gt) | |
| trans_init = np.eye(4) | |
| reg_p2p = o3d.pipelines.registration.registration_icp( | |
| pcd, pcd_gt, threshold, trans_init, | |
| o3d.pipelines.registration.TransformationEstimationPointToPoint()) | |
| transformation = reg_p2p.transformation | |
| pcd = pcd.transform(transformation) | |
| pcd.estimate_normals() | |
| pcd_gt.estimate_normals() | |
| gt_normal = np.asarray(pcd_gt.normals) | |
| pred_normal = np.asarray(pcd.normals) | |
| acc, acc_med, nc1, nc1_med = accuracy(pcd_gt.points, pcd.points, gt_normal, pred_normal) | |
| comp, comp_med, nc2, nc2_med = completion(pcd_gt.points, pcd.points, gt_normal, pred_normal) | |
| print(f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}", file=open(log_file, "a")) | |
| acc_all += acc | |
| comp_all += comp | |
| nc1_all += nc1 | |
| nc2_all += nc2 | |
| acc_all_med += acc_med | |
| comp_all_med += comp_med | |
| nc1_all_med += nc1_med | |
| nc2_all_med += nc2_med | |
| # release cuda memory | |
| torch.cuda.empty_cache() | |
| print(f"Finished evaluation for {name_data} {i+1}/{len(dataloader)}") | |
| # Get depth from pcd and run TSDFusion | |
| print(f"Dataset: {name_data}, Accuracy: {acc_all/len(dataloader)}, Completion: {comp_all/len(dataloader)}, NC1: {nc1_all/len(dataloader)}, NC2: {nc2_all/len(dataloader)} - Acc_med: {acc_all_med/len(dataloader)}, Comp_med: {comp_all_med/len(dataloader)}, NC1_med: {nc1_all_med/len(dataloader)}, NC2_med: {nc2_all_med/len(dataloader)}", file=open(log_file, "a")) | |
| print(f"Average fps: {sum(fps) / len(fps)}, Average time: {sum(time_all) / len(time_all)}", file=open(log_file, "a")) | |
| if __name__ == '__main__': | |
| parser = get_args_parser() | |
| args = parser.parse_args() | |
| main(args) | |