abreza commited on
Commit
06b736b
·
1 Parent(s): f15498a

move wan pipeline to cuda in zero gpu inference time

Browse files
app.py CHANGED
@@ -103,7 +103,6 @@ wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained(
103
  )
104
  wan_pipeline.vae.enable_tiling()
105
  wan_pipeline.vae.enable_slicing()
106
- wan_pipeline.to("cuda")
107
 
108
 
109
 
@@ -218,7 +217,7 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
218
  video_input = preprocess_image(video_tensor)[None].cuda()
219
 
220
  with torch.no_grad():
221
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
222
  predictions = vggt4track_model(video_input / 255)
223
  extrinsic = predictions["poses_pred"]
224
  intrinsic = predictions["intrs"]
@@ -293,6 +292,9 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
293
  "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
294
  )
295
 
 
 
 
296
  # Match resolution logic from run_wan.py
297
  max_area = 480 * 832
298
  mod_value = wan_pipeline.vae_scale_factor_spatial * \
 
103
  )
104
  wan_pipeline.vae.enable_tiling()
105
  wan_pipeline.vae.enable_slicing()
 
106
 
107
 
108
 
 
217
  video_input = preprocess_image(video_tensor)[None].cuda()
218
 
219
  with torch.no_grad():
220
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
221
  predictions = vggt4track_model(video_input / 255)
222
  extrinsic = predictions["poses_pred"]
223
  intrinsic = predictions["intrs"]
 
292
  "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
293
  )
294
 
295
+
296
+ wan_pipeline.to("cuda")
297
+
298
  # Match resolution logic from run_wan.py
299
  max_area = 480 * 832
300
  mod_value = wan_pipeline.vae_scale_factor_spatial * \
app_3rd/spatrack_utils/infer_track.py CHANGED
@@ -34,13 +34,13 @@ def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=N
34
  """
35
  viz = True
36
  os.makedirs(output_dir, exist_ok=True)
37
-
38
  with open(config["cfg_dir"], "r") as f:
39
  cfg = yaml.load(f, Loader=yaml.FullLoader)
40
  cfg = easydict.EasyDict(cfg)
41
  cfg.out_dir = output_dir
42
  cfg.model.track_num = vo_points
43
-
44
  # Check if it's a local path or HuggingFace repo
45
  if tracker_model is not None:
46
  model = tracker_model
@@ -60,8 +60,8 @@ def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=N
60
  model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
61
  model.eval()
62
  model.to("cuda")
63
-
64
- viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
65
  fps=10, pad_value=0, tracks_leave_trace=5)
66
 
67
  return model, viser
@@ -83,11 +83,11 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
83
  mask_path = os.path.join(temp_dir, f"{video_name}.png")
84
  out_dir = os.path.join(temp_dir, "results")
85
  os.makedirs(out_dir, exist_ok=True)
86
-
87
  # Load video using decord
88
  video_reader = decord.VideoReader(video_path)
89
  video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
90
-
91
  # resize make sure the shortest side is 336
92
  h, w = video_tensor.shape[2:]
93
  scale = max(336 / h, 336 / w)
@@ -99,7 +99,7 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
99
  intrs = None
100
  extrs = None
101
  data_npz_load = {}
102
-
103
  # Load and process mask
104
  if os.path.exists(mask_path):
105
  mask = cv2.imread(mask_path)
@@ -107,20 +107,20 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
107
  mask = mask.sum(axis=-1)>0
108
  else:
109
  mask = np.ones_like(video_tensor[0,0].numpy())>0
110
-
111
  # Get frame dimensions and create grid points
112
  frame_H, frame_W = video_tensor.shape[2:]
113
  grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
114
-
115
  # Sample mask values at grid points and filter out points where mask=0
116
  if os.path.exists(mask_path):
117
  grid_pts_int = grid_pts[0].long()
118
  mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
119
  grid_pts = grid_pts[:, mask_values]
120
-
121
  query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
122
 
123
- # run vggt
124
  if os.environ.get("VGGT_DIR", None) is not None:
125
  vggt_model = VGGT()
126
  vggt_model.load_state_dict(torch.load(VGGT_DIR))
@@ -128,7 +128,7 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
128
  vggt_model = vggt_model.to("cuda")
129
  # process the image tensor
130
  video_tensor = preprocess_image(video_tensor)[None]
131
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
132
  # Predict attributes including cameras, depth maps, and point maps.
133
  aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
134
  pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
@@ -154,12 +154,12 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
154
  c2w_traj, intrs, point_map, conf_depth,
155
  track3d_pred, track2d_pred, vis_pred, conf_pred, video
156
  ) = model.forward(video_tensor, depth=depth_tensor,
157
- intrs=intrs, extrs=extrs,
158
  queries=query_xyt,
159
  fps=1, full_point=False, iters_track=4,
160
  query_no_BA=True, fixed_cam=False, stage=1,
161
- support_frame=len(video_tensor)-1, replace_ratio=0.2)
162
-
163
  # Resize results to avoid too large I/O Burden
164
  max_size = 336
165
  h, w = video.shape[2:]
@@ -174,12 +174,12 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
174
  if depth_tensor is not None:
175
  depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
176
  conf_depth = T.Resize((new_h, new_w))(conf_depth)
177
-
178
  # Visualize tracks
179
  viser.visualize(video=video[None],
180
  tracks=track2d_pred[None][...,:2],
181
  visibility=vis_pred[None],filename="test")
182
-
183
  # Save in tapip3d format
184
  data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
185
  data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
@@ -190,5 +190,5 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
190
  data_npz_load["confs"] = conf_pred.cpu().numpy()
191
  data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
192
  np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
193
-
194
  print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
 
34
  """
35
  viz = True
36
  os.makedirs(output_dir, exist_ok=True)
37
+
38
  with open(config["cfg_dir"], "r") as f:
39
  cfg = yaml.load(f, Loader=yaml.FullLoader)
40
  cfg = easydict.EasyDict(cfg)
41
  cfg.out_dir = output_dir
42
  cfg.model.track_num = vo_points
43
+
44
  # Check if it's a local path or HuggingFace repo
45
  if tracker_model is not None:
46
  model = tracker_model
 
60
  model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
61
  model.eval()
62
  model.to("cuda")
63
+
64
+ viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
65
  fps=10, pad_value=0, tracks_leave_trace=5)
66
 
67
  return model, viser
 
83
  mask_path = os.path.join(temp_dir, f"{video_name}.png")
84
  out_dir = os.path.join(temp_dir, "results")
85
  os.makedirs(out_dir, exist_ok=True)
86
+
87
  # Load video using decord
88
  video_reader = decord.VideoReader(video_path)
89
  video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
90
+
91
  # resize make sure the shortest side is 336
92
  h, w = video_tensor.shape[2:]
93
  scale = max(336 / h, 336 / w)
 
99
  intrs = None
100
  extrs = None
101
  data_npz_load = {}
102
+
103
  # Load and process mask
104
  if os.path.exists(mask_path):
105
  mask = cv2.imread(mask_path)
 
107
  mask = mask.sum(axis=-1)>0
108
  else:
109
  mask = np.ones_like(video_tensor[0,0].numpy())>0
110
+
111
  # Get frame dimensions and create grid points
112
  frame_H, frame_W = video_tensor.shape[2:]
113
  grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
114
+
115
  # Sample mask values at grid points and filter out points where mask=0
116
  if os.path.exists(mask_path):
117
  grid_pts_int = grid_pts[0].long()
118
  mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
119
  grid_pts = grid_pts[:, mask_values]
120
+
121
  query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
122
 
123
+ # run vggt
124
  if os.environ.get("VGGT_DIR", None) is not None:
125
  vggt_model = VGGT()
126
  vggt_model.load_state_dict(torch.load(VGGT_DIR))
 
128
  vggt_model = vggt_model.to("cuda")
129
  # process the image tensor
130
  video_tensor = preprocess_image(video_tensor)[None]
131
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
132
  # Predict attributes including cameras, depth maps, and point maps.
