Spaces:
Sleeping
Sleeping
move wan pipeline to cuda in zero gpu inference time
Browse files- app.py +4 -2
- app_3rd/spatrack_utils/infer_track.py +18 -18
- inference.py +15 -15
- models/SpaTrackV2/models/utils.py +59 -59
- models/SpaTrackV2/models/vggt4track/models/tracker_front.py +7 -7
- models/SpaTrackV2/models/vggt4track/models/vggt.py +1 -1
- models/SpaTrackV2/models/vggt4track/models/vggt_moe.py +5 -5
- models/vggt/vggt/models/tracker_front.py +7 -7
- models/vggt/vggt/models/vggt.py +1 -1
- models/vggt/vggt/models/vggt_moe.py +5 -5
app.py
CHANGED
|
@@ -103,7 +103,6 @@ wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained(
|
|
| 103 |
)
|
| 104 |
wan_pipeline.vae.enable_tiling()
|
| 105 |
wan_pipeline.vae.enable_slicing()
|
| 106 |
-
wan_pipeline.to("cuda")
|
| 107 |
|
| 108 |
|
| 109 |
|
|
@@ -218,7 +217,7 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 218 |
video_input = preprocess_image(video_tensor)[None].cuda()
|
| 219 |
|
| 220 |
with torch.no_grad():
|
| 221 |
-
with torch.
|
| 222 |
predictions = vggt4track_model(video_input / 255)
|
| 223 |
extrinsic = predictions["poses_pred"]
|
| 224 |
intrinsic = predictions["intrs"]
|
|
@@ -293,6 +292,9 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
|
|
| 293 |
"毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
| 294 |
)
|
| 295 |
|
|
|
|
|
|
|
|
|
|
| 296 |
# Match resolution logic from run_wan.py
|
| 297 |
max_area = 480 * 832
|
| 298 |
mod_value = wan_pipeline.vae_scale_factor_spatial * \
|
|
|
|
| 103 |
)
|
| 104 |
wan_pipeline.vae.enable_tiling()
|
| 105 |
wan_pipeline.vae.enable_slicing()
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
|
|
|
|
| 217 |
video_input = preprocess_image(video_tensor)[None].cuda()
|
| 218 |
|
| 219 |
with torch.no_grad():
|
| 220 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 221 |
predictions = vggt4track_model(video_input / 255)
|
| 222 |
extrinsic = predictions["poses_pred"]
|
| 223 |
intrinsic = predictions["intrs"]
|
|
|
|
| 292 |
"毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
| 293 |
)
|
| 294 |
|
| 295 |
+
|
| 296 |
+
wan_pipeline.to("cuda")
|
| 297 |
+
|
| 298 |
# Match resolution logic from run_wan.py
|
| 299 |
max_area = 480 * 832
|
| 300 |
mod_value = wan_pipeline.vae_scale_factor_spatial * \
|
app_3rd/spatrack_utils/infer_track.py
CHANGED
|
@@ -34,13 +34,13 @@ def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=N
|
|
| 34 |
"""
|
| 35 |
viz = True
|
| 36 |
os.makedirs(output_dir, exist_ok=True)
|
| 37 |
-
|
| 38 |
with open(config["cfg_dir"], "r") as f:
|
| 39 |
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
| 40 |
cfg = easydict.EasyDict(cfg)
|
| 41 |
cfg.out_dir = output_dir
|
| 42 |
cfg.model.track_num = vo_points
|
| 43 |
-
|
| 44 |
# Check if it's a local path or HuggingFace repo
|
| 45 |
if tracker_model is not None:
|
| 46 |
model = tracker_model
|
|
@@ -60,8 +60,8 @@ def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=N
|
|
| 60 |
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
|
| 61 |
model.eval()
|
| 62 |
model.to("cuda")
|
| 63 |
-
|
| 64 |
-
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
|
| 65 |
fps=10, pad_value=0, tracks_leave_trace=5)
|
| 66 |
|
| 67 |
return model, viser
|
|
@@ -83,11 +83,11 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
|
|
| 83 |
mask_path = os.path.join(temp_dir, f"{video_name}.png")
|
| 84 |
out_dir = os.path.join(temp_dir, "results")
|
| 85 |
os.makedirs(out_dir, exist_ok=True)
|
| 86 |
-
|
| 87 |
# Load video using decord
|
| 88 |
video_reader = decord.VideoReader(video_path)
|
| 89 |
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
|
| 90 |
-
|
| 91 |
# resize make sure the shortest side is 336
|
| 92 |
h, w = video_tensor.shape[2:]
|
| 93 |
scale = max(336 / h, 336 / w)
|
|
@@ -99,7 +99,7 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
|
|
| 99 |
intrs = None
|
| 100 |
extrs = None
|
| 101 |
data_npz_load = {}
|
| 102 |
-
|
| 103 |
# Load and process mask
|
| 104 |
if os.path.exists(mask_path):
|
| 105 |
mask = cv2.imread(mask_path)
|
|
@@ -107,20 +107,20 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
|
|
| 107 |
mask = mask.sum(axis=-1)>0
|
| 108 |
else:
|
| 109 |
mask = np.ones_like(video_tensor[0,0].numpy())>0
|
| 110 |
-
|
| 111 |
# Get frame dimensions and create grid points
|
| 112 |
frame_H, frame_W = video_tensor.shape[2:]
|
| 113 |
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
|
| 114 |
-
|
| 115 |
# Sample mask values at grid points and filter out points where mask=0
|
| 116 |
if os.path.exists(mask_path):
|
| 117 |
grid_pts_int = grid_pts[0].long()
|
| 118 |
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
|
| 119 |
grid_pts = grid_pts[:, mask_values]
|
| 120 |
-
|
| 121 |
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
|
| 122 |
|
| 123 |
-
# run vggt
|
| 124 |
if os.environ.get("VGGT_DIR", None) is not None:
|
| 125 |
vggt_model = VGGT()
|
| 126 |
vggt_model.load_state_dict(torch.load(VGGT_DIR))
|
|
@@ -128,7 +128,7 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
|
|
| 128 |
vggt_model = vggt_model.to("cuda")
|
| 129 |
# process the image tensor
|
| 130 |
video_tensor = preprocess_image(video_tensor)[None]
|
| 131 |
-
with torch.
|
| 132 |
# Predict attributes including cameras, depth maps, and point maps.
|
| 133 |
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
|
| 134 |
pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
|
|
@@ -154,12 +154,12 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
|
|
| 154 |
c2w_traj, intrs, point_map, conf_depth,
|
| 155 |
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
| 156 |
) = model.forward(video_tensor, depth=depth_tensor,
|
| 157 |
-
intrs=intrs, extrs=extrs,
|
| 158 |
queries=query_xyt,
|
| 159 |
fps=1, full_point=False, iters_track=4,
|
| 160 |
query_no_BA=True, fixed_cam=False, stage=1,
|
| 161 |
-
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
| 162 |
-
|
| 163 |
# Resize results to avoid too large I/O Burden
|
| 164 |
max_size = 336
|
| 165 |
h, w = video.shape[2:]
|
|
@@ -174,12 +174,12 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
|
|
| 174 |
if depth_tensor is not None:
|
| 175 |
depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
|
| 176 |
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 177 |
-
|
| 178 |
# Visualize tracks
|
| 179 |
viser.visualize(video=video[None],
|
| 180 |
tracks=track2d_pred[None][...,:2],
|
| 181 |
visibility=vis_pred[None],filename="test")
|
| 182 |
-
|
| 183 |
# Save in tapip3d format
|
| 184 |
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
| 185 |
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
|
@@ -190,5 +190,5 @@ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3)
|
|
| 190 |
data_npz_load["confs"] = conf_pred.cpu().numpy()
|
| 191 |
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
|
| 192 |
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
|
| 193 |
-
|
| 194 |
print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
|
|
|
|
| 34 |
"""
|
| 35 |
viz = True
|
| 36 |
os.makedirs(output_dir, exist_ok=True)
|
| 37 |
+
|
| 38 |
with open(config["cfg_dir"], "r") as f:
|
| 39 |
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
| 40 |
cfg = easydict.EasyDict(cfg)
|
| 41 |
cfg.out_dir = output_dir
|
| 42 |
cfg.model.track_num = vo_points
|
| 43 |
+
|
| 44 |
# Check if it's a local path or HuggingFace repo
|
| 45 |
if tracker_model is not None:
|
| 46 |
model = tracker_model
|
|
|
|
| 60 |
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
|
| 61 |
model.eval()
|
| 62 |
model.to("cuda")
|
| 63 |
+
|
| 64 |
+
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
|
| 65 |
fps=10, pad_value=0, tracks_leave_trace=5)
|
| 66 |
|
| 67 |
return model, viser
|
|
|
|
| 83 |
mask_path = os.path.join(temp_dir, f"{video_name}.png")
|
| 84 |
out_dir = os.path.join(temp_dir, "results")
|
| 85 |
os.makedirs(out_dir, exist_ok=True)
|
| 86 |
+
|
| 87 |
# Load video using decord
|
| 88 |
video_reader = decord.VideoReader(video_path)
|
| 89 |
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
|
| 90 |
+
|
| 91 |
# resize make sure the shortest side is 336
|
| 92 |
h, w = video_tensor.shape[2:]
|
| 93 |
scale = max(336 / h, 336 / w)
|
|
|
|
| 99 |
intrs = None
|
| 100 |
extrs = None
|
| 101 |
data_npz_load = {}
|
| 102 |
+
|
| 103 |
# Load and process mask
|
| 104 |
if os.path.exists(mask_path):
|
| 105 |
mask = cv2.imread(mask_path)
|
|
|
|
| 107 |
mask = mask.sum(axis=-1)>0
|
| 108 |
else:
|
| 109 |
mask = np.ones_like(video_tensor[0,0].numpy())>0
|
| 110 |
+
|
| 111 |
# Get frame dimensions and create grid points
|
| 112 |
frame_H, frame_W = video_tensor.shape[2:]
|
| 113 |
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
|
| 114 |
+
|
| 115 |
# Sample mask values at grid points and filter out points where mask=0
|
| 116 |
if os.path.exists(mask_path):
|
| 117 |
grid_pts_int = grid_pts[0].long()
|
| 118 |
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
|
| 119 |
grid_pts = grid_pts[:, mask_values]
|
| 120 |
+
|
| 121 |
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
|
| 122 |
|
| 123 |
+
# run vggt
|
| 124 |
if os.environ.get("VGGT_DIR", None) is not None:
|
| 125 |
vggt_model = VGGT()
|
| 126 |
vggt_model.load_state_dict(torch.load(VGGT_DIR))
|
|
|
|
| 128 |
vggt_model = vggt_model.to("cuda")
|
| 129 |
# process the image tensor
|
| 130 |
video_tensor = preprocess_image(video_tensor)[None]
|
| 131 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 132 |
# Predict attributes including cameras, depth maps, and point maps.
|
| 133 |
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
|
| 134 |
pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
|
|
|
|
| 154 |
c2w_traj, intrs, point_map, conf_depth,
|
| 155 |
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
| 156 |
) = model.forward(video_tensor, depth=depth_tensor,
|
| 157 |
+
intrs=intrs, extrs=extrs,
|
| 158 |
queries=query_xyt,
|
| 159 |
fps=1, full_point=False, iters_track=4,
|
| 160 |
query_no_BA=True, fixed_cam=False, stage=1,
|
| 161 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
| 162 |
+
|
| 163 |
# Resize results to avoid too large I/O Burden
|
| 164 |
max_size = 336
|
| 165 |
h, w = video.shape[2:]
|
|
|
|
| 174 |
if depth_tensor is not None:
|
| 175 |
depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
|
| 176 |
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 177 |
+
|
| 178 |
# Visualize tracks
|
| 179 |
viser.visualize(video=video[None],
|
| 180 |
tracks=track2d_pred[None][...,:2],
|
| 181 |
visibility=vis_pred[None],filename="test")
|
| 182 |
+
|
| 183 |
# Save in tapip3d format
|
| 184 |
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
| 185 |
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
|
|
|
| 190 |
data_npz_load["confs"] = conf_pred.cpu().numpy()
|
| 191 |
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
|
| 192 |
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
|
| 193 |
+
|
| 194 |
print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
|
inference.py
CHANGED
|
@@ -38,7 +38,7 @@ if __name__ == "__main__":
|
|
| 38 |
# fps
|
| 39 |
fps = int(args.fps)
|
| 40 |
mask_dir = args.data_dir + f"/{args.video_name}.png"
|
| 41 |
-
|
| 42 |
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
|
| 43 |
vggt4track_model.eval()
|
| 44 |
vggt4track_model = vggt4track_model.to("cuda")
|
|
@@ -66,12 +66,12 @@ if __name__ == "__main__":
|
|
| 66 |
# process the image tensor
|
| 67 |
video_tensor = preprocess_image(video_tensor)[None]
|
| 68 |
with torch.no_grad():
|
| 69 |
-
with torch.
|
| 70 |
# Predict attributes including cameras, depth maps, and point maps.
|
| 71 |
predictions = vggt4track_model(video_tensor.cuda()/255)
|
| 72 |
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
|
| 73 |
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
| 74 |
-
|
| 75 |
depth_tensor = depth_map.squeeze().cpu().numpy()
|
| 76 |
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
| 77 |
extrs = extrinsic.squeeze().cpu().numpy()
|
|
@@ -82,7 +82,7 @@ if __name__ == "__main__":
|
|
| 82 |
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 83 |
|
| 84 |
data_npz_load = {}
|
| 85 |
-
|
| 86 |
if os.path.exists(mask_dir):
|
| 87 |
mask_files = mask_dir
|
| 88 |
mask = cv2.imread(mask_files)
|
|
@@ -90,11 +90,11 @@ if __name__ == "__main__":
|
|
| 90 |
mask = mask.sum(axis=-1)>0
|
| 91 |
else:
|
| 92 |
mask = np.ones_like(video_tensor[0,0].numpy())>0
|
| 93 |
-
|
| 94 |
# get all data pieces
|
| 95 |
viz = True
|
| 96 |
os.makedirs(out_dir, exist_ok=True)
|
| 97 |
-
|
| 98 |
# with open(cfg_dir, "r") as f:
|
| 99 |
# cfg = yaml.load(f, Loader=yaml.FullLoader)
|
| 100 |
# cfg = easydict.EasyDict(cfg)
|
|
@@ -108,12 +108,12 @@ if __name__ == "__main__":
|
|
| 108 |
|
| 109 |
# config the model; the track_num is the number of points in the grid
|
| 110 |
model.spatrack.track_num = args.vo_points
|
| 111 |
-
|
| 112 |
model.eval()
|
| 113 |
model.to("cuda")
|
| 114 |
-
viser = Visualizer(save_dir=out_dir, grayscale=True,
|
| 115 |
fps=10, pad_value=0, tracks_leave_trace=5)
|
| 116 |
-
|
| 117 |
grid_size = args.grid_size
|
| 118 |
|
| 119 |
# get frame H W
|
|
@@ -124,13 +124,13 @@ if __name__ == "__main__":
|
|
| 124 |
else:
|
| 125 |
frame_H, frame_W = video_tensor.shape[2:]
|
| 126 |
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
|
| 127 |
-
|
| 128 |
# Sample mask values at grid points and filter out points where mask=0
|
| 129 |
if os.path.exists(mask_dir):
|
| 130 |
grid_pts_int = grid_pts[0].long()
|
| 131 |
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
|
| 132 |
grid_pts = grid_pts[:, mask_values]
|
| 133 |
-
|
| 134 |
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
|
| 135 |
|
| 136 |
# Run model inference
|
|
@@ -139,12 +139,12 @@ if __name__ == "__main__":
|
|
| 139 |
c2w_traj, intrs, point_map, conf_depth,
|
| 140 |
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
| 141 |
) = model.forward(video_tensor, depth=depth_tensor,
|
| 142 |
-
intrs=intrs, extrs=extrs,
|
| 143 |
queries=query_xyt,
|
| 144 |
fps=1, full_point=False, iters_track=4,
|
| 145 |
query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
|
| 146 |
-
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
| 147 |
-
|
| 148 |
# resize the results to avoid too large I/O Burden
|
| 149 |
# depth and image, the maximum side is 336
|
| 150 |
max_size = 336
|
|
@@ -169,7 +169,7 @@ if __name__ == "__main__":
|
|
| 169 |
tracks=track2d_pred[None][...,:2],
|
| 170 |
visibility=vis_pred[None],filename="test")
|
| 171 |
|
| 172 |
-
# save as the tapip3d format
|
| 173 |
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
| 174 |
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
| 175 |
data_npz_load["intrinsics"] = intrs.cpu().numpy()
|
|
|
|
| 38 |
# fps
|
| 39 |
fps = int(args.fps)
|
| 40 |
mask_dir = args.data_dir + f"/{args.video_name}.png"
|
| 41 |
+
|
| 42 |
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
|
| 43 |
vggt4track_model.eval()
|
| 44 |
vggt4track_model = vggt4track_model.to("cuda")
|
|
|
|
| 66 |
# process the image tensor
|
| 67 |
video_tensor = preprocess_image(video_tensor)[None]
|
| 68 |
with torch.no_grad():
|
| 69 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 70 |
# Predict attributes including cameras, depth maps, and point maps.
|
| 71 |
predictions = vggt4track_model(video_tensor.cuda()/255)
|
| 72 |
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
|
| 73 |
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
| 74 |
+
|
| 75 |
depth_tensor = depth_map.squeeze().cpu().numpy()
|
| 76 |
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
| 77 |
extrs = extrinsic.squeeze().cpu().numpy()
|
|
|
|
| 82 |
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 83 |
|
| 84 |
data_npz_load = {}
|
| 85 |
+
|
| 86 |
if os.path.exists(mask_dir):
|
| 87 |
mask_files = mask_dir
|
| 88 |
mask = cv2.imread(mask_files)
|
|
|
|
| 90 |
mask = mask.sum(axis=-1)>0
|
| 91 |
else:
|
| 92 |
mask = np.ones_like(video_tensor[0,0].numpy())>0
|
| 93 |
+
|
| 94 |
# get all data pieces
|
| 95 |
viz = True
|
| 96 |
os.makedirs(out_dir, exist_ok=True)
|
| 97 |
+
|
| 98 |
# with open(cfg_dir, "r") as f:
|
| 99 |
# cfg = yaml.load(f, Loader=yaml.FullLoader)
|
| 100 |
# cfg = easydict.EasyDict(cfg)
|
|
|
|
| 108 |
|
| 109 |
# config the model; the track_num is the number of points in the grid
|
| 110 |
model.spatrack.track_num = args.vo_points
|
| 111 |
+
|
| 112 |
model.eval()
|
| 113 |
model.to("cuda")
|
| 114 |
+
viser = Visualizer(save_dir=out_dir, grayscale=True,
|
| 115 |
fps=10, pad_value=0, tracks_leave_trace=5)
|
| 116 |
+
|
| 117 |
grid_size = args.grid_size
|
| 118 |
|
| 119 |
# get frame H W
|
|
|
|
| 124 |
else:
|
| 125 |
frame_H, frame_W = video_tensor.shape[2:]
|
| 126 |
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
|
| 127 |
+
|
| 128 |
# Sample mask values at grid points and filter out points where mask=0
|
| 129 |
if os.path.exists(mask_dir):
|
| 130 |
grid_pts_int = grid_pts[0].long()
|
| 131 |
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
|
| 132 |
grid_pts = grid_pts[:, mask_values]
|
| 133 |
+
|
| 134 |
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
|
| 135 |
|
| 136 |
# Run model inference
|
|
|
|
| 139 |
c2w_traj, intrs, point_map, conf_depth,
|
| 140 |
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
| 141 |
) = model.forward(video_tensor, depth=depth_tensor,
|
| 142 |
+
intrs=intrs, extrs=extrs,
|
| 143 |
queries=query_xyt,
|
| 144 |
fps=1, full_point=False, iters_track=4,
|
| 145 |
query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
|
| 146 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
| 147 |
+
|
| 148 |
# resize the results to avoid too large I/O Burden
|
| 149 |
# depth and image, the maximum side is 336
|
| 150 |
max_size = 336
|
|
|
|
| 169 |
tracks=track2d_pred[None][...,:2],
|
| 170 |
visibility=vis_pred[None],filename="test")
|
| 171 |
|
| 172 |
+
# save as the tapip3d format
|
| 173 |
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
| 174 |
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
| 175 |
data_npz_load["intrinsics"] = intrs.cpu().numpy()
|
models/SpaTrackV2/models/utils.py
CHANGED
|
@@ -95,7 +95,7 @@ class AverageMeter(object):
|
|
| 95 |
return fmtstr.format(**self.__dict__)
|
| 96 |
|
| 97 |
|
| 98 |
-
def procrustes_analysis(X0,X1): # [N,3]
|
| 99 |
# translation
|
| 100 |
t0 = X0.mean(dim=0,keepdim=True)
|
| 101 |
t1 = X1.mean(dim=0,keepdim=True)
|
|
@@ -218,7 +218,7 @@ def get_EFP(pred_cameras, image_size, B, S, default_focal=False):
|
|
| 218 |
|
| 219 |
intrinsics = create_intri_matrix(focal_length, principal_point)
|
| 220 |
return extrinsics, intrinsics
|
| 221 |
-
|
| 222 |
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
| 223 |
"""
|
| 224 |
Convert rotations given as quaternions to rotation matrices.
|
|
@@ -278,7 +278,7 @@ def pose_encoding_to_camera(
|
|
| 278 |
# Now converted back
|
| 279 |
focal_length = (log_focal_length + log_focal_length_bias).exp()
|
| 280 |
# clamp to avoid weird fl values
|
| 281 |
-
focal_length = torch.clamp(focal_length,
|
| 282 |
min=min_focal_length, max=max_focal_length)
|
| 283 |
elif pose_encoding_type == "absT_quaR_OneFL":
|
| 284 |
# 3 for absT, 4 for quaR, 1 for absFL
|
|
@@ -287,7 +287,7 @@ def pose_encoding_to_camera(
|
|
| 287 |
quaternion_R = pose_encoding_reshaped[:, 3:7]
|
| 288 |
R = quaternion_to_matrix(quaternion_R)
|
| 289 |
focal_length = pose_encoding_reshaped[:, 7:8]
|
| 290 |
-
focal_length = torch.clamp(focal_length,
|
| 291 |
min=min_focal_length, max=max_focal_length)
|
| 292 |
else:
|
| 293 |
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
|
|
@@ -316,7 +316,7 @@ def pose_encoding_to_camera(
|
|
| 316 |
|
| 317 |
R = extrinsics_4x4[:, :3, :3].clone()
|
| 318 |
abs_T = extrinsics_4x4[:, :3, 3].clone()
|
| 319 |
-
|
| 320 |
if return_dict:
|
| 321 |
return {"focal_length": focal_length, "R": R, "T": abs_T}
|
| 322 |
|
|
@@ -326,7 +326,7 @@ def pose_encoding_to_camera(
|
|
| 326 |
|
| 327 |
|
| 328 |
def camera_to_pose_encoding(
|
| 329 |
-
camera, pose_encoding_type="absT_quaR_logFL",
|
| 330 |
log_focal_length_bias=1.8, min_focal_length=0.1, max_focal_length=30
|
| 331 |
):
|
| 332 |
"""
|
|
@@ -359,7 +359,7 @@ def camera_to_pose_encoding(
|
|
| 359 |
return pose_encoding
|
| 360 |
|
| 361 |
|
| 362 |
-
def init_pose_enc(B: int,
|
| 363 |
S: int, pose_encoding_type: str="absT_quaR_logFL",
|
| 364 |
device: Optional[torch.device]=None):
|
| 365 |
"""
|
|
@@ -378,7 +378,7 @@ def init_pose_enc(B: int,
|
|
| 378 |
C = 8
|
| 379 |
else:
|
| 380 |
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
|
| 381 |
-
|
| 382 |
pose_enc = torch.zeros(B, S, C, device=device)
|
| 383 |
pose_enc[..., :3] = 0 # absT
|
| 384 |
pose_enc[..., 3] = 1 # quaR
|
|
@@ -389,7 +389,7 @@ def first_pose_enc_norm(pose_enc: torch.Tensor,
|
|
| 389 |
pose_encoding_type: str="absT_quaR_OneFL",
|
| 390 |
pose_mode: str = "W2C"):
|
| 391 |
"""
|
| 392 |
-
make sure the poses in on window are normalized by the first frame, where the
|
| 393 |
first frame transformation is the Identity Matrix.
|
| 394 |
NOTE: Poses are all W2C
|
| 395 |
args:
|
|
@@ -403,23 +403,23 @@ def first_pose_enc_norm(pose_enc: torch.Tensor,
|
|
| 403 |
pose_enc, pose_encoding_type=pose_encoding_type,
|
| 404 |
to_OpenCV=False
|
| 405 |
) #NOTE: the camera parameters are not in NDC
|
| 406 |
-
|
| 407 |
R = pred_cameras.R # [B*S, 3, 3]
|
| 408 |
T = pred_cameras.T # [B*S, 3]
|
| 409 |
-
|
| 410 |
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*S, 3, 4]
|
| 411 |
extra_ = torch.tensor([[[0, 0, 0, 1]]],
|
| 412 |
device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
|
| 413 |
Tran_M = torch.cat([Tran_M, extra_
|
| 414 |
], dim=1)
|
| 415 |
Tran_M = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)
|
| 416 |
-
|
| 417 |
# Take the first frame as the base of world coordinate
|
| 418 |
if pose_mode == "C2W":
|
| 419 |
Tran_M_new = (Tran_M[:,:1,...].inverse())@Tran_M
|
| 420 |
elif pose_mode == "W2C":
|
| 421 |
Tran_M_new = Tran_M@(Tran_M[:,:1,...].inverse())
|
| 422 |
-
|
| 423 |
Tran_M_new = rearrange(Tran_M_new, 'b s c d -> (b s) c d')
|
| 424 |
|
| 425 |
R_ = Tran_M_new[:, :3, :3]
|
|
@@ -429,7 +429,7 @@ def first_pose_enc_norm(pose_enc: torch.Tensor,
|
|
| 429 |
pred_cameras.R = R_
|
| 430 |
pred_cameras.T = T_
|
| 431 |
pose_enc_norm = camera_to_pose_encoding(pred_cameras,
|
| 432 |
-
pose_encoding_type=pose_encoding_type)
|
| 433 |
pose_enc_norm = rearrange(pose_enc_norm, '(b s) c -> b s c', b=B)
|
| 434 |
return pose_enc_norm
|
| 435 |
|
|
@@ -439,7 +439,7 @@ def first_pose_enc_denorm(
|
|
| 439 |
pose_encoding_type: str="absT_quaR_OneFL",
|
| 440 |
pose_mode: str = "W2C"):
|
| 441 |
"""
|
| 442 |
-
make sure the poses in on window are de-normalized by the first frame, where the
|
| 443 |
first frame transformation is the Identity Matrix.
|
| 444 |
args:
|
| 445 |
pose_enc: [B S C]
|
|
@@ -457,7 +457,7 @@ def first_pose_enc_denorm(
|
|
| 457 |
) #NOTE: the camera parameters are not in NDC
|
| 458 |
R = pred_cameras.R # [B*(1+S), 3, 3]
|
| 459 |
T = pred_cameras.T # [B*(1+S), 3]
|
| 460 |
-
|
| 461 |
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*(1+S), 3, 4]
|
| 462 |
extra_ = torch.tensor([[[0, 0, 0, 1]]],
|
| 463 |
device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
|
|
@@ -470,7 +470,7 @@ def first_pose_enc_denorm(
|
|
| 470 |
Tran_M_new = Tran_M_1st@Tran_M_new
|
| 471 |
elif pose_mode == "W2C":
|
| 472 |
Tran_M_new = Tran_M_new@Tran_M_1st
|
| 473 |
-
|
| 474 |
Tran_M_new_ = torch.cat([Tran_M_1st, Tran_M_new], dim=1)
|
| 475 |
R_ = Tran_M_new_[..., :3, :3].view(-1, 3, 3)
|
| 476 |
T_ = Tran_M_new_[..., :3, 3].view(-1, 3)
|
|
@@ -481,7 +481,7 @@ def first_pose_enc_denorm(
|
|
| 481 |
|
| 482 |
# Cameras to Pose encoding
|
| 483 |
pose_enc_denorm = camera_to_pose_encoding(pred_cameras,
|
| 484 |
-
pose_encoding_type=pose_encoding_type)
|
| 485 |
pose_enc_denorm = rearrange(pose_enc_denorm, '(b s) c -> b s c', b=B)
|
| 486 |
return pose_enc_denorm[:, 1:]
|
| 487 |
|
|
@@ -560,7 +560,7 @@ def median_loss(prediction, target, mask, Bs):
|
|
| 560 |
target_nm, a_norm_gt, b_norm_gt = normalize_prediction_robust(target.float(), mask, Bs)
|
| 561 |
depth_loss = nn.functional.l1_loss(prediction_nm[mask], target_nm[mask])
|
| 562 |
scale = b_norm_gt/b_norm
|
| 563 |
-
shift = a_norm_gt - a_norm*scale
|
| 564 |
return depth_loss, scale, shift, prediction_nm, target_nm
|
| 565 |
|
| 566 |
def reduction_batch_based(image_loss, M):
|
|
@@ -593,7 +593,7 @@ class ScaleAndShiftInvariantLoss(nn.Module):
|
|
| 593 |
|
| 594 |
def forward(self, prediction, target, mask, Bs,
|
| 595 |
interpolate=True, return_interpolated=False):
|
| 596 |
-
|
| 597 |
if prediction.shape[-1] != target.shape[-1] and interpolate:
|
| 598 |
prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True)
|
| 599 |
intr_input = prediction
|
|
@@ -602,7 +602,7 @@ class ScaleAndShiftInvariantLoss(nn.Module):
|
|
| 602 |
|
| 603 |
prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze()
|
| 604 |
assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}."
|
| 605 |
-
|
| 606 |
|
| 607 |
scale, shift = compute_scale_and_shift(prediction, target, mask)
|
| 608 |
a_norm = scale.view(Bs, -1, 1, 1).mean(dim=1, keepdim=True)
|
|
@@ -634,7 +634,7 @@ class GradientLoss(nn.Module):
|
|
| 634 |
|
| 635 |
for scale in range(self.__scales):
|
| 636 |
step = pow(2, scale)
|
| 637 |
-
l1_ln, a_nm, b_nm = ScaleAndShiftInvariantLoss_fn(prediction[:, ::step, ::step],
|
| 638 |
target[:, ::step, ::step], mask[:, ::step, ::step], 1)
|
| 639 |
total += l1_ln
|
| 640 |
a_nm = a_nm.squeeze().detach() # [B, 1, 1]
|
|
@@ -663,7 +663,7 @@ def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
|
|
| 663 |
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
|
| 664 |
|
| 665 |
return reduction(image_loss, M)
|
| 666 |
-
|
| 667 |
def loss_fn(
|
| 668 |
poses_preds: List[torch.Tensor],
|
| 669 |
poses_pred_all: List[torch.Tensor],
|
|
@@ -700,7 +700,7 @@ def loss_fn(
|
|
| 700 |
if logger is not None:
|
| 701 |
if poses_preds_ij.max()>5e1:
|
| 702 |
logger.info(f"pose_pred_max_and_mean: {poses_preds_ij.max(), poses_preds_ij.mean()}")
|
| 703 |
-
|
| 704 |
trans_loss = (poses_preds_ij[...,:3] - poses_gt_i_norm[...,:3]).abs().sum(dim=-1).mean()
|
| 705 |
rot_loss = (poses_preds_ij[...,3:7] - poses_gt_i_norm[...,3:7]).abs().sum(dim=-1).mean()
|
| 706 |
focal_loss = (poses_preds_ij[...,7:] - poses_gt_i_norm[...,7:]).abs().sum(dim=-1).mean()
|
|
@@ -714,7 +714,7 @@ def loss_fn(
|
|
| 714 |
logger_tf.add_scalar(f"loss@pose/rot_iter{idx}",
|
| 715 |
rot_loss, global_step=global_step)
|
| 716 |
logger_tf.add_scalar(f"loss@pose/focal_iter{idx}",
|
| 717 |
-
focal_loss, global_step=global_step)
|
| 718 |
# compute the uncertainty loss
|
| 719 |
with torch.no_grad():
|
| 720 |
pose_loss_dist = (poses_preds_ij-poses_gt_i_norm).detach().abs()
|
|
@@ -726,9 +726,9 @@ def loss_fn(
|
|
| 726 |
unc_loss,
|
| 727 |
global_step=global_step)
|
| 728 |
# if logger is not None:
|
| 729 |
-
# logger.info(f"pose_loss: {pose_loss}, unc_loss: {unc_loss}")
|
| 730 |
# total loss
|
| 731 |
-
loss_total += 0.1*unc_loss + 2*pose_loss
|
| 732 |
|
| 733 |
poses_gt_norm = poses_gt
|
| 734 |
pose_all_loss = 0.0
|
|
@@ -743,7 +743,7 @@ def loss_fn(
|
|
| 743 |
prev_loss = (trans_loss + rot_loss + focal_loss)
|
| 744 |
else:
|
| 745 |
des_loss = (trans_loss + rot_loss + focal_loss) - prev_loss
|
| 746 |
-
prev_loss = trans_loss + rot_loss + focal_loss
|
| 747 |
logger_tf.add_scalar(f"loss@global_pose/des_iter{idx}",
|
| 748 |
des_loss, global_step=global_step)
|
| 749 |
logger_tf.add_scalar(f"loss@global_pose/trans_iter{idx}",
|
|
@@ -751,20 +751,20 @@ def loss_fn(
|
|
| 751 |
logger_tf.add_scalar(f"loss@global_pose/rot_iter{idx}",
|
| 752 |
rot_loss, global_step=global_step)
|
| 753 |
logger_tf.add_scalar(f"loss@global_pose/focal_iter{idx}",
|
| 754 |
-
focal_loss, global_step=global_step)
|
| 755 |
if torch.isnan((trans_loss + rot_loss + focal_loss)).any():
|
| 756 |
pose_all_loss += 0
|
| 757 |
else:
|
| 758 |
pose_all_loss += i_weight*(trans_loss + rot_loss + focal_loss)
|
| 759 |
-
|
| 760 |
# if logger is not None:
|
| 761 |
-
# logger.info(f"global_pose_loss: {pose_all_loss}")
|
| 762 |
|
| 763 |
# compute the depth loss
|
| 764 |
if inv_depth_preds[0] is not None:
|
| 765 |
depths_gt = depths_gt[:,:,0]
|
| 766 |
msk = depths_gt > 5e-2
|
| 767 |
-
inv_gt = 1.0 / (depths_gt.clamp(1e-3, 1e16))
|
| 768 |
inv_gt_reshp = rearrange(inv_gt, 'b t h w -> (b t) h w')
|
| 769 |
inv_depth_preds_reshp = rearrange(inv_depth_preds[0], 'b t h w -> (b t) h w')
|
| 770 |
inv_raw_reshp = rearrange(inv_depth_raw[0], 'b t h w -> (b t) h w')
|
|
@@ -785,11 +785,11 @@ def loss_fn(
|
|
| 785 |
depth_loss,
|
| 786 |
global_step=global_step)
|
| 787 |
# if logger is not None:
|
| 788 |
-
# logger.info(f"opt_depth: {huber_loss_raw - huber_loss}")
|
| 789 |
else:
|
| 790 |
depth_loss = 0.0
|
| 791 |
|
| 792 |
-
|
| 793 |
loss_total = loss_total/(len(poses_preds)) + 20*depth_loss + pose_all_loss
|
| 794 |
|
| 795 |
return loss_total, (huber_loss_raw - huber_loss)
|
|
@@ -803,7 +803,7 @@ def vis_depth(x: torch.tensor,
|
|
| 803 |
"""
|
| 804 |
assert len(x.shape) == 2
|
| 805 |
|
| 806 |
-
depth_map_normalized = cv2.normalize(x.cpu().numpy(),
|
| 807 |
None, 0, 255, cv2.NORM_MINMAX)
|
| 808 |
depth_map_colored = cv2.applyColorMap(depth_map_normalized.astype(np.uint8),
|
| 809 |
cv2.COLORMAP_JET)
|
|
@@ -848,7 +848,7 @@ def vis_pcd(
|
|
| 848 |
return pcl
|
| 849 |
|
| 850 |
def vis_result(rgbs, poses_pred, poses_gt,
|
| 851 |
-
depth_gt, depth_pred, iter_num=0,
|
| 852 |
vis=None, logger_tf=None, cfg=None):
|
| 853 |
"""
|
| 854 |
Args:
|
|
@@ -863,7 +863,7 @@ def vis_result(rgbs, poses_pred, poses_gt,
|
|
| 863 |
if vis is None:
|
| 864 |
return
|
| 865 |
S, _, H, W = depth_gt.shape
|
| 866 |
-
# get the xy
|
| 867 |
yx = torch.meshgrid(torch.arange(H).to(depth_pred.device),
|
| 868 |
torch.arange(W).to(depth_pred.device),indexing='ij')
|
| 869 |
xy = torch.stack(yx[::-1], dim=0).float().to(depth_pred.device)
|
|
@@ -880,7 +880,7 @@ def vis_result(rgbs, poses_pred, poses_gt,
|
|
| 880 |
pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
|
| 881 |
poses_pred_vis = pose_encoding_to_camera(poses_pred,
|
| 882 |
pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
|
| 883 |
-
|
| 884 |
R_gt = poses_gt_vis.R.float()
|
| 885 |
R_pred = poses_pred_vis.R.float()
|
| 886 |
T_gt = poses_gt_vis.T.float()
|
|
@@ -890,8 +890,8 @@ def vis_result(rgbs, poses_pred, poses_gt,
|
|
| 890 |
T_gt_c2w = (-R_gt_c2w @ T_gt[:, :, None]).squeeze(-1)
|
| 891 |
R_pred_c2w = R_pred.permute(0,2,1)
|
| 892 |
T_pred_c2w = (-R_pred_c2w @ T_pred[:, :, None]).squeeze(-1)
|
| 893 |
-
with torch.
|
| 894 |
-
pick_idx = torch.randperm(S)[:min(24, S)]
|
| 895 |
# pick_idx = [1]
|
| 896 |
#NOTE: very strange that the camera need C2W Rotation and W2C translation as input
|
| 897 |
poses_gt_vis = PerspectiveCamerasVisual(
|
|
@@ -922,9 +922,9 @@ def vis_result(rgbs, poses_pred, poses_gt,
|
|
| 922 |
fig = plot_scene(visual_dict, camera_scale=0.05)
|
| 923 |
vis.plotlyplot(fig, env=env_name, win="3D")
|
| 924 |
vis.save([env_name])
|
| 925 |
-
|
| 926 |
return
|
| 927 |
-
|
| 928 |
def depth2pcd(
|
| 929 |
xy_depth: torch.Tensor,
|
| 930 |
focal_length: torch.Tensor,
|
|
@@ -953,7 +953,7 @@ def depth2pcd(
|
|
| 953 |
K_inv = K.inverse()
|
| 954 |
# xyz
|
| 955 |
xyz = xy_depth.view(S, -1, 3).permute(0, 2, 1) # S 3 (H W)
|
| 956 |
-
depth = xyz[:, 2:].clone() # S (H W) 1
|
| 957 |
xyz[:, 2] = 1
|
| 958 |
xyz = K_inv @ xyz # S 3 (H W)
|
| 959 |
xyz = xyz * depth
|
|
@@ -963,29 +963,29 @@ def depth2pcd(
|
|
| 963 |
return xyz
|
| 964 |
|
| 965 |
|
| 966 |
-
def pose_enc2mat(poses_pred,
|
| 967 |
H_resize, W_resize, resolution=336):
|
| 968 |
"""
|
| 969 |
This function convert the pose encoding into `intrinsic` and `extrinsic`
|
| 970 |
|
| 971 |
Args:
|
| 972 |
poses_pred: B T 8
|
| 973 |
-
Return:
|
| 974 |
Intrinsic B T 3 3
|
| 975 |
Extrinsic B T 4 4
|
| 976 |
"""
|
| 977 |
B, T, _ = poses_pred.shape
|
| 978 |
focal_pred = poses_pred[:, :, -1].clone()
|
| 979 |
-
pos_quat_preds = poses_pred[:, :, :7].clone()
|
| 980 |
-
pos_quat_preds = pos_quat_preds.view(B*T, -1)
|
| 981 |
-
# get extrinsic
|
| 982 |
c2w_rot = quaternion_to_matrix(pos_quat_preds[:, 3:])
|
| 983 |
c2w_tran = pos_quat_preds[:, :3]
|
| 984 |
c2w_traj = torch.eye(4)[None].repeat(B*T, 1, 1).to(poses_pred.device)
|
| 985 |
c2w_traj[:, :3, :3], c2w_traj[:, :3, 3] = c2w_rot, c2w_tran
|
| 986 |
c2w_traj = c2w_traj.view(B, T, 4, 4)
|
| 987 |
# get intrinsic
|
| 988 |
-
fxs, fys = focal_pred*resolution, focal_pred*resolution
|
| 989 |
intrs = torch.eye(3).to(c2w_traj.device).to(c2w_traj.dtype)[None, None].repeat(B, T, 1, 1)
|
| 990 |
intrs[:,:,0,0], intrs[:,:,1,1] = fxs, fys
|
| 991 |
intrs[:,:,0,2], intrs[:,:,1,2] = W_resize/2, H_resize/2
|
|
@@ -1001,7 +1001,7 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
|
| 1001 |
positive_mask = x > 0
|
| 1002 |
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 1003 |
return ret
|
| 1004 |
-
|
| 1005 |
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 1006 |
"""
|
| 1007 |
Convert a unit quaternion to a standard form: one in which the real
|
|
@@ -1086,11 +1086,11 @@ def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
|
|
| 1086 |
return grid
|
| 1087 |
else:
|
| 1088 |
return grid_y, grid_x
|
| 1089 |
-
|
| 1090 |
def get_points_on_a_grid(grid_size, interp_shape,
|
| 1091 |
grid_center=(0, 0), device="cuda"):
|
| 1092 |
if grid_size == 1:
|
| 1093 |
-
return torch.tensor([interp_shape[1] / 2,
|
| 1094 |
interp_shape[0] / 2], device=device)[
|
| 1095 |
None, None
|
| 1096 |
]
|
|
@@ -1114,12 +1114,12 @@ def get_points_on_a_grid(grid_size, interp_shape,
|
|
| 1114 |
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
|
| 1115 |
return xy
|
| 1116 |
|
| 1117 |
-
def normalize_rgb(x,input_size=224,
|
| 1118 |
resize_mode: Literal['resize', 'padding'] = 'resize',
|
| 1119 |
if_da=False):
|
| 1120 |
"""
|
| 1121 |
normalize the image for depth anything input
|
| 1122 |
-
|
| 1123 |
args:
|
| 1124 |
x: the input images [B T C H W]
|
| 1125 |
"""
|
|
@@ -1127,8 +1127,8 @@ def normalize_rgb(x,input_size=224,
|
|
| 1127 |
x = torch.from_numpy(x) / 255.0
|
| 1128 |
elif isinstance(x, torch.Tensor):
|
| 1129 |
x = x / 255.0
|
| 1130 |
-
B, T, C, H, W = x.shape
|
| 1131 |
-
x = x.view(B * T, C, H, W)
|
| 1132 |
Resizer = Resize(
|
| 1133 |
width=input_size,
|
| 1134 |
height=input_size,
|
|
@@ -1136,7 +1136,7 @@ def normalize_rgb(x,input_size=224,
|
|
| 1136 |
keep_aspect_ratio=True,
|
| 1137 |
ensure_multiple_of=14,
|
| 1138 |
resize_method='lower_bound',
|
| 1139 |
-
)
|
| 1140 |
if resize_mode == 'padding':
|
| 1141 |
# zero padding to make the input size to be multiple of 14
|
| 1142 |
if H > W:
|
|
@@ -1160,7 +1160,7 @@ def normalize_rgb(x,input_size=224,
|
|
| 1160 |
x = F.interpolate(x, size=(int(H_scale), int(W_scale)),
|
| 1161 |
mode='bicubic', align_corners=True)
|
| 1162 |
# get the mean and std
|
| 1163 |
-
__mean__ = torch.tensor([0.485,
|
| 1164 |
0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
|
| 1165 |
__std__ = torch.tensor([0.229,
|
| 1166 |
0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
|
|
@@ -1168,7 +1168,7 @@ def normalize_rgb(x,input_size=224,
|
|
| 1168 |
if if_da:
|
| 1169 |
x = (x - __mean__) / __std__
|
| 1170 |
else:
|
| 1171 |
-
x = x
|
| 1172 |
return x.view(B, T, C, x.shape[-2], x.shape[-1])
|
| 1173 |
|
| 1174 |
def get_track_points(H, W, T, device, size=100, support_frame=0,
|
|
|
|
| 95 |
return fmtstr.format(**self.__dict__)
|
| 96 |
|
| 97 |
|
| 98 |
+
def procrustes_analysis(X0,X1): # [N,3]
|
| 99 |
# translation
|
| 100 |
t0 = X0.mean(dim=0,keepdim=True)
|
| 101 |
t1 = X1.mean(dim=0,keepdim=True)
|
|
|
|
| 218 |
|
| 219 |
intrinsics = create_intri_matrix(focal_length, principal_point)
|
| 220 |
return extrinsics, intrinsics
|
| 221 |
+
|
| 222 |
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
| 223 |
"""
|
| 224 |
Convert rotations given as quaternions to rotation matrices.
|
|
|
|
| 278 |
# Now converted back
|
| 279 |
focal_length = (log_focal_length + log_focal_length_bias).exp()
|
| 280 |
# clamp to avoid weird fl values
|
| 281 |
+
focal_length = torch.clamp(focal_length,
|
| 282 |
min=min_focal_length, max=max_focal_length)
|
| 283 |
elif pose_encoding_type == "absT_quaR_OneFL":
|
| 284 |
# 3 for absT, 4 for quaR, 1 for absFL
|
|
|
|
| 287 |
quaternion_R = pose_encoding_reshaped[:, 3:7]
|
| 288 |
R = quaternion_to_matrix(quaternion_R)
|
| 289 |
focal_length = pose_encoding_reshaped[:, 7:8]
|
| 290 |
+
focal_length = torch.clamp(focal_length,
|
| 291 |
min=min_focal_length, max=max_focal_length)
|
| 292 |
else:
|
| 293 |
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
|
|
|
|
| 316 |
|
| 317 |
R = extrinsics_4x4[:, :3, :3].clone()
|
| 318 |
abs_T = extrinsics_4x4[:, :3, 3].clone()
|
| 319 |
+
|
| 320 |
if return_dict:
|
| 321 |
return {"focal_length": focal_length, "R": R, "T": abs_T}
|
| 322 |
|
|
|
|
| 326 |
|
| 327 |
|
| 328 |
def camera_to_pose_encoding(
|
| 329 |
+
camera, pose_encoding_type="absT_quaR_logFL",
|
| 330 |
log_focal_length_bias=1.8, min_focal_length=0.1, max_focal_length=30
|
| 331 |
):
|
| 332 |
"""
|
|
|
|
| 359 |
return pose_encoding
|
| 360 |
|
| 361 |
|
| 362 |
+
def init_pose_enc(B: int,
|
| 363 |
S: int, pose_encoding_type: str="absT_quaR_logFL",
|
| 364 |
device: Optional[torch.device]=None):
|
| 365 |
"""
|
|
|
|
| 378 |
C = 8
|
| 379 |
else:
|
| 380 |
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
|
| 381 |
+
|
| 382 |
pose_enc = torch.zeros(B, S, C, device=device)
|
| 383 |
pose_enc[..., :3] = 0 # absT
|
| 384 |
pose_enc[..., 3] = 1 # quaR
|
|
|
|
| 389 |
pose_encoding_type: str="absT_quaR_OneFL",
|
| 390 |
pose_mode: str = "W2C"):
|
| 391 |
"""
|
| 392 |
+
make sure the poses in on window are normalized by the first frame, where the
|
| 393 |
first frame transformation is the Identity Matrix.
|
| 394 |
NOTE: Poses are all W2C
|
| 395 |
args:
|
|
|
|
| 403 |
pose_enc, pose_encoding_type=pose_encoding_type,
|
| 404 |
to_OpenCV=False
|
| 405 |
) #NOTE: the camera parameters are not in NDC
|
| 406 |
+
|
| 407 |
R = pred_cameras.R # [B*S, 3, 3]
|
| 408 |
T = pred_cameras.T # [B*S, 3]
|
| 409 |
+
|
| 410 |
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*S, 3, 4]
|
| 411 |
extra_ = torch.tensor([[[0, 0, 0, 1]]],
|
| 412 |
device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
|
| 413 |
Tran_M = torch.cat([Tran_M, extra_
|
| 414 |
], dim=1)
|
| 415 |
Tran_M = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)
|
| 416 |
+
|
| 417 |
# Take the first frame as the base of world coordinate
|
| 418 |
if pose_mode == "C2W":
|
| 419 |
Tran_M_new = (Tran_M[:,:1,...].inverse())@Tran_M
|
| 420 |
elif pose_mode == "W2C":
|
| 421 |
Tran_M_new = Tran_M@(Tran_M[:,:1,...].inverse())
|
| 422 |
+
|
| 423 |
Tran_M_new = rearrange(Tran_M_new, 'b s c d -> (b s) c d')
|
| 424 |
|
| 425 |
R_ = Tran_M_new[:, :3, :3]
|
|
|
|
| 429 |
pred_cameras.R = R_
|
| 430 |
pred_cameras.T = T_
|
| 431 |
pose_enc_norm = camera_to_pose_encoding(pred_cameras,
|
| 432 |
+
pose_encoding_type=pose_encoding_type)
|
| 433 |
pose_enc_norm = rearrange(pose_enc_norm, '(b s) c -> b s c', b=B)
|
| 434 |
return pose_enc_norm
|
| 435 |
|
|
|
|
| 439 |
pose_encoding_type: str="absT_quaR_OneFL",
|
| 440 |
pose_mode: str = "W2C"):
|
| 441 |
"""
|
| 442 |
+
make sure the poses in on window are de-normalized by the first frame, where the
|
| 443 |
first frame transformation is the Identity Matrix.
|
| 444 |
args:
|
| 445 |
pose_enc: [B S C]
|
|
|
|
| 457 |
) #NOTE: the camera parameters are not in NDC
|
| 458 |
R = pred_cameras.R # [B*(1+S), 3, 3]
|
| 459 |
T = pred_cameras.T # [B*(1+S), 3]
|
| 460 |
+
|
| 461 |
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*(1+S), 3, 4]
|
| 462 |
extra_ = torch.tensor([[[0, 0, 0, 1]]],
|
| 463 |
device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
|
|
|
|
| 470 |
Tran_M_new = Tran_M_1st@Tran_M_new
|
| 471 |
elif pose_mode == "W2C":
|
| 472 |
Tran_M_new = Tran_M_new@Tran_M_1st
|
| 473 |
+
|
| 474 |
Tran_M_new_ = torch.cat([Tran_M_1st, Tran_M_new], dim=1)
|
| 475 |
R_ = Tran_M_new_[..., :3, :3].view(-1, 3, 3)
|
| 476 |
T_ = Tran_M_new_[..., :3, 3].view(-1, 3)
|
|
|
|
| 481 |
|
| 482 |
# Cameras to Pose encoding
|
| 483 |
pose_enc_denorm = camera_to_pose_encoding(pred_cameras,
|
| 484 |
+
pose_encoding_type=pose_encoding_type)
|
| 485 |
pose_enc_denorm = rearrange(pose_enc_denorm, '(b s) c -> b s c', b=B)
|
| 486 |
return pose_enc_denorm[:, 1:]
|
| 487 |
|
|
|
|
| 560 |
target_nm, a_norm_gt, b_norm_gt = normalize_prediction_robust(target.float(), mask, Bs)
|
| 561 |
depth_loss = nn.functional.l1_loss(prediction_nm[mask], target_nm[mask])
|
| 562 |
scale = b_norm_gt/b_norm
|
| 563 |
+
shift = a_norm_gt - a_norm*scale
|
| 564 |
return depth_loss, scale, shift, prediction_nm, target_nm
|
| 565 |
|
| 566 |
def reduction_batch_based(image_loss, M):
|
|
|
|
| 593 |
|
| 594 |
def forward(self, prediction, target, mask, Bs,
|
| 595 |
interpolate=True, return_interpolated=False):
|
| 596 |
+
|
| 597 |
if prediction.shape[-1] != target.shape[-1] and interpolate:
|
| 598 |
prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True)
|
| 599 |
intr_input = prediction
|
|
|
|
| 602 |
|
| 603 |
prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze()
|
| 604 |
assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}."
|
| 605 |
+
|
| 606 |
|
| 607 |
scale, shift = compute_scale_and_shift(prediction, target, mask)
|
| 608 |
a_norm = scale.view(Bs, -1, 1, 1).mean(dim=1, keepdim=True)
|
|
|
|
| 634 |
|
| 635 |
for scale in range(self.__scales):
|
| 636 |
step = pow(2, scale)
|
| 637 |
+
l1_ln, a_nm, b_nm = ScaleAndShiftInvariantLoss_fn(prediction[:, ::step, ::step],
|
| 638 |
target[:, ::step, ::step], mask[:, ::step, ::step], 1)
|
| 639 |
total += l1_ln
|
| 640 |
a_nm = a_nm.squeeze().detach() # [B, 1, 1]
|
|
|
|
| 663 |
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
|
| 664 |
|
| 665 |
return reduction(image_loss, M)
|
| 666 |
+
|
| 667 |
def loss_fn(
|
| 668 |
poses_preds: List[torch.Tensor],
|
| 669 |
poses_pred_all: List[torch.Tensor],
|
|
|
|
| 700 |
if logger is not None:
|
| 701 |
if poses_preds_ij.max()>5e1:
|
| 702 |
logger.info(f"pose_pred_max_and_mean: {poses_preds_ij.max(), poses_preds_ij.mean()}")
|
| 703 |
+
|
| 704 |
trans_loss = (poses_preds_ij[...,:3] - poses_gt_i_norm[...,:3]).abs().sum(dim=-1).mean()
|
| 705 |
rot_loss = (poses_preds_ij[...,3:7] - poses_gt_i_norm[...,3:7]).abs().sum(dim=-1).mean()
|
| 706 |
focal_loss = (poses_preds_ij[...,7:] - poses_gt_i_norm[...,7:]).abs().sum(dim=-1).mean()
|
|
|
|
| 714 |
logger_tf.add_scalar(f"loss@pose/rot_iter{idx}",
|
| 715 |
rot_loss, global_step=global_step)
|
| 716 |
logger_tf.add_scalar(f"loss@pose/focal_iter{idx}",
|
| 717 |
+
focal_loss, global_step=global_step)
|
| 718 |
# compute the uncertainty loss
|
| 719 |
with torch.no_grad():
|
| 720 |
pose_loss_dist = (poses_preds_ij-poses_gt_i_norm).detach().abs()
|
|
|
|
| 726 |
unc_loss,
|
| 727 |
global_step=global_step)
|
| 728 |
# if logger is not None:
|
| 729 |
+
# logger.info(f"pose_loss: {pose_loss}, unc_loss: {unc_loss}")
|
| 730 |
# total loss
|
| 731 |
+
loss_total += 0.1*unc_loss + 2*pose_loss
|
| 732 |
|
| 733 |
poses_gt_norm = poses_gt
|
| 734 |
pose_all_loss = 0.0
|
|
|
|
| 743 |
prev_loss = (trans_loss + rot_loss + focal_loss)
|
| 744 |
else:
|
| 745 |
des_loss = (trans_loss + rot_loss + focal_loss) - prev_loss
|
| 746 |
+
prev_loss = trans_loss + rot_loss + focal_loss
|
| 747 |
logger_tf.add_scalar(f"loss@global_pose/des_iter{idx}",
|
| 748 |
des_loss, global_step=global_step)
|
| 749 |
logger_tf.add_scalar(f"loss@global_pose/trans_iter{idx}",
|
|
|
|
| 751 |
logger_tf.add_scalar(f"loss@global_pose/rot_iter{idx}",
|
| 752 |
rot_loss, global_step=global_step)
|
| 753 |
logger_tf.add_scalar(f"loss@global_pose/focal_iter{idx}",
|
| 754 |
+
focal_loss, global_step=global_step)
|
| 755 |
if torch.isnan((trans_loss + rot_loss + focal_loss)).any():
|
| 756 |
pose_all_loss += 0
|
| 757 |
else:
|
| 758 |
pose_all_loss += i_weight*(trans_loss + rot_loss + focal_loss)
|
| 759 |
+
|
| 760 |
# if logger is not None:
|
| 761 |
+
# logger.info(f"global_pose_loss: {pose_all_loss}")
|
| 762 |
|
| 763 |
# compute the depth loss
|
| 764 |
if inv_depth_preds[0] is not None:
|
| 765 |
depths_gt = depths_gt[:,:,0]
|
| 766 |
msk = depths_gt > 5e-2
|
| 767 |
+
inv_gt = 1.0 / (depths_gt.clamp(1e-3, 1e16))
|
| 768 |
inv_gt_reshp = rearrange(inv_gt, 'b t h w -> (b t) h w')
|
| 769 |
inv_depth_preds_reshp = rearrange(inv_depth_preds[0], 'b t h w -> (b t) h w')
|
| 770 |
inv_raw_reshp = rearrange(inv_depth_raw[0], 'b t h w -> (b t) h w')
|
|
|
|
| 785 |
depth_loss,
|
| 786 |
global_step=global_step)
|
| 787 |
# if logger is not None:
|
| 788 |
+
# logger.info(f"opt_depth: {huber_loss_raw - huber_loss}")
|
| 789 |
else:
|
| 790 |
depth_loss = 0.0
|
| 791 |
|
| 792 |
+
|
| 793 |
loss_total = loss_total/(len(poses_preds)) + 20*depth_loss + pose_all_loss
|
| 794 |
|
| 795 |
return loss_total, (huber_loss_raw - huber_loss)
|
|
|
|
| 803 |
"""
|
| 804 |
assert len(x.shape) == 2
|
| 805 |
|
| 806 |
+
depth_map_normalized = cv2.normalize(x.cpu().numpy(),
|
| 807 |
None, 0, 255, cv2.NORM_MINMAX)
|
| 808 |
depth_map_colored = cv2.applyColorMap(depth_map_normalized.astype(np.uint8),
|
| 809 |
cv2.COLORMAP_JET)
|
|
|
|
| 848 |
return pcl
|
| 849 |
|
| 850 |
def vis_result(rgbs, poses_pred, poses_gt,
|
| 851 |
+
depth_gt, depth_pred, iter_num=0,
|
| 852 |
vis=None, logger_tf=None, cfg=None):
|
| 853 |
"""
|
| 854 |
Args:
|
|
|
|
| 863 |
if vis is None:
|
| 864 |
return
|
| 865 |
S, _, H, W = depth_gt.shape
|
| 866 |
+
# get the xy
|
| 867 |
yx = torch.meshgrid(torch.arange(H).to(depth_pred.device),
|
| 868 |
torch.arange(W).to(depth_pred.device),indexing='ij')
|
| 869 |
xy = torch.stack(yx[::-1], dim=0).float().to(depth_pred.device)
|
|
|
|
| 880 |
pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
|
| 881 |
poses_pred_vis = pose_encoding_to_camera(poses_pred,
|
| 882 |
pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
|
| 883 |
+
|
| 884 |
R_gt = poses_gt_vis.R.float()
|
| 885 |
R_pred = poses_pred_vis.R.float()
|
| 886 |
T_gt = poses_gt_vis.T.float()
|
|
|
|
| 890 |
T_gt_c2w = (-R_gt_c2w @ T_gt[:, :, None]).squeeze(-1)
|
| 891 |
R_pred_c2w = R_pred.permute(0,2,1)
|
| 892 |
T_pred_c2w = (-R_pred_c2w @ T_pred[:, :, None]).squeeze(-1)
|
| 893 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 894 |
+
pick_idx = torch.randperm(S)[:min(24, S)]
|
| 895 |
# pick_idx = [1]
|
| 896 |
#NOTE: very strange that the camera need C2W Rotation and W2C translation as input
|
| 897 |
poses_gt_vis = PerspectiveCamerasVisual(
|
|
|
|
| 922 |
fig = plot_scene(visual_dict, camera_scale=0.05)
|
| 923 |
vis.plotlyplot(fig, env=env_name, win="3D")
|
| 924 |
vis.save([env_name])
|
| 925 |
+
|
| 926 |
return
|
| 927 |
+
|
| 928 |
def depth2pcd(
|
| 929 |
xy_depth: torch.Tensor,
|
| 930 |
focal_length: torch.Tensor,
|
|
|
|
| 953 |
K_inv = K.inverse()
|
| 954 |
# xyz
|
| 955 |
xyz = xy_depth.view(S, -1, 3).permute(0, 2, 1) # S 3 (H W)
|
| 956 |
+
depth = xyz[:, 2:].clone() # S (H W) 1
|
| 957 |
xyz[:, 2] = 1
|
| 958 |
xyz = K_inv @ xyz # S 3 (H W)
|
| 959 |
xyz = xyz * depth
|
|
|
|
| 963 |
return xyz
|
| 964 |
|
| 965 |
|
| 966 |
+
def pose_enc2mat(poses_pred,
|
| 967 |
H_resize, W_resize, resolution=336):
|
| 968 |
"""
|
| 969 |
This function convert the pose encoding into `intrinsic` and `extrinsic`
|
| 970 |
|
| 971 |
Args:
|
| 972 |
poses_pred: B T 8
|
| 973 |
+
Return:
|
| 974 |
Intrinsic B T 3 3
|
| 975 |
Extrinsic B T 4 4
|
| 976 |
"""
|
| 977 |
B, T, _ = poses_pred.shape
|
| 978 |
focal_pred = poses_pred[:, :, -1].clone()
|
| 979 |
+
pos_quat_preds = poses_pred[:, :, :7].clone()
|
| 980 |
+
pos_quat_preds = pos_quat_preds.view(B*T, -1)
|
| 981 |
+
# get extrinsic
|
| 982 |
c2w_rot = quaternion_to_matrix(pos_quat_preds[:, 3:])
|
| 983 |
c2w_tran = pos_quat_preds[:, :3]
|
| 984 |
c2w_traj = torch.eye(4)[None].repeat(B*T, 1, 1).to(poses_pred.device)
|
| 985 |
c2w_traj[:, :3, :3], c2w_traj[:, :3, 3] = c2w_rot, c2w_tran
|
| 986 |
c2w_traj = c2w_traj.view(B, T, 4, 4)
|
| 987 |
# get intrinsic
|
| 988 |
+
fxs, fys = focal_pred*resolution, focal_pred*resolution
|
| 989 |
intrs = torch.eye(3).to(c2w_traj.device).to(c2w_traj.dtype)[None, None].repeat(B, T, 1, 1)
|
| 990 |
intrs[:,:,0,0], intrs[:,:,1,1] = fxs, fys
|
| 991 |
intrs[:,:,0,2], intrs[:,:,1,2] = W_resize/2, H_resize/2
|
|
|
|
| 1001 |
positive_mask = x > 0
|
| 1002 |
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 1003 |
return ret
|
| 1004 |
+
|
| 1005 |
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 1006 |
"""
|
| 1007 |
Convert a unit quaternion to a standard form: one in which the real
|
|
|
|
| 1086 |
return grid
|
| 1087 |
else:
|
| 1088 |
return grid_y, grid_x
|
| 1089 |
+
|
| 1090 |
def get_points_on_a_grid(grid_size, interp_shape,
|
| 1091 |
grid_center=(0, 0), device="cuda"):
|
| 1092 |
if grid_size == 1:
|
| 1093 |
+
return torch.tensor([interp_shape[1] / 2,
|
| 1094 |
interp_shape[0] / 2], device=device)[
|
| 1095 |
None, None
|
| 1096 |
]
|
|
|
|
| 1114 |
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
|
| 1115 |
return xy
|
| 1116 |
|
| 1117 |
+
def normalize_rgb(x,input_size=224,
|
| 1118 |
resize_mode: Literal['resize', 'padding'] = 'resize',
|
| 1119 |
if_da=False):
|
| 1120 |
"""
|
| 1121 |
normalize the image for depth anything input
|
| 1122 |
+
|
| 1123 |
args:
|
| 1124 |
x: the input images [B T C H W]
|
| 1125 |
"""
|
|
|
|
| 1127 |
x = torch.from_numpy(x) / 255.0
|
| 1128 |
elif isinstance(x, torch.Tensor):
|
| 1129 |
x = x / 255.0
|
| 1130 |
+
B, T, C, H, W = x.shape
|
| 1131 |
+
x = x.view(B * T, C, H, W)
|
| 1132 |
Resizer = Resize(
|
| 1133 |
width=input_size,
|
| 1134 |
height=input_size,
|
|
|
|
| 1136 |
keep_aspect_ratio=True,
|
| 1137 |
ensure_multiple_of=14,
|
| 1138 |
resize_method='lower_bound',
|
| 1139 |
+
)
|
| 1140 |
if resize_mode == 'padding':
|
| 1141 |
# zero padding to make the input size to be multiple of 14
|
| 1142 |
if H > W:
|
|
|
|
| 1160 |
x = F.interpolate(x, size=(int(H_scale), int(W_scale)),
|
| 1161 |
mode='bicubic', align_corners=True)
|
| 1162 |
# get the mean and std
|
| 1163 |
+
__mean__ = torch.tensor([0.485,
|
| 1164 |
0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
|
| 1165 |
__std__ = torch.tensor([0.229,
|
| 1166 |
0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
|
|
|
|
| 1168 |
if if_da:
|
| 1169 |
x = (x - __mean__) / __std__
|
| 1170 |
else:
|
| 1171 |
+
x = x
|
| 1172 |
return x.view(B, T, C, x.shape[-2], x.shape[-1])
|
| 1173 |
|
| 1174 |
def get_track_points(H, W, T, device, size=100, support_frame=0,
|
models/SpaTrackV2/models/vggt4track/models/tracker_front.py
CHANGED
|
@@ -75,15 +75,15 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
|
|
| 75 |
B, T, C, H, W = images.shape
|
| 76 |
images = (images - self.base_model.image_mean) / self.base_model.image_std
|
| 77 |
H_14 = H // 14 * 14
|
| 78 |
-
W_14 = W // 14 * 14
|
| 79 |
image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
|
| 80 |
|
| 81 |
with torch.no_grad():
|
| 82 |
-
features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
|
| 83 |
self.base_model.intermediate_layers, return_class_token=True)
|
| 84 |
# aggregate the features with checkpoint
|
| 85 |
aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
|
| 86 |
-
|
| 87 |
# enhance the features
|
| 88 |
enhanced_features = []
|
| 89 |
for layer_i, layer in enumerate(self.intermediate_layers):
|
|
@@ -94,7 +94,7 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
|
|
| 94 |
|
| 95 |
predictions = {}
|
| 96 |
|
| 97 |
-
with torch.
|
| 98 |
if self.camera_head is not None:
|
| 99 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 100 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
|
@@ -104,7 +104,7 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
|
|
| 104 |
# Predict points (and mask) with checkpoint
|
| 105 |
output = self.base_model.head(enhanced_features, image_14)
|
| 106 |
points, mask = output
|
| 107 |
-
|
| 108 |
# Post-process points and mask
|
| 109 |
points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
|
| 110 |
points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
|
|
@@ -119,13 +119,13 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
|
|
| 119 |
if self.training:
|
| 120 |
loss = compute_loss(predictions, annots)
|
| 121 |
predictions["loss"] = loss
|
| 122 |
-
|
| 123 |
# rescale the points
|
| 124 |
if self.scale_head is not None:
|
| 125 |
points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
|
| 126 |
points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
|
| 127 |
predictions["points_map"] = points_scale
|
| 128 |
-
|
| 129 |
predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
|
| 130 |
predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
|
| 131 |
predictions["images"].shape[-2:])
|
|
|
|
| 75 |
B, T, C, H, W = images.shape
|
| 76 |
images = (images - self.base_model.image_mean) / self.base_model.image_std
|
| 77 |
H_14 = H // 14 * 14
|
| 78 |
+
W_14 = W // 14 * 14
|
| 79 |
image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
|
| 80 |
|
| 81 |
with torch.no_grad():
|
| 82 |
+
features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
|
| 83 |
self.base_model.intermediate_layers, return_class_token=True)
|
| 84 |
# aggregate the features with checkpoint
|
| 85 |
aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
|
| 86 |
+
|
| 87 |
# enhance the features
|
| 88 |
enhanced_features = []
|
| 89 |
for layer_i, layer in enumerate(self.intermediate_layers):
|
|
|
|
| 94 |
|
| 95 |
predictions = {}
|
| 96 |
|
| 97 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 98 |
if self.camera_head is not None:
|
| 99 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 100 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
|
|
|
| 104 |
# Predict points (and mask) with checkpoint
|
| 105 |
output = self.base_model.head(enhanced_features, image_14)
|
| 106 |
points, mask = output
|
| 107 |
+
|
| 108 |
# Post-process points and mask
|
| 109 |
points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
|
| 110 |
points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
|
|
|
|
| 119 |
if self.training:
|
| 120 |
loss = compute_loss(predictions, annots)
|
| 121 |
predictions["loss"] = loss
|
| 122 |
+
|
| 123 |
# rescale the points
|
| 124 |
if self.scale_head is not None:
|
| 125 |
points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
|
| 126 |
points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
|
| 127 |
predictions["points_map"] = points_scale
|
| 128 |
+
|
| 129 |
predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
|
| 130 |
predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
|
| 131 |
predictions["images"].shape[-2:])
|
models/SpaTrackV2/models/vggt4track/models/vggt.py
CHANGED
|
@@ -64,7 +64,7 @@ class VGGT(nn.Module, PyTorchModelHubMixin):
|
|
| 64 |
|
| 65 |
predictions = {}
|
| 66 |
|
| 67 |
-
with torch.
|
| 68 |
if self.camera_head is not None:
|
| 69 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 70 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
|
|
|
| 64 |
|
| 65 |
predictions = {}
|
| 66 |
|
| 67 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 68 |
if self.camera_head is not None:
|
| 69 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 70 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
models/SpaTrackV2/models/vggt4track/models/vggt_moe.py
CHANGED
|
@@ -65,13 +65,13 @@ class VGGT4Track(nn.Module, PyTorchModelHubMixin):
|
|
| 65 |
|
| 66 |
if len(images.shape) == 4:
|
| 67 |
images = images.unsqueeze(0)
|
| 68 |
-
|
| 69 |
with torch.no_grad():
|
| 70 |
aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
|
| 71 |
|
| 72 |
predictions = {}
|
| 73 |
|
| 74 |
-
with torch.
|
| 75 |
if self.camera_head is not None:
|
| 76 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 77 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
|
@@ -97,11 +97,11 @@ class VGGT4Track(nn.Module, PyTorchModelHubMixin):
|
|
| 97 |
size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
|
| 98 |
predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
|
| 99 |
size=(H, W), mode='bilinear', align_corners=True)[:,0]
|
| 100 |
-
predictions["intrs"][..., :1, :] *= W/W_proc
|
| 101 |
-
predictions["intrs"][..., 1:2, :] *= H/H_proc
|
| 102 |
|
| 103 |
if self.training:
|
| 104 |
loss = compute_loss(predictions, annots)
|
| 105 |
predictions["loss"] = loss
|
| 106 |
-
|
| 107 |
return predictions
|
|
|
|
| 65 |
|
| 66 |
if len(images.shape) == 4:
|
| 67 |
images = images.unsqueeze(0)
|
| 68 |
+
|
| 69 |
with torch.no_grad():
|
| 70 |
aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
|
| 71 |
|
| 72 |
predictions = {}
|
| 73 |
|
| 74 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 75 |
if self.camera_head is not None:
|
| 76 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 77 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
|
|
|
| 97 |
size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
|
| 98 |
predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
|
| 99 |
size=(H, W), mode='bilinear', align_corners=True)[:,0]
|
| 100 |
+
predictions["intrs"][..., :1, :] *= W/W_proc
|
| 101 |
+
predictions["intrs"][..., 1:2, :] *= H/H_proc
|
| 102 |
|
| 103 |
if self.training:
|
| 104 |
loss = compute_loss(predictions, annots)
|
| 105 |
predictions["loss"] = loss
|
| 106 |
+
|
| 107 |
return predictions
|
models/vggt/vggt/models/tracker_front.py
CHANGED
|
@@ -75,15 +75,15 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
|
|
| 75 |
B, T, C, H, W = images.shape
|
| 76 |
images = (images - self.base_model.image_mean) / self.base_model.image_std
|
| 77 |
H_14 = H // 14 * 14
|
| 78 |
-
W_14 = W // 14 * 14
|
| 79 |
image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
|
| 80 |
|
| 81 |
with torch.no_grad():
|
| 82 |
-
features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
|
| 83 |
self.base_model.intermediate_layers, return_class_token=True)
|
| 84 |
# aggregate the features with checkpoint
|
| 85 |
aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
|
| 86 |
-
|
| 87 |
# enhance the features
|
| 88 |
enhanced_features = []
|
| 89 |
for layer_i, layer in enumerate(self.intermediate_layers):
|
|
@@ -94,7 +94,7 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
|
|
| 94 |
|
| 95 |
predictions = {}
|
| 96 |
|
| 97 |
-
with torch.
|
| 98 |
if self.camera_head is not None:
|
| 99 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 100 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
|
@@ -104,7 +104,7 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
|
|
| 104 |
# Predict points (and mask) with checkpoint
|
| 105 |
output = self.base_model.head(enhanced_features, image_14)
|
| 106 |
points, mask = output
|
| 107 |
-
|
| 108 |
# Post-process points and mask
|
| 109 |
points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
|
| 110 |
points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
|
|
@@ -119,13 +119,13 @@ class FrontTracker(nn.Module, PyTorchModelHubMixin):
|
|
| 119 |
if self.training:
|
| 120 |
loss = compute_loss(predictions, annots)
|
| 121 |
predictions["loss"] = loss
|
| 122 |
-
|
| 123 |
# rescale the points
|
| 124 |
if self.scale_head is not None:
|
| 125 |
points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
|
| 126 |
points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
|
| 127 |
predictions["points_map"] = points_scale
|
| 128 |
-
|
| 129 |
predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
|
| 130 |
predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
|
| 131 |
predictions["images"].shape[-2:])
|
|
|
|
| 75 |
B, T, C, H, W = images.shape
|
| 76 |
images = (images - self.base_model.image_mean) / self.base_model.image_std
|
| 77 |
H_14 = H // 14 * 14
|
| 78 |
+
W_14 = W // 14 * 14
|
| 79 |
image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
|
| 80 |
|
| 81 |
with torch.no_grad():
|
| 82 |
+
features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
|
| 83 |
self.base_model.intermediate_layers, return_class_token=True)
|
| 84 |
# aggregate the features with checkpoint
|
| 85 |
aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
|
| 86 |
+
|
| 87 |
# enhance the features
|
| 88 |
enhanced_features = []
|
| 89 |
for layer_i, layer in enumerate(self.intermediate_layers):
|
|
|
|
| 94 |
|
| 95 |
predictions = {}
|
| 96 |
|
| 97 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 98 |
if self.camera_head is not None:
|
| 99 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 100 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
|
|
|
| 104 |
# Predict points (and mask) with checkpoint
|
| 105 |
output = self.base_model.head(enhanced_features, image_14)
|
| 106 |
points, mask = output
|
| 107 |
+
|
| 108 |
# Post-process points and mask
|
| 109 |
points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
|
| 110 |
points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
|
|
|
|
| 119 |
if self.training:
|
| 120 |
loss = compute_loss(predictions, annots)
|
| 121 |
predictions["loss"] = loss
|
| 122 |
+
|
| 123 |
# rescale the points
|
| 124 |
if self.scale_head is not None:
|
| 125 |
points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
|
| 126 |
points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
|
| 127 |
predictions["points_map"] = points_scale
|
| 128 |
+
|
| 129 |
predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
|
| 130 |
predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
|
| 131 |
predictions["images"].shape[-2:])
|
models/vggt/vggt/models/vggt.py
CHANGED
|
@@ -64,7 +64,7 @@ class VGGT(nn.Module, PyTorchModelHubMixin):
|
|
| 64 |
|
| 65 |
predictions = {}
|
| 66 |
|
| 67 |
-
with torch.
|
| 68 |
if self.camera_head is not None:
|
| 69 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 70 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
|
|
|
| 64 |
|
| 65 |
predictions = {}
|
| 66 |
|
| 67 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 68 |
if self.camera_head is not None:
|
| 69 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 70 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
models/vggt/vggt/models/vggt_moe.py
CHANGED
|
@@ -65,13 +65,13 @@ class VGGT_MoE(nn.Module, PyTorchModelHubMixin):
|
|
| 65 |
|
| 66 |
if len(images.shape) == 4:
|
| 67 |
images = images.unsqueeze(0)
|
| 68 |
-
|
| 69 |
with torch.no_grad():
|
| 70 |
aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
|
| 71 |
|
| 72 |
predictions = {}
|
| 73 |
|
| 74 |
-
with torch.
|
| 75 |
if self.camera_head is not None:
|
| 76 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 77 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
|
@@ -97,11 +97,11 @@ class VGGT_MoE(nn.Module, PyTorchModelHubMixin):
|
|
| 97 |
size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
|
| 98 |
predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
|
| 99 |
size=(H, W), mode='bilinear', align_corners=True)[:,0]
|
| 100 |
-
predictions["intrs"][..., :1, :] *= W/W_proc
|
| 101 |
-
predictions["intrs"][..., 1:2, :] *= H/H_proc
|
| 102 |
|
| 103 |
if self.training:
|
| 104 |
loss = compute_loss(predictions, annots)
|
| 105 |
predictions["loss"] = loss
|
| 106 |
-
|
| 107 |
return predictions
|
|
|
|
| 65 |
|
| 66 |
if len(images.shape) == 4:
|
| 67 |
images = images.unsqueeze(0)
|
| 68 |
+
|
| 69 |
with torch.no_grad():
|
| 70 |
aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
|
| 71 |
|
| 72 |
predictions = {}
|
| 73 |
|
| 74 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 75 |
if self.camera_head is not None:
|
| 76 |
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 77 |
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
|
|
|
| 97 |
size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
|
| 98 |
predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
|
| 99 |
size=(H, W), mode='bilinear', align_corners=True)[:,0]
|
| 100 |
+
predictions["intrs"][..., :1, :] *= W/W_proc
|
| 101 |
+
predictions["intrs"][..., 1:2, :] *= H/H_proc
|
| 102 |
|
| 103 |
if self.training:
|
| 104 |
loss = compute_loss(predictions, annots)
|
| 105 |
predictions["loss"] = loss
|
| 106 |
+
|
| 107 |
return predictions
|