133
  aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
134
  pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
 
154
  c2w_traj, intrs, point_map, conf_depth,
155
  track3d_pred, track2d_pred, vis_pred, conf_pred, video
156
  ) = model.forward(video_tensor, depth=depth_tensor,
157
+ intrs=intrs, extrs=extrs,
158
  queries=query_xyt,
159
  fps=1, full_point=False, iters_track=4,
160
  query_no_BA=True, fixed_cam=False, stage=1,
161
+ support_frame=len(video_tensor)-1, replace_ratio=0.2)
162
+
163
  # Resize results to avoid too large I/O Burden
164
  max_size = 336
165
  h, w = video.shape[2:]
 
174
  if depth_tensor is not None:
175
  depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
176
  conf_depth = T.Resize((new_h, new_w))(conf_depth)
177
+
178
  # Visualize tracks
179
  viser.visualize(video=video[None],
180
  tracks=track2d_pred[None][...,:2],
181
  visibility=vis_pred[None],filename="test")
182
+
183
  # Save in tapip3d format
184
  data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
185
  data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
 
190
  data_npz_load["confs"] = conf_pred.cpu().numpy()
191
  data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
192
  np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
193
+
194
  print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
inference.py CHANGED
@@ -38,7 +38,7 @@ if __name__ == "__main__":
38
  # fps
39
  fps = int(args.fps)
40
  mask_dir = args.data_dir + f"/{args.video_name}.png"
41
-
42
  vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
43
  vggt4track_model.eval()
44
  vggt4track_model = vggt4track_model.to("cuda")
@@ -66,12 +66,12 @@ if __name__ == "__main__":
66
  # process the image tensor
67
  video_tensor = preprocess_image(video_tensor)[None]
68
  with torch.no_grad():
69
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
70
  # Predict attributes including cameras, depth maps, and point maps.
71
  predictions = vggt4track_model(video_tensor.cuda()/255)
72
  extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
73
  depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
74
-
75
  depth_tensor = depth_map.squeeze().cpu().numpy()
76
  extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
77
  extrs = extrinsic.squeeze().cpu().numpy()
@@ -82,7 +82,7 @@ if __name__ == "__main__":
82
  unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
83
 
84
  data_npz_load = {}
85
-
86
  if os.path.exists(mask_dir):
87
  mask_files = mask_dir
88
  mask = cv2.imread(mask_files)
@@ -90,11 +90,11 @@ if __name__ == "__main__":
90
  mask = mask.sum(axis=-1)>0
91
  else:
92
  mask = np.ones_like(video_tensor[0,0].numpy())>0
93
-
94
  # get all data pieces
95
  viz = True
96
  os.makedirs(out_dir, exist_ok=True)
97
-
98
  # with open(cfg_dir, "r") as f:
99
  # cfg = yaml.load(f, Loader=yaml.FullLoader)
100
  # cfg = easydict.EasyDict(cfg)
@@ -108,12 +108,12 @@ if __name__ == "__main__":
108
 
109
  # config the model; the track_num is the number of points in the grid
110
  model.spatrack.track_num = args.vo_points
111
-
112
  model.eval()
113
  model.to("cuda")
114
- viser = Visualizer(save_dir=out_dir, grayscale=True,
115
  fps=10, pad_value=0, tracks_leave_trace=5)
116
-
117
  grid_size = args.grid_size
118
 
119
  # get frame H W
@@ -124,13 +124,13 @@ if __name__ == "__main__":
124
  else:
125
  frame_H, frame_W = video_tensor.shape[2:]
126
  grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
127
-
128
  # Sample mask values at grid points and filter out points where mask=0
129
  if os.path.exists(mask_dir):
130
  grid_pts_int = grid_pts[0].long()
131
  mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
132
  grid_pts = grid_pts[:, mask_values]
133
-
134
  query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
135
 
136
  # Run model inference
@@ -139,12 +139,12 @@ if __name__ == "__main__":
139
  c2w_traj, intrs, point_map, conf_depth,
140
  track3d_pred, track2d_pred, vis_pred, conf_pred, video
141
  ) = model.forward(video_tensor, depth=depth_tensor,
142
- intrs=intrs, extrs=extrs,
143
  queries=query_xyt,
144
  fps=1, full_point=False, iters_track=4,
145
  query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
146
- support_frame=len(video_tensor)-1, replace_ratio=0.2)
147
-
148
  # resize the results to avoid too large I/O Burden
149
  # depth and image, the maximum side is 336
150
  max_size = 336
@@ -169,7 +169,7 @@ if __name__ == "__main__":
169
  tracks=track2d_pred[None][...,:2],
170
  visibility=vis_pred[None],filename="test")
171
 
172
- # save as the tapip3d format
173
  data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
174
  data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
175
  data_npz_load["intrinsics"] = intrs.cpu().numpy()
 
38
  # fps
39
  fps = int(args.fps)
40
  mask_dir = args.data_dir + f"/{args.video_name}.png"
41
+
42
  vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
43
  vggt4track_model.eval()
44
  vggt4track_model = vggt4track_model.to("cuda")
 
66
  # process the image tensor
67
  video_tensor = preprocess_image(video_tensor)[None]
68
  with torch.no_grad():
69
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
70
  # Predict attributes including cameras, depth maps, and point maps.
71
  predictions = vggt4track_model(video_tensor.cuda()/255)
72
  extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
73
  depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
74
+
75
  depth_tensor = depth_map.squeeze().cpu().numpy()
76
  extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
77
  extrs = extrinsic.squeeze().cpu().numpy()
 
82
  unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
83
 
84
  data_npz_load = {}
85
+
86
  if os.path.exists(mask_dir):
87
  mask_files = mask_dir
88
  mask = cv2.imread(mask_files)
 
90
  mask = mask.sum(axis=-1)>0
91
  else:
92
  mask = np.ones_like(video_tensor[0,0].numpy())>0
93
+
94
  # get all data pieces
95
  viz = True
96
  os.makedirs(out_dir, exist_ok=True)
97
+
98
  # with open(cfg_dir, "r") as f:
99
  # cfg = yaml.load(f, Loader=yaml.FullLoader)
100
  # cfg = easydict.EasyDict(cfg)
 
108
 
109
  # config the model; the track_num is the number of points in the grid
110
  model.spatrack.track_num = args.vo_points
111
+
112
  model.eval()
113
  model.to("cuda")
114
+ viser = Visualizer(save_dir=out_dir, grayscale=True,
115
  fps=10, pad_value=0, tracks_leave_trace=5)
116
+
117
  grid_size = args.grid_size
118
 
119
  # get frame H W
 
124
  else:
125
  frame_H, frame_W = video_tensor.shape[2:]
126
  grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
127
+
128
  # Sample mask values at grid points and filter out points where mask=0
129
  if os.path.exists(mask_dir):
130
  grid_pts_int = grid_pts[0].long()
131
  mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
132
  grid_pts = grid_pts[:, mask_values]
133
+
134
  query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
135
 
136
  # Run model inference
 
139
  c2w_traj, intrs, point_map, conf_depth,
140
  track3d_pred, track2d_pred, vis_pred, conf_pred, video
141
  ) = model.forward(video_tensor, depth=depth_tensor,
142
+ intrs=intrs, extrs=extrs,
143
  queries=query_xyt,
144
  fps=1, full_point=False, iters_track=4,
145
  query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
146
+ support_frame=len(video_tensor)-1, replace_ratio=0.2)
147
+
148
  # resize the results to avoid too large I/O Burden
149
  # depth and image, the maximum side is 336
150
  max_size = 336
 
169
  tracks=track2d_pred[None][...,:2],
170
  visibility=vis_pred[None],filename="test")
171
 
172
+ # save as the tapip3d format
173
  data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
174
  data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
175
  data_npz_load["intrinsics"] = intrs.cpu().numpy()
models/SpaTrackV2/models/utils.py CHANGED
@@ -95,7 +95,7 @@ class AverageMeter(object):
95
  return fmtstr.format(**self.__dict__)
96
 
97
 
98
- def procrustes_analysis(X0,X1): # [N,3]
99
  # translation
100
  t0 = X0.mean(dim=0,keepdim=True)
101
  t1 = X1.mean(dim=0,keepdim=True)
@@ -218,7 +218,7 @@ def get_EFP(pred_cameras, image_size, B, S, default_focal=False):
218
 
219
  intrinsics = create_intri_matrix(focal_length, principal_point)
220
  return extrinsics, intrinsics
221
-
222
  def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
223
  """
224
  Convert rotations given as quaternions to rotation matrices.
@@ -278,7 +278,7 @@ def pose_encoding_to_camera(
278
  # Now converted back
279
  focal_length = (log_focal_length + log_focal_length_bias).exp()
280
  # clamp to avoid weird fl values
281
- focal_length = torch.clamp(focal_length,
282
  min=min_focal_length, max=max_focal_length)
283
  elif pose_encoding_type == "absT_quaR_OneFL":
284
  # 3 for absT, 4 for quaR, 1 for absFL
@@ -287,7 +287,7 @@ def pose_encoding_to_camera(
287
  quaternion_R = pose_encoding_reshaped[:, 3:7]
288
  R = quaternion_to_matrix(quaternion_R)
289
  focal_length = pose_encoding_reshaped[:, 7:8]
290
- focal_length = torch.clamp(focal_length,
291
  min=min_focal_length, max=max_focal_length)
292
  else:
293
  raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
@@ -316,7 +316,7 @@ def pose_encoding_to_camera(
316
 
317
  R = extrinsics_4x4[:, :3, :3].clone()
318
  abs_T = extrinsics_4x4[:, :3, 3].clone()
319
-
320
  if return_dict:
321
  return {"focal_length": focal_length, "R": R, "T": abs_T}
322
 
@@ -326,7 +326,7 @@ def pose_encoding_to_camera(
326
 
327
 
328
  def camera_to_pose_encoding(
329
- camera, pose_encoding_type="absT_quaR_logFL",
330
  log_focal_length_bias=1.8, min_focal_length=0.1, max_focal_length=30
331
  ):
332
  """
@@ -359,7 +359,7 @@ def camera_to_pose_encoding(
359
  return pose_encoding
360
 
361
 
362
- def init_pose_enc(B: int,
363
  S: int, pose_encoding_type: str="absT_quaR_logFL",
364
  device: Optional[torch.device]=None):
365
  """
@@ -378,7 +378,7 @@ def init_pose_enc(B: int,
378
  C = 8
379
  else:
380
  raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
381
-
382
  pose_enc = torch.zeros(B, S, C, device=device)
383
  pose_enc[..., :3] = 0 # absT
384
  pose_enc[..., 3] = 1 # quaR
@@ -389,7 +389,7 @@ def first_pose_enc_norm(pose_enc: torch.Tensor,
389
  pose_encoding_type: str="absT_quaR_OneFL",
390
  pose_mode: str = "W2C"):
391
  """
392
- make sure the poses in on window are normalized by the first frame, where the
393
  first frame transformation is the Identity Matrix.
394
  NOTE: Poses are all W2C
395
  args:
@@ -403,23 +403,23 @@ def first_pose_enc_norm(pose_enc: torch.Tensor,
403
  pose_enc, pose_encoding_type=pose_encoding_type,
404
  to_OpenCV=False
405
  ) #NOTE: the camera parameters are not in NDC
406
-
407
  R = pred_cameras.R # [B*S, 3, 3]
408
  T = pred_cameras.T # [B*S, 3]
409
-
410
  Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*S, 3, 4]
411
  extra_ = torch.tensor([[[0, 0, 0, 1]]],
412
  device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
413
  Tran_M = torch.cat([Tran_M, extra_
414
  ], dim=1)
415
  Tran_M = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)
416
-
417
  # Take the first frame as the base of world coordinate
418
  if pose_mode == "C2W":
419
  Tran_M_new = (Tran_M[:,:1,...].inverse())@Tran_M
420
  elif pose_mode == "W2C":
421
  Tran_M_new = Tran_M@(Tran_M[:,:1,...].inverse())
422
-
423
  Tran_M_new = rearrange(Tran_M_new, 'b s c d -> (b s) c d')
424
 
425
  R_ = Tran_M_new[:, :3, :3]
@@ -429,7 +429,7 @@ def first_pose_enc_norm(pose_enc: torch.Tensor,
429
  pred_cameras.R = R_
430
  pred_cameras.T = T_
431
  pose_enc_norm = camera_to_pose_encoding(pred_cameras,
432
- pose_encoding_type=pose_encoding_type)
433
  pose_enc_norm = rearrange(pose_enc_norm, '(b s) c -> b s c', b=B)
434
  return pose_enc_norm
435
 
@@ -439,7 +439,7 @@ def first_pose_enc_denorm(
439
  pose_encoding_type: str="absT_quaR_OneFL",
440
  pose_mode: str = "W2C"):
441
  """
442
- make sure the poses in on window are de-normalized by the first frame, where the
443
  first frame transformation is the Identity Matrix.
444
  args:
445
  pose_enc: [B S C]
@@ -457,7 +457,7 @@ def first_pose_enc_denorm(
457
  ) #NOTE: the camera parameters are not in NDC
458
  R = pred_cameras.R # [B*(1+S), 3, 3]
459
  T = pred_cameras.T # [B*(1+S), 3]
460
-
461
  Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*(1+S), 3, 4]
462
  extra_ = torch.tensor([[[0, 0, 0, 1]]],
463
  device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
@@ -470,7 +470,7 @@ def first_pose_enc_denorm(
470
  Tran_M_new = Tran_M_1st@Tran_M_new
471
  elif pose_mode == "W2C":
472
  Tran_M_new = Tran_M_new@Tran_M_1st
473
-
474
  Tran_M_new_ = torch.cat([Tran_M_1st, Tran_M_new], dim=1)
475
  R_ = Tran_M_new_[..., :3, :3].view(-1, 3, 3)
476
  T_ = Tran_M_new_[..., :3, 3].view(-1, 3)
@@ -481,7 +481,7 @@ def first_pose_enc_denorm(
481
 
482
  # Cameras to Pose encoding
483
  pose_enc_denorm = camera_to_pose_encoding(pred_cameras,
484
- pose_encoding_type=pose_encoding_type)
485
  pose_enc_denorm = rearrange(pose_enc_denorm, '(b s) c -> b s c', b=B)
486
  return pose_enc_denorm[:, 1:]
487
 
@@ -560,7 +560,7 @@ def median_loss(prediction, target, mask, Bs):
560
  target_nm, a_norm_gt, b_norm_gt = normalize_prediction_robust(target.float(), mask, Bs)
561
  depth_loss = nn.functional.l1_loss(prediction_nm[mask], target_nm[mask])
562
  scale = b_norm_gt/b_norm
563
- shift = a_norm_gt - a_norm*scale
564
  return depth_loss, scale, shift, prediction_nm, target_nm
565
 
566
  def reduction_batch_based(image_loss, M):
@@ -593,7 +593,7 @@ class ScaleAndShiftInvariantLoss(nn.Module):
593
 
594
  def forward(self, prediction, target, mask, Bs,
595
  interpolate=True, return_interpolated=False):
596
-
597
  if prediction.shape[-1] != target.shape[-1] and interpolate:
598
  prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True)
599
  intr_input = prediction
@@ -602,7 +602,7 @@ class ScaleAndShiftInvariantLoss(nn.Module):
602
 
603
  prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze()
604
  assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}."
605
-
606
 
607
  scale, shift = compute_scale_and_shift(prediction, target, mask)
608
  a_norm = scale.view(Bs, -1, 1, 1).mean(dim=1, keepdim=True)
@@ -634,7 +634,7 @@ class GradientLoss(nn.Module):
634
 
635
  for scale in range(self.__scales):
636
  step = pow(2, scale)
637
- l1_ln, a_nm, b_nm = ScaleAndShiftInvariantLoss_fn(prediction[:, ::step, ::step],
638
  target[:, ::step, ::step], mask[:, ::step, ::step], 1)
639
  total += l1_ln
640
  a_nm = a_nm.squeeze().detach() # [B, 1, 1]
@@ -663,7 +663,7 @@ def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
663
  image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
664
 
665
  return reduction(image_loss, M)
666
-
667
  def loss_fn(
668
  poses_preds: List[torch.Tensor],
669
  poses_pred_all: List[torch.Tensor],
@@ -700,7 +700,7 @@ def loss_fn(
700
  if logger is not None:
701
  if poses_preds_ij.max()>5e1:
702
  logger.info(f"pose_pred_max_and_mean: {poses_preds_ij.max(), poses_preds_ij.mean()}")
703
-
704
  trans_loss = (poses_preds_ij[...,:3] - poses_gt_i_norm[...,:3]).abs().sum(dim=-1).mean()
705
  rot_loss = (poses_preds_ij[...,3:7] - poses_gt_i_norm[...,3:7]).abs().sum(dim=-1).mean()
706
  focal_loss = (poses_preds_ij[...,7:] - poses_gt_i_norm[...,7:]).abs().sum(dim=-1).mean()
@@ -714,7 +714,7 @@ def loss_fn(
714
  logger_tf.add_scalar(f"loss@pose/rot_iter{idx}",
715
  rot_loss, global_step=global_step)
716
  logger_tf.add_scalar(f"loss@pose/focal_iter{idx}",
717
- focal_loss, global_step=global_step)
718
  # compute the uncertainty loss
719
  with torch.no_grad():
720
  pose_loss_dist = (poses_preds_ij-poses_gt_i_norm).detach().abs()
@@ -726,9 +726,9 @@ def loss_fn(
726
  unc_loss,
727
  global_step=global_step)
728
  # if logger is not None:
729
- # logger.info(f"pose_loss: {pose_loss}, unc_loss: {unc_loss}")
730
  # total loss
731
- loss_total += 0.1*unc_loss + 2*pose_loss
732
 
733
  poses_gt_norm = poses_gt
734
  pose_all_loss = 0.0
@@ -743,7 +743,7 @@ def loss_fn(
743
  prev_loss = (trans_loss + rot_loss + focal_loss)
744
  else:
745
  des_loss = (trans_loss + rot_loss + focal_loss) - prev_loss
746
- prev_loss = trans_loss + rot_loss + focal_loss
747
  logger_tf.add_scalar(f"loss@global_pose/des_iter{idx}",
748
  des_loss, global_step=global_step)
749
  logger_tf.add_scalar(f"loss@global_pose/trans_iter{idx}",
@@ -751,20 +751,20 @@ def loss_fn(
751
  logger_tf.add_scalar(f"loss@global_pose/rot_iter{idx}",
752
  rot_loss, global_step=global_step)
753
  logger_tf.add_scalar(f"loss@global_pose/focal_iter{idx}",
754
- focal_loss, global_step=global_step)
755
  if torch.isnan((trans_loss + rot_loss + focal_loss)).any():
756
  pose_all_loss += 0
757
  else:
758
  pose_all_loss += i_weight*(trans_loss + rot_loss + focal_loss)
759
-
760
  # if logger is not None:
761
- # logger.info(f"global_pose_loss: {pose_all_loss}")
762
 
763
  # compute the depth loss
764
  if inv_depth_preds[0] is not None:
765
  depths_gt = depths_gt[:,:,0]
766
  msk = depths_gt > 5e-2
767
- inv_gt = 1.0 / (depths_gt.clamp(1e-3, 1e16))
768
  inv_gt_reshp = rearrange(inv_gt, 'b t h w -> (b t) h w')
769
  inv_depth_preds_reshp = rearrange(inv_depth_preds[0], 'b t h w -> (b t) h w')
770
  inv_raw_reshp = rearrange(inv_depth_raw[0], 'b t h w -> (b t) h w')
@@ -785,11 +785,11 @@ def loss_fn(
785
  depth_loss,
786
  global_step=global_step)
787
  # if logger is not None:
788
- # logger.info(f"opt_depth: {huber_loss_raw - huber_loss}")
789
  else:
790
  depth_loss = 0.0
791
 
792
-
793
  loss_total = loss_total/(len(poses_preds)) + 20*depth_loss + pose_all_loss
794
 
795
  return loss_total, (huber_loss_raw - huber_loss)
@@ -803,7 +803,7 @@ def vis_depth(x: torch.tensor,
803
  """
804
  assert len(x.shape) == 2
805
 
806
- depth_map_normalized = cv2.normalize(x.cpu().numpy(),
807
  None, 0, 255, cv2.NORM_MINMAX)
808
  depth_map_colored = cv2.applyColorMap(depth_map_normalized.astype(np.uint8),
809
  cv2.COLORMAP_JET)
@@ -848,7 +848,7 @@ def vis_pcd(
848
  return pcl
849
 
850
  def vis_result(rgbs, poses_pred, poses_gt,
851
- depth_gt, depth_pred, iter_num=0,
852
  vis=None, logger_tf=None, cfg=None):
853
  """
854
  Args:
@@ -863,7 +863,7 @@ def vis_result(rgbs, poses_pred, poses_gt,
863
  if vis is None:
864
  return
865
  S, _, H, W = depth_gt.shape
866
- # get the xy
867
  yx = torch.meshgrid(torch.arange(H).to(depth_pred.device),
868
  torch.arange(W).to(depth_pred.device),indexing='ij')
869
  xy = torch.stack(yx[::-1], dim=0).float().to(depth_pred.device)
@@ -880,7 +880,7 @@ def vis_result(rgbs, poses_pred, poses_gt,
880
  pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
881
  poses_pred_vis = pose_encoding_to_camera(poses_pred,
882
  pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
883
-
884
  R_gt = poses_gt_vis.R.float()
885
  R_pred = poses_pred_vis.R.float()
886
  T_gt = poses_gt_vis.T.float()
@@ -890,8 +890,8 @@ def vis_result(rgbs, poses_pred, poses_gt,
890
  T_gt_c2w = (-R_gt_c2w @ T_gt[:, :, None]).squeeze(-1)
891
  R_pred_c2w = R_pred.permute(0,2,1)
892
  T_pred_c2w = (-R_pred_c2w @ T_pred[:, :, None]).squeeze(-1)
893
- with torch.cuda.amp.autocast(enabled=False):
894
- pick_idx = torch.randperm(S)[:min(24, S)]
895
  # pick_idx = [1]
896
  #NOTE: very strange that the camera need C2W Rotation and W2C translation as input
897
  poses_gt_vis = PerspectiveCamerasVisual(
@@ -922,9 +922,9 @@ def vis_result(rgbs, poses_pred, poses_gt,
922
  fig = plot_scene(visual_dict, camera_scale=0.05)
923
  vis.plotlyplot(fig, env=env_name, win="3D")
924
  vis.save([env_name])
925
-
926
  return
927
-
928
  def depth2pcd(
929
  xy_depth: torch.Tensor,
930
  focal_length: torch.Tensor,
@@ -953,7 +953,7 @@ def depth2pcd(
953
  K_inv = K.inverse()
954
  # xyz
955
  xyz = xy_depth.view(S, -1, 3).permute(0, 2, 1) # S 3 (H W)
956
- depth = xyz[:, 2:].clone() # S (H W) 1
957
  xyz[:, 2] = 1
958
  xyz = K_inv @ xyz # S 3 (H W)
959
  xyz = xyz * depth
@@ -963,29 +963,29 @@ def depth2pcd(
963
  return xyz
964
 
965
 
966
- def pose_enc2mat(poses_pred,
967
  H_resize, W_resize, resolution=336):
968
  """
969
  This function convert the pose encoding into `intrinsic` and `extrinsic`
970
 
971
  Args:
972
  poses_pred: B T 8
973
- Return:
974
  Intrinsic B T 3 3
975
  Extrinsic B T 4 4
976
  """
977
  B, T, _ = poses_pred.shape
978
  focal_pred = poses_pred[:, :, -1].clone()
979
- pos_quat_preds = poses_pred[:, :, :7].clone()
980
- pos_quat_preds = pos_quat_preds.view(B*T, -1)
981
- # get extrinsic
982
  c2w_rot = quaternion_to_matrix(pos_quat_preds[:, 3:])
983
  c2w_tran = pos_quat_preds[:, :3]
984
  c2w_traj = torch.eye(4)[None].repeat(B*T, 1, 1).to(poses_pred.device)
985
  c2w_traj[:, :3, :3], c2w_traj[:, :3, 3] = c2w_rot, c2w_tran
986
  c2w_traj = c2w_traj.view(B, T, 4, 4)
987
  # get intrinsic
988
- fxs, fys = focal_pred*resolution, focal_pred*resolution
989
  intrs = torch.eye(3).to(c2w_traj.device).to(c2w_traj.dtype)[None, None].repeat(B, T, 1, 1)
990
  intrs[:,:,0,0], intrs[:,:,1,1] = fxs, fys
991
  intrs[:,:,0,2], intrs[:,:,1,2] = W_resize/2, H_resize/2
@@ -1001,7 +1001,7 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
1001
  positive_mask = x > 0
1002
  ret[positive_mask] = torch.sqrt(x[positive_mask])
1003
  return ret
1004
-
1005
  def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
1006
  """
1007
  Convert a unit quaternion to a standard form: one in which the real
@@ -1086,11 +1086,11 @@ def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
1086
  return grid
1087
  else:
1088
  return grid_y, grid_x
1089
-
1090
  def get_points_on_a_grid(grid_size, interp_shape,
1091
  grid_center=(0, 0), device="cuda"):
1092
  if grid_size == 1:
1093
- return torch.tensor([interp_shape[1] / 2,
1094
  interp_shape[0] / 2], device=device)[
1095
  None, None
1096
  ]
@@ -1114,12 +1114,12 @@ def get_points_on_a_grid(grid_size, interp_shape,
1114
  xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
1115
  return xy
1116
 
1117
- def normalize_rgb(x,input_size=224,
1118
  resize_mode: Literal['resize', 'padding'] = 'resize',
1119
  if_da=False):
1120
  """
1121
  normalize the image for depth anything input
1122
-
1123
  args:
1124
  x: the input images [B T C H W]
1125
  """
@@ -1127,8 +1127,8 @@ def normalize_rgb(x,input_size=224,
1127
  x = torch.from_numpy(x) / 255.0
1128
  elif isinstance(x, torch.Tensor):
1129
  x = x / 255.0
1130
- B, T, C, H, W = x.shape
1131
- x = x.view(B * T, C, H, W)
1132
  Resizer = Resize(
1133
  width=input_size,
1134
  height=input_size,
@@ -1136,7 +1136,7 @@ def normalize_rgb(x,input_size=224,
1136
  keep_aspect_ratio=True,
1137
  ensure_multiple_of=14,
1138
  resize_method='lower_bound',
1139
- )
1140
  if resize_mode == 'padding':
1141
  # zero padding to make the input size to be multiple of 14
1142
  if H > W:
@@ -1160,7 +1160,7 @@ def normalize_rgb(x,input_size=224,
1160
  x = F.interpolate(x, size=(int(H_scale), int(W_scale)),
1161
  mode='bicubic', align_corners=True)
1162
  # get the mean and std
1163
- __mean__ = torch.tensor([0.485,
1164
  0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
1165
  __std__ = torch.tensor([0.229,
1166
  0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
@@ -1168,7 +1168,7 @@ def normalize_rgb(x,input_size=224,
1168
  if if_da:
1169
  x = (x - __mean__) / __std__
1170
  else:
1171
- x = x
1172
  return x.view(B, T, C, x.shape[-2], x.shape[-1])
1173
 
1174
  def get_track_points(H, W, T, device, size=100, support_frame=0,
 
95
  return fmtstr.format(**self.__dict__)
96
 
97
 
98
+ def procrustes_analysis(X0,X1): # [N,3]
99
  # translation
100
  t0 = X0.mean(dim=0,keepdim=True)
101
  t1 = X1.mean(dim=0,keepdim=True)
 
218
 
219
  intrinsics = create_intri_matrix(focal_length, principal_point)
220
  return extrinsics, intrinsics
221
+
222
  def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
223
  """
224
  Convert rotations given as quaternions to rotation matrices.
 
278
  # Now converted back
279
  focal_length = (log_focal_length + log_focal_length_bias).exp()
280
  # clamp to avoid weird fl values
281
+ focal_length = torch.clamp(focal_length,
282
  min=min_focal_length, max=max_focal_length)
283
  elif pose_encoding_type == "absT_quaR_OneFL":
284
  # 3 for absT, 4 for quaR, 1 for absFL
 
287
  quaternion_R = pose_encoding_reshaped[:, 3:7]
288
  R = quaternion_to_matrix(quaternion_R)
289
  focal_length = pose_encoding_reshaped[:, 7:8]
290
+ focal_length = torch.clamp(focal_length,
291
  min=min_focal_length, max=max_focal_length)
292
  else:
293
  raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
 
316
 
317
  R = extrinsics_4x4[:, :3, :3].clone()
318
  abs_T = extrinsics_4x4[:, :3, 3].clone()
319
+
320
  if return_dict:
321
  return {"focal_length": focal_length, "R": R, "T": abs_T}
322
 
 
326
 
327
 
328
  def camera_to_pose_encoding(
329
+ camera, pose_encoding_type="absT_quaR_logFL",
330
  log_focal_length_bias=1.8, min_focal_length=0.1, max_focal_length=30
331
  ):
332
  """
 
359
  return pose_encoding
360
 
361
 
362
+ def init_pose_enc(B: int,
363
  S: int, pose_encoding_type: str="absT_quaR_logFL",
364
  device: Optional[torch.device]=None):
365
  """
 
378
  C = 8
379
  else:
380
  raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
381
+
382
  pose_enc = torch.zeros(B, S, C, device=device)
383
  pose_enc[..., :3] = 0 # absT
384
  pose_enc[..., 3] = 1 # quaR
 
389
  pose_encoding_type: str="absT_quaR_OneFL",
390
  pose_mode: str = "W2C"):
391
  """
392
+ make sure the poses in on window are normalized by the first frame, where the
393
  first frame transformation is the Identity Matrix.
394
  NOTE: Poses are all W2C
395
  args:
 
403
  pose_enc, pose_encoding_type=pose_encoding_type,
404
  to_OpenCV=False
405
  ) #NOTE: the camera parameters are not in NDC
406
+
407
  R = pred_cameras.R # [B*S, 3, 3]
408
  T = pred_cameras.T # [B*S, 3]
409
+
410
  Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*S, 3, 4]
411
  extra_ = torch.tensor([[[0, 0, 0, 1]]],
412
  device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
413
  Tran_M = torch.cat([Tran_M, extra_
414
  ], dim=1)
415
  Tran_M = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)
416
+
417
  # Take the first frame as the base of world coordinate
418
  if pose_mode == "C2W":
419
  Tran_M_new = (Tran_M[:,:1,...].inverse())@Tran_M
420
  elif pose_mode == "W2C":
421
  Tran_M_new = Tran_M@(Tran_M[:,:1,...].inverse())
422
+
423
  Tran_M_new = rearrange(Tran_M_new, 'b s c d -> (b s) c d')
424
 
425
  R_ = Tran_M_new[:, :3, :3]
 
429
  pred_cameras.R = R_
430
  pred_cameras.T = T_
431
  pose_enc_norm = camera_to_pose_encoding(pred_cameras,
432
+ pose_encoding_type=pose_encoding_type)
433
  pose_enc_norm = rearrange(pose_enc_norm, '(b s) c -> b s c', b=B)
434
  return pose_enc_norm
435
 
 
439
  pose_encoding_type: str="absT_quaR_OneFL",
440
  pose_mode: str = "W2C"):
441
  """
442
+ make sure the poses in on window are de-normalized by the first frame, where the
443
  first frame transformation is the Identity Matrix.
444
  args:
445
  pose_enc: [B S C]
 
457
  ) #NOTE: the camera parameters are not in NDC
458
  R = pred_cameras.R # [B*(1+S), 3, 3]
459
  T = pred_cameras.T # [B*(1+S), 3]
460
+
461
  Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*(1+S), 3, 4]
462
  extra_ = torch.tensor([[[0, 0, 0, 1]]],
463
  device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
 
470
  Tran_M_new = Tran_M_1st@Tran_M_new
471
  elif pose_mode == "W2C":
472
  Tran_M_new = Tran_M_new@Tran_M_1st
473
+
474
  Tran_M_new_ = torch.cat([Tran_M_1st, Tran_M_new], dim=1)
475
  R_ = Tran_M_new_[..., :3, :3].view(-1, 3, 3)
476
  T_ = Tran_M_new_[..., :3, 3].view(-1, 3)
 
481
 
482
  # Cameras to Pose encoding
483
  pose_enc_denorm = camera_to_pose_encoding(pred_cameras,
484
+ pose_encoding_type=pose_encoding_type)
485
  pose_enc_denorm = rearrange(pose_enc_denorm, '(b s) c -> b s c', b=B)
486
  return pose_enc_denorm[:, 1:]
487
 
 
560
  target_nm, a_norm_gt, b_norm_gt = normalize_prediction_robust(target.float(), mask, Bs)
561
  depth_loss = nn.functional.l1_loss(prediction_nm[mask], target_nm[mask])
562
  scale = b_norm_gt/b_norm
563
+ shift = a_norm_gt - a_norm*scale
564
  return depth_loss, scale, shift, prediction_nm, target_nm
565
 
566
  def reduction_batch_based(image_loss, M):
 
593
 
594
  def forward(self, prediction, target, mask, Bs,
595
  interpolate=True, return_interpolated=False):
596
+
597
  if prediction.shape[-1] != target.shape[-1] and interpolate:
598
  prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True)
599
  intr_input = prediction
 
602
 
603
  prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze()
604
  assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}."
605
+
606
 
607
  scale, shift = compute_scale_and_shift(prediction, target, mask)
608
  a_norm = scale.view(Bs, -1, 1, 1).mean(dim=1, keepdim=True)
 
634
 
635
  for scale in range(self.__scales):
636
  step = pow(2, scale)
637
+ l1_ln, a_nm, b_nm = ScaleAndShiftInvariantLoss_fn(prediction[:, ::step, ::step],
638
  target[:, ::step, ::step], mask[:, ::step, ::step], 1)
639
  total += l1_ln
640
  a_nm = a_nm.squeeze().detach() # [B, 1, 1]
 
663
  image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
664
 
665
  return reduction(image_loss, M)
666
+
667
  def loss_fn(
668
  poses_preds: List[torch.Tensor],
669
  poses_pred_all: List[torch.Tensor],
 
700
  if logger is not None:
701
  if poses_preds_ij.max()>5e1:
702
  logger.info(f"pose_pred_max_and_mean: {poses_preds_ij.max(), poses_preds_ij.mean()}")
703
+
704
  trans_loss = (poses_preds_ij[...,:3] - poses_gt_i_norm[...,:3]).abs().sum(dim=-1).mean()
705
  rot_loss = (poses_preds_ij[...,3:7] - poses_gt_i_norm[...,3:7]).abs().sum(dim=-1).mean()
706
  focal_loss = (poses_preds_ij[...,7:] - poses_gt_i_norm[...,7:]).abs().sum(dim=-1).mean()
 
714
  logger_tf.add_scalar(f"loss@pose/rot_iter{idx}",
715
  rot_loss, global_step=global_step)
716
  logger_tf.add_scalar(f"loss@pose/focal_iter{idx}",
717
+ focal_loss, global_step=global_step)
718
  # compute the uncertainty loss
719
  with torch.no_grad():
720
  pose_loss_dist = (poses_preds_ij-poses_gt_i_norm).detach().abs()
 
726
  unc_loss,
727
  global_step=global_step)
728
  # if logger is not None:
729
+ # logger.info(f"pose_loss: {pose_loss}, unc_loss: {unc_loss}")
730
  # total loss
731
+ loss_total += 0.1*unc_loss + 2*pose_loss
732
 
733
  poses_gt_norm = poses_gt
734
  pose_all_loss = 0.0
 
743
  prev_loss = (trans_loss + rot_loss + focal_loss)
744
  else:
745
  des_loss = (trans_loss + rot_loss + focal_loss) - prev_loss
746
+ prev_loss = trans_loss + rot_loss + focal_loss
747
  logger_tf.add_scalar(f"loss@global_pose/des_iter{idx}",
748
  des_loss, global_step=global_step)
749
  logger_tf.add_scalar(f"loss@global_pose/trans_iter{idx}",
 
751
  logger_tf.add_scalar(f"loss@global_pose/rot_iter{idx}",
752
  rot_loss, global_step=global_step)
753
  logger_tf.add_scalar(f"loss@global_pose/focal_iter{idx}",
754
+ focal_loss, global_step=global_step)
755
  if torch.isnan((trans_loss + rot_loss + focal_loss)).any():
756
  pose_all_loss += 0
757
  else:
758
  pose_all_loss += i_weight*(trans_loss + rot_loss + focal_loss)
759
+
760
  # if logger is not None:
761
+ # logger.info(f"global_pose_loss: {pose_all_loss}")
762
 
763
  # compute the depth loss
764
  if inv_depth_preds[0] is not None:
765
  depths_gt = depths_gt[:,:,0]
766
  msk = depths_gt > 5e-2
767
+ inv_gt = 1.0 / (depths_gt.clamp(1e-3, 1e16))
768
  inv_gt_reshp = rearrange(inv_gt, 'b t h w -> (b t) h w')
769
  inv_depth_preds_reshp = rearrange(inv_depth_preds[0], 'b t h w -> (b t) h w')
770
  inv_raw_reshp = rearrange(inv_depth_raw[0], 'b t h w -> (b t) h w')
 
785
  depth_loss,
786
  global_step=global_step)
787
  # if logger is not None:
788
+ # logger.info(f"opt_depth: {huber_loss_raw - huber_loss}")
789
  else:
790
  depth_loss = 0.0
791
 
792
+
793
  loss_total = loss_total/(len(poses_preds)) + 20*depth_loss + pose_all_loss
794
 
795
  return loss_total, (huber_loss_raw - huber_loss)
 
803
  """
804
  assert len(x.shape) == 2
805
 
806
+ depth_map_normalized = cv2.normalize(x.cpu().numpy(),
807
  None, 0, 255, cv2.NORM_MINMAX)
808
  depth_map_colored = cv2.applyColorMap(depth_map_normalized.astype(np.uint8),
809
  cv2.COLORMAP_JET)
 
848
  return pcl
849
 
850
  def vis_result(rgbs, poses_pred, poses_gt,
851
+ depth_gt, depth_pred, iter_num=0,
852
  vis=None, logger_tf=None, cfg=None):
853
  """
854
  Args:
 
863
  if vis is None:
864
  return
865
  S, _, H, W = depth_gt.shape
866
+ # get the xy
867
  yx = torch.meshgrid(torch.arange(H).to(depth_pred.device),
868
  torch.arange(W).to(depth_pred.device),indexing='ij')
869
  xy = torch.stack(yx[::-1], dim=0).float().to(depth_pred.device)
 
880
  pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
881
  poses_pred_vis = pose_encoding_to_camera(poses_pred,
882
  pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
883
+
884
  R_gt = poses_gt_vis.R.float()
885
  R_pred = poses_pred_vis.R.float()
886
  T_gt = poses_gt_vis.T.float()
 
890
  T_gt_c2w = (-R_gt_c2w @ T_gt[:, :, None]).squeeze(-1)
891
  R_pred_c2w = R_pred.permute(0,2,1)
892
  T_pred_c2w = (-R_pred_c2w @ T_pred[:, :, None]).squeeze(-1)
893
+ with torch.amp.autocast('cuda', enabled=False):
894
+ pick_idx = torch.randperm(S)[:min(24, S)]
895
  # pick_idx = [1]
896
  #NOTE: very strange that the camera need C2W Rotation and W2C translation as input
897
  poses_gt_vis = PerspectiveCamerasVisual(
 
922
  fig = plot_scene(visual_dict, camera_scale=0.05)
923
  vis.plotlyplot(fig, env=env_name, win="3D")
924
  vis.save([env_name])
925
+
926
  return
927
+
928
  def depth2pcd(
929
  xy_depth: torch.Tensor,
930
  focal_length: torch.Tensor,
 
953
  K_inv = K.inverse()
954
  # xyz
955
  xyz = xy_depth.view(S, -1, 3).permute(0, 2, 1) # S 3 (H W)
956
+ depth = xyz[:, 2:].clone() # S (H W) 1
957
  xyz[:, 2] = 1
958
  xyz = K_inv @ xyz # S 3 (H W)
959
  xyz = xyz * depth
 
963
  return xyz
964
 
965
 
966
+ def pose_enc2mat(poses_pred,
967
  H_resize, W_resize, resolution=336):
968
  """
969
  This function convert the pose encoding into `intrinsic` and `extrinsic`
970
 
971
  Args:
972
  poses_pred: B T 8
973
+ Return:
974
  Intrinsic B T 3 3
975
  Extrinsic B T 4 4
976
  """
977
  B, T, _ = poses_pred.shape
978
  focal_pred = poses_pred[:, :, -1].clone()
979
+ pos_quat_preds = poses_pred[:, :, :7].clone()
980
+ pos_quat_preds = pos_quat_preds.view(B*T, -1)
981
+ # get extrinsic
982
  c2w_rot = quaternion_to_matrix(pos_quat_preds[:, 3:])
983
  c2w_tran = pos_quat_preds[:, :3]
984
  c2w_traj = torch.eye(4)[None].repeat(B*T, 1, 1).to(poses_pred.device)
985
  c2w_traj[:, :3, :3], c2w_traj[:, :3, 3] = c2w_rot, c2w_tran
986
  c2w_traj = c2w_traj.view(B, T, 4, 4)
987
  # get intrinsic
988
+ fxs, fys = focal_pred*resolution, focal_pred*resolution
989
  intrs = torch.eye(3).to(c2w_traj.device).to(c2w_traj.dtype)[None, None].repeat(B, T, 1, 1)
990
  intrs[:,:,0,0], intrs[:,:,1,1] = fxs, fys
991
  intrs[:,:,0,2], intrs[:,:,1,2] = W_resize/2, H_resize/2
 
1001
  positive_mask = x > 0
1002
  ret[positive_mask] = torch.sqrt(x[positive_mask])
1003
  return ret
1004
+
1005
  def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
1006
  """
1007
  Convert a unit quaternion to a standard form: one in which the real
 
1086
  return grid
1087
  else:
1088
  return grid_y, grid_x
1089
+
1090
  def get_points_on_a_grid(grid_size, interp_shape,
1091
  grid_center=(0, 0), device="cuda"):
1092
  if grid_size == 1:
1093
+ return torch.tensor([interp_shape[1] / 2,
1094
  interp_shape[0] / 2], device=device)[
1095
  None, None
1096
  ]
 
1114
  xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
1115
  return xy
1116
 
1117
+ def normalize_rgb(x,input_size=224,
1118
  resize_mode: Literal['resize', 'padding'] = 'resize',
1119
  if_da=False):
1120
  """
1121
  normalize the image for depth anything input
1122
+
1123
  args:
1124
  x: the input images [B T C H W]
1125
  """
 
1127
  x = torch.from_numpy(x) / 255.0
1128
  elif isinstance(x, torch.Tensor):
1129
  x = x / 255.0
1130
+ B, T, C, H, W = x.shape
1131
+ x = x.view(B * T, C, H, W)
1132
  Resizer = Resize(
1133
  width=input_size,
1134
  height=input_size,
 
1136
  keep_aspect_ratio=True,
1137
  ensure_multiple_of=14,
1138
  resize_method='lower_bound',
1139
+ )
1140
  if resize_mode == 'padding':
1141
  # zero padding to make the input size to be multiple of 14
1142
  if H > W:
 
1160
  x = F.interpolate(x, size=(int(H_scale), int(W_scale)),
1161
  mode='bicubic', align_corners=True)
1162
  # get the mean and std
1163
+ __mean__ = torch.tensor([0.485,
1164
  0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
1165
  __std__ = torch.tensor([0.229,
1166
  0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
 
1168
  if if_da:
1169
  x = (x - __mean__) / __std__
1170
  else:
1171
+ x = x
1172
  return x.view(B, T, C, x.shape[-2], x.shape[-1])
1173
 
1174
  def get_track_points(H, W, T, device, size=100, support_frame=0,
models/SpaTrackV2/models/vggt4track/models/tracker_front.py CHANGED
@@ -75,15 +75,15 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
75
  B, T, C, H, W = images.shape
76
  images = (images - self.base_model.image_mean) / self.base_model.image_std
77
  H_14 = H // 14 * 14
78
- W_14 = W // 14 * 14
79
  image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
80
 
81
  with torch.no_grad():
82
- features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
83
  self.base_model.intermediate_layers, return_class_token=True)
84
  # aggregate the features with checkpoint
85
  aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
86
-
87
  # enhance the features
88
  enhanced_features = []
89
  for layer_i, layer in enumerate(self.intermediate_layers):
@@ -94,7 +94,7 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
94
 
95
  predictions = {}
96
 
97
- with torch.cuda.amp.autocast(enabled=False):
98
  if self.camera_head is not None:
99
  pose_enc_list = self.camera_head(aggregated_tokens_list)
100
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
@@ -104,7 +104,7 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
104
  # Predict points (and mask) with checkpoint
105
  output = self.base_model.head(enhanced_features, image_14)
106
  points, mask = output
107
-
108
  # Post-process points and mask
109
  points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
110
  points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
@@ -119,13 +119,13 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
119
  if self.training:
120
  loss = compute_loss(predictions, annots)
121
  predictions["loss"] = loss
122
-
123
  # rescale the points
124
  if self.scale_head is not None:
125
  points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
126
  points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
127
  predictions["points_map"] = points_scale
128
-
129
  predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
130
  predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
131
  predictions["images"].shape[-2:])
 
75
  B, T, C, H, W = images.shape
76
  images = (images - self.base_model.image_mean) / self.base_model.image_std
77
  H_14 = H // 14 * 14
78
+ W_14 = W // 14 * 14
79
  image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
80
 
81
  with torch.no_grad():
82
+ features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
83
  self.base_model.intermediate_layers, return_class_token=True)
84
  # aggregate the features with checkpoint
85
  aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
86
+
87
  # enhance the features
88
  enhanced_features = []
89
  for layer_i, layer in enumerate(self.intermediate_layers):
 
94
 
95
  predictions = {}
96
 
97
+ with torch.amp.autocast('cuda', enabled=False):
98
  if self.camera_head is not None:
99
  pose_enc_list = self.camera_head(aggregated_tokens_list)
100
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
 
104
  # Predict points (and mask) with checkpoint
105
  output = self.base_model.head(enhanced_features, image_14)
106
  points, mask = output
107
+
108
  # Post-process points and mask
109
  points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
110
  points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
 
119
  if self.training:
120
  loss = compute_loss(predictions, annots)
121
  predictions["loss"] = loss
122
+
123
  # rescale the points
124
  if self.scale_head is not None:
125
  points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
126
  points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
127
  predictions["points_map"] = points_scale
128
+
129
  predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
130
  predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
131
  predictions["images"].shape[-2:])
models/SpaTrackV2/models/vggt4track/models/vggt.py CHANGED
@@ -64,7 +64,7 @@ class VGGT(nn.Module, PyTorchModelHubMixin):
64
 
65
  predictions = {}
66
 
67
- with torch.cuda.amp.autocast(enabled=False):
68
  if self.camera_head is not None:
69
  pose_enc_list = self.camera_head(aggregated_tokens_list)
70
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
 
64
 
65
  predictions = {}
66
 
67
+ with torch.amp.autocast('cuda', enabled=False):
68
  if self.camera_head is not None:
69
  pose_enc_list = self.camera_head(aggregated_tokens_list)
70
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
models/SpaTrackV2/models/vggt4track/models/vggt_moe.py CHANGED
@@ -65,13 +65,13 @@ class VGGT4Track(nn.Module, PyTorchModelHubMixin):
65
 
66
  if len(images.shape) == 4:
67
  images = images.unsqueeze(0)
68
-
69
  with torch.no_grad():
70
  aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
71
 
72
  predictions = {}
73
 
74
- with torch.cuda.amp.autocast(enabled=False):
75
  if self.camera_head is not None:
76
  pose_enc_list = self.camera_head(aggregated_tokens_list)
77
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
@@ -97,11 +97,11 @@ class VGGT4Track(nn.Module, PyTorchModelHubMixin):
97
  size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
98
  predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
99
  size=(H, W), mode='bilinear', align_corners=True)[:,0]
100
- predictions["intrs"][..., :1, :] *= W/W_proc
101
- predictions["intrs"][..., 1:2, :] *= H/H_proc
102
 
103
  if self.training:
104
  loss = compute_loss(predictions, annots)
105
  predictions["loss"] = loss
106
-
107
  return predictions
 
65
 
66
  if len(images.shape) == 4:
67
  images = images.unsqueeze(0)
68
+
69
  with torch.no_grad():
70
  aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
71
 
72
  predictions = {}
73
 
74
+ with torch.amp.autocast('cuda', enabled=False):
75
  if self.camera_head is not None:
76
  pose_enc_list = self.camera_head(aggregated_tokens_list)
77
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
 
97
  size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
98
  predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
99
  size=(H, W), mode='bilinear', align_corners=True)[:,0]
100
+ predictions["intrs"][..., :1, :] *= W/W_proc
101
+ predictions["intrs"][..., 1:2, :] *= H/H_proc
102
 
103
  if self.training:
104
  loss = compute_loss(predictions, annots)
105
  predictions["loss"] = loss
106
+
107
  return predictions
models/vggt/vggt/models/tracker_front.py CHANGED
@@ -75,15 +75,15 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
75
  B, T, C, H, W = images.shape
76
  images = (images - self.base_model.image_mean) / self.base_model.image_std
77
  H_14 = H // 14 * 14
78
- W_14 = W // 14 * 14
79
  image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
80
 
81
  with torch.no_grad():
82
- features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
83
  self.base_model.intermediate_layers, return_class_token=True)
84
  # aggregate the features with checkpoint
85
  aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
86
-
87
  # enhance the features
88
  enhanced_features = []
89
  for layer_i, layer in enumerate(self.intermediate_layers):
@@ -94,7 +94,7 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
94
 
95
  predictions = {}
96
 
97
- with torch.cuda.amp.autocast(enabled=False):
98
  if self.camera_head is not None:
99
  pose_enc_list = self.camera_head(aggregated_tokens_list)
100
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
@@ -104,7 +104,7 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
104
  # Predict points (and mask) with checkpoint
105
  output = self.base_model.head(enhanced_features, image_14)
106
  points, mask = output
107
-
108
  # Post-process points and mask
109
  points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
110
  points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
@@ -119,13 +119,13 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
119
  if self.training:
120
  loss = compute_loss(predictions, annots)
121
  predictions["loss"] = loss
122
-
123
  # rescale the points
124
  if self.scale_head is not None:
125
  points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
126
  points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
127
  predictions["points_map"] = points_scale
128
-
129
  predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
130
  predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
131
  predictions["images"].shape[-2:])
 
75
  B, T, C, H, W = images.shape
76
  images = (images - self.base_model.image_mean) / self.base_model.image_std
77
  H_14 = H // 14 * 14
78
+ W_14 = W // 14 * 14
79
  image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
80
 
81
  with torch.no_grad():
82
+ features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
83
  self.base_model.intermediate_layers, return_class_token=True)
84
  # aggregate the features with checkpoint
85
  aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
86
+
87
  # enhance the features
88
  enhanced_features = []
89
  for layer_i, layer in enumerate(self.intermediate_layers):
 
94
 
95
  predictions = {}
96
 
97
+ with torch.amp.autocast('cuda', enabled=False):
98
  if self.camera_head is not None:
99
  pose_enc_list = self.camera_head(aggregated_tokens_list)
100
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
 
104
  # Predict points (and mask) with checkpoint
105
  output = self.base_model.head(enhanced_features, image_14)
106
  points, mask = output
107
+
108
  # Post-process points and mask
109
  points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
110
  points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
 
119
  if self.training:
120
  loss = compute_loss(predictions, annots)
121
  predictions["loss"] = loss
122
+
123
  # rescale the points
124
  if self.scale_head is not None:
125
  points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
126
  points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
127
  predictions["points_map"] = points_scale
128
+
129
  predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
130
  predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
131
  predictions["images"].shape[-2:])
models/vggt/vggt/models/vggt.py CHANGED
@@ -64,7 +64,7 @@ class VGGT(nn.Module, PyTorchModelHubMixin):
64
 
65
  predictions = {}
66
 
67
- with torch.cuda.amp.autocast(enabled=False):
68
  if self.camera_head is not None:
69
  pose_enc_list = self.camera_head(aggregated_tokens_list)
70
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
 
64
 
65
  predictions = {}
66
 
67
+ with torch.amp.autocast('cuda', enabled=False):
68
  if self.camera_head is not None:
69
  pose_enc_list = self.camera_head(aggregated_tokens_list)
70
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
models/vggt/vggt/models/vggt_moe.py CHANGED
@@ -65,13 +65,13 @@ class VGGT_MoE(nn.Module, PyTorchModelHubMixin):
65
 
66
  if len(images.shape) == 4:
67
  images = images.unsqueeze(0)
68
-
69
  with torch.no_grad():
70
  aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
71
 
72
  predictions = {}
73
 
74
- with torch.cuda.amp.autocast(enabled=False):
75
  if self.camera_head is not None:
76
  pose_enc_list = self.camera_head(aggregated_tokens_list)
77
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
@@ -97,11 +97,11 @@ class VGGT_MoE(nn.Module, PyTorchModelHubMixin):
97
  size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
98
  predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
99
  size=(H, W), mode='bilinear', align_corners=True)[:,0]
100
- predictions["intrs"][..., :1, :] *= W/W_proc
101
- predictions["intrs"][..., 1:2, :] *= H/H_proc
102
 
103
  if self.training:
104
  loss = compute_loss(predictions, annots)
105
  predictions["loss"] = loss
106
-
107
  return predictions
 
65
 
66
  if len(images.shape) == 4:
67
  images = images.unsqueeze(0)
68
+
69
  with torch.no_grad():
70
  aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
71
 
72
  predictions = {}
73
 
74
+ with torch.amp.autocast('cuda', enabled=False):
75
  if self.camera_head is not None:
76
  pose_enc_list = self.camera_head(aggregated_tokens_list)
77
  predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
 
97
  size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
98
  predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
99
  size=(H, W), mode='bilinear', align_corners=True)[:,0]
100
+ predictions["intrs"][..., :1, :] *= W/W_proc
101
+ predictions["intrs"][..., 1:2, :] *= H/H_proc
102
 
103
  if self.training:
104
  loss = compute_loss(predictions, annots)
105
  predictions["loss"] = loss
106
+
107
  return predictions