xizaoqu
commited on
Commit
·
100414d
1
Parent(s):
f07d258
rm
Browse files- app.py +1 -25
- configurations/README.md +0 -7
- configurations/algorithm/base_algo.yaml +0 -3
- configurations/algorithm/base_pytorch_algo.yaml +0 -4
- configurations/algorithm/df_base.yaml +0 -42
- configurations/algorithm/df_video_worldmemminecraft.yaml +0 -42
- configurations/algorithm/pose_prediction.yaml +0 -19
- configurations/config.yaml +0 -16
- configurations/dataset/base_dataset.yaml +0 -3
- configurations/dataset/base_video.yaml +0 -14
- configurations/dataset/video_minecraft.yaml +0 -14
- configurations/dataset/video_minecraft_pose.yaml +0 -14
- configurations/experiment/base_experiment.yaml +0 -2
- configurations/experiment/base_pytorch.yaml +0 -50
- configurations/experiment/exp_pose.yaml +0 -31
- configurations/experiment/exp_video.yaml +0 -31
- datasets/README.md +0 -11
- datasets/__init__.py +0 -1
- datasets/video/__init__.py +0 -2
- datasets/video/base_video_dataset.py +0 -158
- datasets/video/minecraft_video_dataset.py +0 -262
- datasets/video/minecraft_video_dataset_oasis_filter.py +0 -99
- datasets/video/minecraft_video_dataset_pose.py +0 -421
- experiments/README.md +0 -19
- experiments/__init__.py +0 -35
- experiments/exp_base.py +0 -473
- experiments/exp_pose.py +0 -310
- experiments/exp_video.py +0 -25
- main.py +0 -219
- scripts/README.md +0 -10
- scripts/dummy_script.sh +0 -1
- split_checkpoint.py +0 -9
app.py
CHANGED
|
@@ -10,13 +10,8 @@ import hydra
|
|
| 10 |
from omegaconf import DictConfig, OmegaConf
|
| 11 |
from omegaconf.omegaconf import open_dict
|
| 12 |
|
| 13 |
-
from utils.print_utils import cyan
|
| 14 |
-
from utils.ckpt_utils import download_latest_checkpoint, is_run_id
|
| 15 |
-
from utils.cluster_utils import submit_slurm_job
|
| 16 |
-
from utils.distributed_utils import is_rank_zero
|
| 17 |
import numpy as np
|
| 18 |
import torch
|
| 19 |
-
from datasets.video.minecraft_video_dataset import *
|
| 20 |
import torchvision.transforms as transforms
|
| 21 |
import cv2
|
| 22 |
import subprocess
|
|
@@ -351,18 +346,7 @@ def set_memory(examples_case, image_display, log_output, slider_denoising_step,
|
|
| 351 |
|
| 352 |
return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
| 353 |
|
| 354 |
-
|
| 355 |
-
h1 {
|
| 356 |
-
text-align: center;
|
| 357 |
-
display:block;
|
| 358 |
-
}
|
| 359 |
-
"""
|
| 360 |
-
|
| 361 |
-
def on_select(evt: gr.SelectData):
|
| 362 |
-
selected_index = evt.index
|
| 363 |
-
return examples[selected_index]
|
| 364 |
-
|
| 365 |
-
with gr.Blocks(css=css) as demo:
|
| 366 |
gr.Markdown(
|
| 367 |
"""
|
| 368 |
# WORLDMEM: Long-term Consistent World Generation with Memory
|
|
@@ -515,13 +499,6 @@ with gr.Blocks(css=css) as demo:
|
|
| 515 |
example_case = gr.Textbox(label="Case", visible=False)
|
| 516 |
image_output = gr.Image(visible=False)
|
| 517 |
|
| 518 |
-
# gr.Examples(examples=example_images,
|
| 519 |
-
# inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
| 520 |
-
# fn=set_memory,
|
| 521 |
-
# outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx],
|
| 522 |
-
# cache_examples=True
|
| 523 |
-
# )
|
| 524 |
-
|
| 525 |
examples = gr.Examples(
|
| 526 |
examples=example_images,
|
| 527 |
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
|
@@ -534,7 +511,6 @@ with gr.Blocks(css=css) as demo:
|
|
| 534 |
outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
|
| 535 |
)
|
| 536 |
|
| 537 |
-
|
| 538 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 539 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 540 |
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
|
|
|
| 10 |
from omegaconf import DictConfig, OmegaConf
|
| 11 |
from omegaconf.omegaconf import open_dict
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
import numpy as np
|
| 14 |
import torch
|
|
|
|
| 15 |
import torchvision.transforms as transforms
|
| 16 |
import cv2
|
| 17 |
import subprocess
|
|
|
|
| 346 |
|
| 347 |
return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
| 348 |
|
| 349 |
+
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
gr.Markdown(
|
| 351 |
"""
|
| 352 |
# WORLDMEM: Long-term Consistent World Generation with Memory
|
|
|
|
| 499 |
example_case = gr.Textbox(label="Case", visible=False)
|
| 500 |
image_output = gr.Image(visible=False)
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
examples = gr.Examples(
|
| 503 |
examples=example_images,
|
| 504 |
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
|
|
|
| 511 |
outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
|
| 512 |
)
|
| 513 |
|
|
|
|
| 514 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 515 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 516 |
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
configurations/README.md
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
# configurations
|
| 2 |
-
|
| 3 |
-
We use [Hydra](https://hydra.cc/docs/intro/) to manage configurations. Change/Add the yaml files in this folder
|
| 4 |
-
to change the default configurations. You can also override the default configurations by
|
| 5 |
-
passing command line arguments.
|
| 6 |
-
|
| 7 |
-
All configurations are automatically saved in wandb run.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/algorithm/base_algo.yaml
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
# This will be passed as the cfg to Algo.__init__(cfg) of your algorithm class
|
| 2 |
-
|
| 3 |
-
debug: ${debug} # inherited from configurations/config.yaml
|
|
|
|
|
|
|
|
|
|
|
|
configurations/algorithm/base_pytorch_algo.yaml
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- base_algo # inherits from configurations/algorithm/base_algo.yaml
|
| 3 |
-
|
| 4 |
-
lr: ${experiment.training.lr}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/algorithm/df_base.yaml
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- base_pytorch_algo
|
| 3 |
-
|
| 4 |
-
# dataset-dependent configurations
|
| 5 |
-
x_shape: ${dataset.observation_shape}
|
| 6 |
-
frame_stack: 1
|
| 7 |
-
frame_skip: 1
|
| 8 |
-
data_mean: ${dataset.data_mean}
|
| 9 |
-
data_std: ${dataset.data_std}
|
| 10 |
-
external_cond_dim: 0 #${dataset.action_dim}
|
| 11 |
-
context_frames: ${dataset.context_length}
|
| 12 |
-
# training hyperparameters
|
| 13 |
-
weight_decay: 1e-4
|
| 14 |
-
warmup_steps: 10000
|
| 15 |
-
optimizer_beta: [0.9, 0.999]
|
| 16 |
-
# diffusion-related
|
| 17 |
-
uncertainty_scale: 1
|
| 18 |
-
guidance_scale: 0.0
|
| 19 |
-
chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
|
| 20 |
-
scheduling_matrix: autoregressive
|
| 21 |
-
noise_level: random_all
|
| 22 |
-
causal: True
|
| 23 |
-
|
| 24 |
-
diffusion:
|
| 25 |
-
# training
|
| 26 |
-
objective: pred_x0
|
| 27 |
-
beta_schedule: cosine
|
| 28 |
-
schedule_fn_kwargs: {}
|
| 29 |
-
clip_noise: 20.0
|
| 30 |
-
use_snr: False
|
| 31 |
-
use_cum_snr: False
|
| 32 |
-
use_fused_snr: False
|
| 33 |
-
snr_clip: 5.0
|
| 34 |
-
cum_snr_decay: 0.98
|
| 35 |
-
timesteps: 1000
|
| 36 |
-
# sampling
|
| 37 |
-
sampling_timesteps: 50 # fixme, numer of diffusion steps, should be increased
|
| 38 |
-
ddim_sampling_eta: 1.0
|
| 39 |
-
stabilization_level: 10
|
| 40 |
-
# architecture
|
| 41 |
-
architecture:
|
| 42 |
-
network_size: 64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/algorithm/df_video_worldmemminecraft.yaml
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- df_base
|
| 3 |
-
|
| 4 |
-
n_frames: ${dataset.n_frames}
|
| 5 |
-
frame_skip: ${dataset.frame_skip}
|
| 6 |
-
metadata: ${dataset.metadata}
|
| 7 |
-
|
| 8 |
-
# training hyperparameters
|
| 9 |
-
weight_decay: 2e-3
|
| 10 |
-
warmup_steps: 10000
|
| 11 |
-
optimizer_beta: [0.9, 0.99]
|
| 12 |
-
action_cond_dim: 25
|
| 13 |
-
|
| 14 |
-
diffusion:
|
| 15 |
-
# training
|
| 16 |
-
beta_schedule: sigmoid
|
| 17 |
-
objective: pred_v
|
| 18 |
-
use_fused_snr: True
|
| 19 |
-
cum_snr_decay: 0.96
|
| 20 |
-
clip_noise: 20.
|
| 21 |
-
# sampling
|
| 22 |
-
sampling_timesteps: 20
|
| 23 |
-
ddim_sampling_eta: 0.0
|
| 24 |
-
stabilization_level: 15
|
| 25 |
-
# architecture
|
| 26 |
-
architecture:
|
| 27 |
-
network_size: 64
|
| 28 |
-
attn_heads: 4
|
| 29 |
-
attn_dim_head: 64
|
| 30 |
-
dim_mults: [1, 2, 4, 8]
|
| 31 |
-
resolution: ${dataset.resolution}
|
| 32 |
-
attn_resolutions: [16, 32, 64, 128]
|
| 33 |
-
use_init_temporal_attn: True
|
| 34 |
-
use_linear_attn: True
|
| 35 |
-
time_emb_type: rotary
|
| 36 |
-
|
| 37 |
-
metrics:
|
| 38 |
-
# - fvd
|
| 39 |
-
# - fid
|
| 40 |
-
# - lpips
|
| 41 |
-
|
| 42 |
-
_name: df_video_worldmemminecraft
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/algorithm/pose_prediction.yaml
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- df_base
|
| 3 |
-
|
| 4 |
-
n_frames: ${dataset.n_frames}
|
| 5 |
-
frame_skip: ${dataset.frame_skip}
|
| 6 |
-
metadata: ${dataset.metadata}
|
| 7 |
-
|
| 8 |
-
# training hyperparameters
|
| 9 |
-
weight_decay: 2e-3
|
| 10 |
-
warmup_steps: 10000
|
| 11 |
-
optimizer_beta: [0.9, 0.99]
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
metrics:
|
| 15 |
-
# - fvd
|
| 16 |
-
# - fid
|
| 17 |
-
# - lpips
|
| 18 |
-
|
| 19 |
-
_name: pose_prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/config.yaml
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
# configuration parsing starts here
|
| 2 |
-
defaults:
|
| 3 |
-
- experiment: exp_video # experiment yaml file name in configurations/experiments folder [fixme]
|
| 4 |
-
- dataset: video_minecraft_oasis # dataset yaml file name in configurations/dataset folder [fixme]
|
| 5 |
-
- algorithm: df_video # algorithm yaml file name in configurations/algorithm folder [fixme]
|
| 6 |
-
- cluster: null # optional, cluster yaml file name in configurations/cluster folder. Leave null for local compute
|
| 7 |
-
|
| 8 |
-
debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
|
| 9 |
-
|
| 10 |
-
wandb:
|
| 11 |
-
entity: xizaoqu # wandb account name / organization name [fixme]
|
| 12 |
-
project: diffusion-forcing # wandb project name; if not provided, defaults to root folder name [fixme]
|
| 13 |
-
mode: online # set wandb logging to online, offline or dryrun
|
| 14 |
-
|
| 15 |
-
resume: null # wandb run id to resume logging and loading checkpoint from
|
| 16 |
-
load: null # wandb run id containing checkpoint or a path to a checkpoint file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/dataset/base_dataset.yaml
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
# This will be passed as the cfg to Dataset.__init__(cfg) of your dataset class
|
| 2 |
-
|
| 3 |
-
debug: ${debug} # inherited from configurations/config.yaml
|
|
|
|
|
|
|
|
|
|
|
|
configurations/dataset/base_video.yaml
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- base_dataset
|
| 3 |
-
|
| 4 |
-
metadata: "data/${dataset.name}/metadata.json"
|
| 5 |
-
data_mean: "data/${dataset.name}/data_mean.npy"
|
| 6 |
-
data_std: "data/${dataset.name}/data_std.npy"
|
| 7 |
-
save_dir: ???
|
| 8 |
-
n_frames: 32
|
| 9 |
-
context_length: 4
|
| 10 |
-
resolution: 128
|
| 11 |
-
observation_shape: [3, "${dataset.resolution}", "${dataset.resolution}"]
|
| 12 |
-
external_cond_dim: 0
|
| 13 |
-
validation_multiplier: 1
|
| 14 |
-
frame_skip: 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/dataset/video_minecraft.yaml
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- base_video
|
| 3 |
-
|
| 4 |
-
save_dir: data/minecraft_simple_backforward
|
| 5 |
-
n_frames: 16 # TODO: increase later
|
| 6 |
-
resolution: 128
|
| 7 |
-
data_mean: 0.5
|
| 8 |
-
data_std: 0.5
|
| 9 |
-
action_cond_dim: 25
|
| 10 |
-
context_length: 1
|
| 11 |
-
frame_skip: 1
|
| 12 |
-
validation_multiplier: 1
|
| 13 |
-
|
| 14 |
-
_name: video_minecraft_oasis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/dataset/video_minecraft_pose.yaml
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- base_video
|
| 3 |
-
|
| 4 |
-
save_dir: data/minecraft_simple_backforward
|
| 5 |
-
n_frames: 16 # TODO: increase later
|
| 6 |
-
resolution: 128
|
| 7 |
-
data_mean: 0.5
|
| 8 |
-
data_std: 0.5
|
| 9 |
-
external_cond_dim: 25
|
| 10 |
-
context_length: 1
|
| 11 |
-
frame_skip: 1
|
| 12 |
-
validation_multiplier: 1
|
| 13 |
-
|
| 14 |
-
_name: video_minecraft_pose
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/experiment/base_experiment.yaml
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
debug: ${debug} # inherited from configurations/config.yaml
|
| 2 |
-
tasks: [main] # tasks to run sequantially, such as [training, test], useful when your project has multiple stages and you want to run only a subset of them.
|
|
|
|
|
|
|
|
|
configurations/experiment/base_pytorch.yaml
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
# inherites from base_experiment.yaml
|
| 2 |
-
# most of the options have docs at https://lightning.ai/docs/pytorch/stable/common/trainer.html
|
| 3 |
-
|
| 4 |
-
defaults:
|
| 5 |
-
- base_experiment
|
| 6 |
-
|
| 7 |
-
tasks: [training] # tasks to run sequantially, change when your project has multiple stages and you want to run only a subset of them.
|
| 8 |
-
num_nodes: 1 # number of gpu servers used in large scale distributed training
|
| 9 |
-
|
| 10 |
-
training:
|
| 11 |
-
precision: 16-mixed # set float precision, 16-mixed is faster while 32 is more stable
|
| 12 |
-
compile: False # whether to compile the model with torch.compile
|
| 13 |
-
lr: 0.001 # learning rate
|
| 14 |
-
batch_size: 16 # training batch size; effective batch size is this number * gpu * nodes iff using distributed training
|
| 15 |
-
max_epochs: 1000 # set to -1 to train forever
|
| 16 |
-
max_steps: -1 # set to -1 to train forever, will override max_epochs
|
| 17 |
-
max_time: null # set to something like "00:12:00:00" to enable
|
| 18 |
-
data:
|
| 19 |
-
num_workers: 4 # number of CPU threads for data preprocessing.
|
| 20 |
-
shuffle: True # whether training data will be shuffled
|
| 21 |
-
optim:
|
| 22 |
-
accumulate_grad_batches: 1 # accumulate gradients for n batches before backprop
|
| 23 |
-
gradient_clip_val: 0 # clip gradients with norm above this value, set to 0 to disable
|
| 24 |
-
checkpointing:
|
| 25 |
-
# these are arguments to pytorch lightning's callback, `ModelCheckpoint` class
|
| 26 |
-
every_n_train_steps: 5000 # save a checkpoint every n train steps
|
| 27 |
-
every_n_epochs: null # mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``
|
| 28 |
-
train_time_interval: null # in format of "00:12:00:00", mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
|
| 29 |
-
enable_version_counter: False # If this is ``False``, later checkpoint will be overwrite previous ones.
|
| 30 |
-
|
| 31 |
-
validation:
|
| 32 |
-
precision: 16-mixed
|
| 33 |
-
compile: False # whether to compile the model with torch.compile
|
| 34 |
-
batch_size: 16 # validation batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
|
| 35 |
-
val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
|
| 36 |
-
val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
|
| 37 |
-
limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
|
| 38 |
-
inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
|
| 39 |
-
data:
|
| 40 |
-
num_workers: 4 # number of CPU threads for data preprocessing, for validation.
|
| 41 |
-
shuffle: False # whether validation data will be shuffled
|
| 42 |
-
|
| 43 |
-
test:
|
| 44 |
-
precision: 16-mixed
|
| 45 |
-
compile: False # whether to compile the model with torch.compile
|
| 46 |
-
batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
|
| 47 |
-
limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
|
| 48 |
-
data:
|
| 49 |
-
num_workers: 4 # number of CPU threads for data preprocessing, for test.
|
| 50 |
-
shuffle: False # whether test data will be shuffled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/experiment/exp_pose.yaml
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- base_pytorch
|
| 3 |
-
|
| 4 |
-
tasks: [training]
|
| 5 |
-
|
| 6 |
-
training:
|
| 7 |
-
lr: 8e-5
|
| 8 |
-
precision: 16-mixed
|
| 9 |
-
batch_size: 4
|
| 10 |
-
max_epochs: -1
|
| 11 |
-
max_steps: 2000005
|
| 12 |
-
checkpointing:
|
| 13 |
-
every_n_train_steps: 2500
|
| 14 |
-
optim:
|
| 15 |
-
gradient_clip_val: 1.0
|
| 16 |
-
|
| 17 |
-
validation:
|
| 18 |
-
val_every_n_step: 300
|
| 19 |
-
val_every_n_epoch: null
|
| 20 |
-
batch_size: 4
|
| 21 |
-
limit_batch: 1
|
| 22 |
-
|
| 23 |
-
test:
|
| 24 |
-
limit_batch: 1
|
| 25 |
-
batch_size: 1
|
| 26 |
-
|
| 27 |
-
logging:
|
| 28 |
-
metrics:
|
| 29 |
-
# - fvd
|
| 30 |
-
# - fid
|
| 31 |
-
# - lpips
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/experiment/exp_video.yaml
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- base_pytorch
|
| 3 |
-
|
| 4 |
-
tasks: [training]
|
| 5 |
-
|
| 6 |
-
training:
|
| 7 |
-
lr: 8e-5
|
| 8 |
-
precision: 16-mixed
|
| 9 |
-
batch_size: 4
|
| 10 |
-
max_epochs: -1
|
| 11 |
-
max_steps: 2000005
|
| 12 |
-
checkpointing:
|
| 13 |
-
every_n_train_steps: 2500
|
| 14 |
-
optim:
|
| 15 |
-
gradient_clip_val: 1.0
|
| 16 |
-
|
| 17 |
-
validation:
|
| 18 |
-
val_every_n_step: 300
|
| 19 |
-
val_every_n_epoch: null
|
| 20 |
-
batch_size: 4
|
| 21 |
-
limit_batch: 1
|
| 22 |
-
|
| 23 |
-
test:
|
| 24 |
-
limit_batch: 1
|
| 25 |
-
batch_size: 1
|
| 26 |
-
|
| 27 |
-
logging:
|
| 28 |
-
metrics:
|
| 29 |
-
# - fvd
|
| 30 |
-
# - fid
|
| 31 |
-
# - lpips
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/README.md
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
The `datasets` folder is used to contain dataset code or environment code.
|
| 2 |
-
Don't store actual data like images here! For those, please use the `data` folder instead of `datasets`.
|
| 3 |
-
|
| 4 |
-
Create a folder to create your own pytorch dataset definition. Then, update the `__init__.py`
|
| 5 |
-
at every level to register all datasets.
|
| 6 |
-
|
| 7 |
-
Each dataset class takes in a DictConfig file `cfg` in its `__init__`, which allows you to pass in arguments via configuration file in `configurations/dataset` or [command line override](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/).
|
| 8 |
-
|
| 9 |
-
---
|
| 10 |
-
|
| 11 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
from .video import MinecraftVideoDataset
|
|
|
|
|
|
datasets/video/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
from .minecraft_video_dataset import MinecraftVideoDataset
|
| 2 |
-
from .minecraft_video_dataset_pose import MinecraftVideoPoseDataset
|
|
|
|
|
|
|
|
|
datasets/video/base_video_dataset.py
DELETED
|
@@ -1,158 +0,0 @@
|
|
| 1 |
-
from typing import Sequence
|
| 2 |
-
import torch
|
| 3 |
-
import random
|
| 4 |
-
import os
|
| 5 |
-
import numpy as np
|
| 6 |
-
import cv2
|
| 7 |
-
from omegaconf import DictConfig
|
| 8 |
-
from torchvision import transforms
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from abc import abstractmethod, ABC
|
| 11 |
-
import json
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class BaseVideoDataset(torch.utils.data.Dataset, ABC):
|
| 15 |
-
"""
|
| 16 |
-
Base class for video datasets. Videos may be of variable length.
|
| 17 |
-
|
| 18 |
-
Folder structure of each dataset:
|
| 19 |
-
- [save_dir] (specified in config, e.g., data/phys101)
|
| 20 |
-
- /[split] (one per split)
|
| 21 |
-
- /data_folder_name (e.g., videos)
|
| 22 |
-
metadata.json
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
def __init__(self, cfg: DictConfig, split: str = "training"):
|
| 26 |
-
super().__init__()
|
| 27 |
-
self.cfg = cfg
|
| 28 |
-
self.split = split
|
| 29 |
-
self.resolution = cfg.resolution
|
| 30 |
-
self.external_cond_dim = cfg.external_cond_dim
|
| 31 |
-
self.n_frames = (
|
| 32 |
-
cfg.n_frames * cfg.frame_skip
|
| 33 |
-
if split == "training"
|
| 34 |
-
else cfg.n_frames * cfg.frame_skip * cfg.validation_multiplier
|
| 35 |
-
)
|
| 36 |
-
self.frame_skip = cfg.frame_skip
|
| 37 |
-
self.save_dir = Path(cfg.save_dir)
|
| 38 |
-
self.save_dir.mkdir(exist_ok=True, parents=True)
|
| 39 |
-
self.split_dir = self.save_dir / f"{split}"
|
| 40 |
-
|
| 41 |
-
self.metadata_path = self.save_dir / "metadata.json"
|
| 42 |
-
|
| 43 |
-
self.data_paths = self.get_data_paths(self.split)
|
| 44 |
-
|
| 45 |
-
if self.split == 'training':
|
| 46 |
-
self.metadata = [1200] * len(self.data_paths) # total 1500 f
|
| 47 |
-
else:
|
| 48 |
-
self.metadata = [1] * len(self.data_paths) # total 1500 f
|
| 49 |
-
# self.clips_per_video = np.clip(np.array(self.metadata[split]) - self.n_frames + 1, a_min=1, a_max=None).astype(
|
| 50 |
-
# np.int32
|
| 51 |
-
# )
|
| 52 |
-
self.clips_per_video = np.clip(np.array(self.metadata) - self.n_frames + 1, a_min=1, a_max=None).astype(
|
| 53 |
-
np.int32
|
| 54 |
-
)
|
| 55 |
-
self.cum_clips_per_video = np.cumsum(self.clips_per_video)
|
| 56 |
-
self.transform = transforms.Resize((self.resolution, self.resolution), antialias=True)
|
| 57 |
-
|
| 58 |
-
# shuffle but keep the same order for each epoch, so validation sample is diverse yet deterministic
|
| 59 |
-
random.seed(0)
|
| 60 |
-
self.idx_remap = list(range(self.__len__()))
|
| 61 |
-
random.shuffle(self.idx_remap)
|
| 62 |
-
|
| 63 |
-
@abstractmethod
|
| 64 |
-
def download_dataset(self) -> Sequence[int]:
|
| 65 |
-
"""
|
| 66 |
-
Download dataset from the internet and build it in save_dir
|
| 67 |
-
|
| 68 |
-
Returns a list of video lengths
|
| 69 |
-
"""
|
| 70 |
-
raise NotImplementedError
|
| 71 |
-
|
| 72 |
-
@abstractmethod
|
| 73 |
-
def get_data_paths(self, split):
|
| 74 |
-
"""Return a list of data paths (e.g. xxx.mp4) for a given split"""
|
| 75 |
-
raise NotImplementedError
|
| 76 |
-
|
| 77 |
-
def get_data_lengths(self, split):
|
| 78 |
-
"""Return a list of num_frames for each data path (e.g. xxx.mp4) for a given split"""
|
| 79 |
-
lengths = []
|
| 80 |
-
for path in self.get_data_paths(split):
|
| 81 |
-
length = cv2.VideoCapture(str(path)).get(cv2.CAP_PROP_FRAME_COUNT)
|
| 82 |
-
lengths.append(length)
|
| 83 |
-
return lengths
|
| 84 |
-
|
| 85 |
-
def split_idx(self, idx):
|
| 86 |
-
video_idx = np.argmax(self.cum_clips_per_video > idx)
|
| 87 |
-
frame_idx = idx - np.pad(self.cum_clips_per_video, (1, 0))[video_idx]
|
| 88 |
-
return video_idx, frame_idx
|
| 89 |
-
|
| 90 |
-
@staticmethod
|
| 91 |
-
def load_video(path: Path):
|
| 92 |
-
"""
|
| 93 |
-
Load video from a path
|
| 94 |
-
:param filename: path to the video
|
| 95 |
-
:return: video as a numpy array
|
| 96 |
-
"""
|
| 97 |
-
|
| 98 |
-
cap = cv2.VideoCapture(str(path))
|
| 99 |
-
|
| 100 |
-
frames = []
|
| 101 |
-
while cap.isOpened():
|
| 102 |
-
ret, frame = cap.read()
|
| 103 |
-
if ret:
|
| 104 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 105 |
-
frames.append(frame)
|
| 106 |
-
else:
|
| 107 |
-
break
|
| 108 |
-
|
| 109 |
-
cap.release()
|
| 110 |
-
frames = np.stack(frames, dtype=np.uint8)
|
| 111 |
-
return np.transpose(frames, (0, 3, 1, 2)) # (T, C, H, W)
|
| 112 |
-
|
| 113 |
-
@staticmethod
|
| 114 |
-
def load_image(filename: Path):
|
| 115 |
-
"""
|
| 116 |
-
Load image from a path
|
| 117 |
-
:param filename: path to the image
|
| 118 |
-
:return: image as a numpy array
|
| 119 |
-
"""
|
| 120 |
-
image = cv2.imread(str(filename))
|
| 121 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 122 |
-
return np.transpose(image, (2, 0, 1))
|
| 123 |
-
|
| 124 |
-
def __len__(self):
|
| 125 |
-
return self.clips_per_video.sum()
|
| 126 |
-
|
| 127 |
-
def __getitem__(self, idx):
|
| 128 |
-
idx = self.idx_remap[idx]
|
| 129 |
-
video_idx, frame_idx = self.split_idx(idx)
|
| 130 |
-
video_path = self.data_paths[video_idx]
|
| 131 |
-
video = self.load_video(video_path)[frame_idx : frame_idx + self.n_frames]
|
| 132 |
-
|
| 133 |
-
pad_len = self.n_frames - len(video)
|
| 134 |
-
|
| 135 |
-
nonterminal = np.ones(self.n_frames)
|
| 136 |
-
if len(video) < self.n_frames:
|
| 137 |
-
video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0)))
|
| 138 |
-
nonterminal[-pad_len:] = 0
|
| 139 |
-
|
| 140 |
-
video = torch.from_numpy(video / 256.0).float()
|
| 141 |
-
video = self.transform(video)
|
| 142 |
-
|
| 143 |
-
if self.external_cond_dim:
|
| 144 |
-
external_cond = np.load(
|
| 145 |
-
# pylint: disable=no-member
|
| 146 |
-
self.condition_dir
|
| 147 |
-
/ f"{video_path.name.replace('.mp4', '.npy')}"
|
| 148 |
-
)
|
| 149 |
-
if len(external_cond) < self.n_frames:
|
| 150 |
-
external_cond = np.pad(external_cond, ((0, pad_len),))
|
| 151 |
-
external_cond = torch.from_numpy(external_cond).float()
|
| 152 |
-
return (
|
| 153 |
-
video[:: self.frame_skip],
|
| 154 |
-
external_cond[:: self.frame_skip],
|
| 155 |
-
nonterminal[:: self.frame_skip],
|
| 156 |
-
)
|
| 157 |
-
else:
|
| 158 |
-
return video[:: self.frame_skip], nonterminal[:: self.frame_skip]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/video/minecraft_video_dataset.py
DELETED
|
@@ -1,262 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import io
|
| 3 |
-
import tarfile
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
from typing import Sequence, Mapping
|
| 7 |
-
from omegaconf import DictConfig
|
| 8 |
-
from pytorchvideo.data.encoded_video import EncodedVideo
|
| 9 |
-
import random
|
| 10 |
-
|
| 11 |
-
from .base_video_dataset import BaseVideoDataset
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
ACTION_KEYS = [
|
| 17 |
-
"inventory",
|
| 18 |
-
"ESC",
|
| 19 |
-
"hotbar.1",
|
| 20 |
-
"hotbar.2",
|
| 21 |
-
"hotbar.3",
|
| 22 |
-
"hotbar.4",
|
| 23 |
-
"hotbar.5",
|
| 24 |
-
"hotbar.6",
|
| 25 |
-
"hotbar.7",
|
| 26 |
-
"hotbar.8",
|
| 27 |
-
"hotbar.9",
|
| 28 |
-
"forward",
|
| 29 |
-
"back",
|
| 30 |
-
"left",
|
| 31 |
-
"right",
|
| 32 |
-
"cameraY",
|
| 33 |
-
"cameraX",
|
| 34 |
-
"jump",
|
| 35 |
-
"sneak",
|
| 36 |
-
"sprint",
|
| 37 |
-
"swapHands",
|
| 38 |
-
"attack",
|
| 39 |
-
"use",
|
| 40 |
-
"pickItem",
|
| 41 |
-
"drop",
|
| 42 |
-
]
|
| 43 |
-
|
| 44 |
-
def convert_action_space(actions):
|
| 45 |
-
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
|
| 46 |
-
vec_25[actions[:,0]==1, 11] = 1
|
| 47 |
-
vec_25[actions[:,0]==2, 12] = 1
|
| 48 |
-
vec_25[actions[:,4]==11, 16] = -1
|
| 49 |
-
vec_25[actions[:,4]==13, 16] = 1
|
| 50 |
-
vec_25[actions[:,3]==11, 15] = -1
|
| 51 |
-
vec_25[actions[:,3]==13, 15] = 1
|
| 52 |
-
vec_25[actions[:,5]==6, 24] = 1
|
| 53 |
-
vec_25[actions[:,5]==1, 24] = 1
|
| 54 |
-
vec_25[actions[:,1]==1, 13] = 1
|
| 55 |
-
vec_25[actions[:,1]==2, 14] = 1
|
| 56 |
-
vec_25[actions[:,7]==1, 2] = 1
|
| 57 |
-
return vec_25
|
| 58 |
-
|
| 59 |
-
# Dataset class
|
| 60 |
-
class MinecraftVideoDataset(BaseVideoDataset):
|
| 61 |
-
"""
|
| 62 |
-
Minecraft video dataset for training and validation.
|
| 63 |
-
|
| 64 |
-
Args:
|
| 65 |
-
cfg (DictConfig): Configuration object.
|
| 66 |
-
split (str): Dataset split ("training" or "validation").
|
| 67 |
-
"""
|
| 68 |
-
def __init__(self, cfg: DictConfig, split: str = "training"):
|
| 69 |
-
if split == "test":
|
| 70 |
-
split = "validation"
|
| 71 |
-
super().__init__(cfg, split)
|
| 72 |
-
self.n_frames = cfg.n_frames_valid if split == "validation" and hasattr(cfg, "n_frames_valid") else cfg.n_frames
|
| 73 |
-
self.use_plucker = cfg.use_plucker
|
| 74 |
-
self.condition_similar_length = cfg.condition_similar_length
|
| 75 |
-
self.customized_validation = cfg.customized_validation
|
| 76 |
-
self.angle_range = cfg.angle_range
|
| 77 |
-
self.pos_range = cfg.pos_range
|
| 78 |
-
self.add_frame_timestep_embedder = cfg.add_frame_timestep_embedder
|
| 79 |
-
self.training_dropout = 0.1
|
| 80 |
-
self.sample_more_place = getattr(cfg, "sample_more_place", False)
|
| 81 |
-
self.within_context = getattr(cfg, "within_context", False)
|
| 82 |
-
self.sample_more_event = getattr(cfg, "sample_more_event", False)
|
| 83 |
-
self.causal_frame = getattr(cfg, "causal_frame", False)
|
| 84 |
-
|
| 85 |
-
def get_data_paths(self, split: str):
|
| 86 |
-
"""
|
| 87 |
-
Retrieve all video file paths for the given split.
|
| 88 |
-
|
| 89 |
-
Args:
|
| 90 |
-
split (str): Dataset split ("training" or "validation").
|
| 91 |
-
|
| 92 |
-
Returns:
|
| 93 |
-
List[Path]: List of video file paths.
|
| 94 |
-
"""
|
| 95 |
-
data_dir = self.save_dir / split
|
| 96 |
-
paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
|
| 97 |
-
if not paths:
|
| 98 |
-
sub_dirs = os.listdir(data_dir)
|
| 99 |
-
for sub_dir in sub_dirs:
|
| 100 |
-
sub_path = data_dir / sub_dir
|
| 101 |
-
paths += sorted(list(sub_path.glob("**/*.mp4")), key=lambda x: x.name)
|
| 102 |
-
return paths
|
| 103 |
-
|
| 104 |
-
def download_dataset(self):
|
| 105 |
-
pass
|
| 106 |
-
|
| 107 |
-
def __getitem__(self, idx: int):
|
| 108 |
-
"""
|
| 109 |
-
Retrieve a single data sample by index.
|
| 110 |
-
|
| 111 |
-
Args:
|
| 112 |
-
idx (int): Index of the data sample.
|
| 113 |
-
|
| 114 |
-
Returns:
|
| 115 |
-
Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]: Video, actions, poses, and timesteps.
|
| 116 |
-
"""
|
| 117 |
-
max_retries = 1000
|
| 118 |
-
for _ in range(max_retries):
|
| 119 |
-
try:
|
| 120 |
-
return self.load_data(idx)
|
| 121 |
-
except Exception as e:
|
| 122 |
-
print(f"Retrying due to error: {e}")
|
| 123 |
-
idx = (idx + 1) % len(self)
|
| 124 |
-
|
| 125 |
-
def load_data(self, idx):
|
| 126 |
-
idx = self.idx_remap[idx]
|
| 127 |
-
file_idx, frame_idx = self.split_idx(idx)
|
| 128 |
-
action_path = self.data_paths[file_idx]
|
| 129 |
-
video_path = self.data_paths[file_idx]
|
| 130 |
-
|
| 131 |
-
action_path = video_path.with_suffix(".npz")
|
| 132 |
-
actions_pool = np.load(action_path)['actions']
|
| 133 |
-
poses_pool = np.load(action_path)['poses']
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
poses_pool[0,1] = poses_pool[1,1] # wrong first in place
|
| 137 |
-
|
| 138 |
-
assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}"
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
if len(poses_pool) < len(actions_pool):
|
| 142 |
-
poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))
|
| 143 |
-
|
| 144 |
-
actions_pool = convert_action_space(actions_pool)
|
| 145 |
-
video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
|
| 146 |
-
|
| 147 |
-
frame_idx = frame_idx + 100 # avoid first frames # first frame is useless
|
| 148 |
-
|
| 149 |
-
if self.split == "validation":
|
| 150 |
-
frame_idx = 240
|
| 151 |
-
|
| 152 |
-
if self.sample_more_place and self.split == "training":
|
| 153 |
-
if random.uniform(0, 1) > 0.5:
|
| 154 |
-
place_mask = (actions_pool[:,24]==1)
|
| 155 |
-
place_mask[:100] = 0
|
| 156 |
-
valid_indices = np.where(place_mask)[0]
|
| 157 |
-
random_index = np.random.choice(valid_indices)
|
| 158 |
-
frame_idx = random_index - random.randint(1, self.n_frames-1)
|
| 159 |
-
|
| 160 |
-
total_frame = video_raw.duration.numerator
|
| 161 |
-
fps = 10 # video_raw.duration.denominator
|
| 162 |
-
total_frame = total_frame * fps / video_raw.duration.denominator
|
| 163 |
-
video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"]
|
| 164 |
-
video = video.permute(1, 2, 3, 0).numpy()
|
| 165 |
-
|
| 166 |
-
if self.split != "validation" and 'degrees' in np.load(action_path).keys():
|
| 167 |
-
degrees = np.load(action_path)['degrees']
|
| 168 |
-
actions_pool[:,16] *= degrees
|
| 169 |
-
|
| 170 |
-
actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames])
|
| 171 |
-
|
| 172 |
-
poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])
|
| 173 |
-
pad_len = self.n_frames - len(video)
|
| 174 |
-
poses_pool[:,:3] -= poses[:1,:3]
|
| 175 |
-
poses_pool[:,-1] = -poses_pool[:,-1]
|
| 176 |
-
poses_pool[:,3:] %= 360
|
| 177 |
-
|
| 178 |
-
poses[:,:3] -= poses[:1,:3] # do not normalize angle
|
| 179 |
-
poses[:,-1] = -poses[:,-1]
|
| 180 |
-
poses[:,3:] %= 360
|
| 181 |
-
|
| 182 |
-
assert len(video) >= self.n_frames, f"{video_path}"
|
| 183 |
-
|
| 184 |
-
if self.split == "training" and self.condition_similar_length>0:
|
| 185 |
-
if random.uniform(0, 1) > self.training_dropout:
|
| 186 |
-
refer_frame_dis = poses[:,None] - poses_pool[None,:]
|
| 187 |
-
refer_frame_dis = np.abs(refer_frame_dis)
|
| 188 |
-
refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180] = 360 - refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180]
|
| 189 |
-
valid_index = ((((refer_frame_dis[..., :3] <= self.pos_range).sum(-1))>=3) & (((refer_frame_dis[..., 3:] <= self.angle_range).sum(-1))>=2) & \
|
| 190 |
-
((((refer_frame_dis[..., :3] > 0).sum(-1))>=1) | (((refer_frame_dis[..., 3:] > 0).sum(-1))>=1))
|
| 191 |
-
).sum(0)
|
| 192 |
-
valid_index[:100] = 0 # mute bad initial scene
|
| 193 |
-
|
| 194 |
-
if self.add_frame_timestep_embedder and self.causal_frame and (actions_pool[:frame_idx,24]==1).sum() > 0:
|
| 195 |
-
valid_index[frame_idx:] = 0
|
| 196 |
-
|
| 197 |
-
mask = valid_index >= 1
|
| 198 |
-
mask[0] = False
|
| 199 |
-
candidate_indices = np.argwhere(mask)
|
| 200 |
-
|
| 201 |
-
mask2 = valid_index >= 0
|
| 202 |
-
mask2[0] = False
|
| 203 |
-
|
| 204 |
-
count = min(self.condition_similar_length, candidate_indices.shape[0])
|
| 205 |
-
selected_indices = candidate_indices[np.random.choice(candidate_indices.shape[0], count, replace=True)][:,0]
|
| 206 |
-
|
| 207 |
-
if count < self.condition_similar_length:
|
| 208 |
-
candidate_indices2 = np.argwhere(mask2)
|
| 209 |
-
selected_indices2 = candidate_indices2[np.random.choice(candidate_indices2.shape[0], self.condition_similar_length-count, replace=True)][:,0]
|
| 210 |
-
selected_indices = np.concatenate([selected_indices, selected_indices2])
|
| 211 |
-
|
| 212 |
-
if self.sample_more_event:
|
| 213 |
-
if random.uniform(0, 1) > 0.3:
|
| 214 |
-
valid_idx = torch.nonzero(actions_pool[:frame_idx,24]==1)[:,0]
|
| 215 |
-
if len(valid_idx) > self.condition_similar_length //2:
|
| 216 |
-
valid_idx = valid_idx[-self.condition_similar_length //2:]
|
| 217 |
-
|
| 218 |
-
if len(valid_idx) > 0:
|
| 219 |
-
selected_indices[-len(valid_idx):] = valid_idx + 4
|
| 220 |
-
|
| 221 |
-
else:
|
| 222 |
-
selected_indices = np.array(list(range(self.condition_similar_length))) * 0 + random.randint(0, frame_idx)
|
| 223 |
-
|
| 224 |
-
video_pool = []
|
| 225 |
-
for si in selected_indices:
|
| 226 |
-
video_pool.append(video_raw.get_clip(start_sec=si/fps, end_sec=(si+1)/fps)["video"][:,0].permute(1,2,0))
|
| 227 |
-
|
| 228 |
-
video_pool = np.stack(video_pool)
|
| 229 |
-
video = np.concatenate([video, video_pool])
|
| 230 |
-
actions = np.concatenate([actions, actions_pool[selected_indices]])
|
| 231 |
-
poses = np.concatenate([poses, poses_pool[selected_indices]])
|
| 232 |
-
|
| 233 |
-
timestep = np.concatenate([np.array(list(range(frame_idx, frame_idx + self.n_frames))), selected_indices])
|
| 234 |
-
|
| 235 |
-
else:
|
| 236 |
-
timestep = np.array(list(range(self.n_frames)))
|
| 237 |
-
|
| 238 |
-
video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
|
| 239 |
-
|
| 240 |
-
if self.split == "validation" and not self.customized_validation:
|
| 241 |
-
num_frame = actions.shape[0]
|
| 242 |
-
|
| 243 |
-
actions[:] = 0
|
| 244 |
-
actions[:,16] = 1
|
| 245 |
-
poses[:] = 0
|
| 246 |
-
for ff in range(1, num_frame):
|
| 247 |
-
poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15
|
| 248 |
-
|
| 249 |
-
if self.within_context:
|
| 250 |
-
actions[:] = 0
|
| 251 |
-
actions[:self.n_frames//2+1,16] = 1
|
| 252 |
-
actions[self.n_frames//2+1:,16] = -1
|
| 253 |
-
poses[:] = 0
|
| 254 |
-
for ff in range(1, num_frame):
|
| 255 |
-
poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15
|
| 256 |
-
|
| 257 |
-
return (
|
| 258 |
-
video[:: self.frame_skip],
|
| 259 |
-
actions[:: self.frame_skip],
|
| 260 |
-
poses[:: self.frame_skip],
|
| 261 |
-
timestep
|
| 262 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/video/minecraft_video_dataset_oasis_filter.py
DELETED
|
@@ -1,99 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from typing import Sequence
|
| 3 |
-
import numpy as np
|
| 4 |
-
import io
|
| 5 |
-
from omegaconf import DictConfig
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
|
| 8 |
-
from typing import Mapping, Sequence
|
| 9 |
-
import os
|
| 10 |
-
import math
|
| 11 |
-
from packaging import version as pver
|
| 12 |
-
from PIL import Image
|
| 13 |
-
import random
|
| 14 |
-
import shutil
|
| 15 |
-
import os
|
| 16 |
-
from pathlib import Path
|
| 17 |
-
import traceback
|
| 18 |
-
|
| 19 |
-
class OASISMinecraftVideoFilterDataset(torch.utils.data.Dataset):
|
| 20 |
-
"""
|
| 21 |
-
Minecraft dataset
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
def __init__(self, source_dir, target_dir, split):
|
| 25 |
-
self.source_dir = Path(source_dir)
|
| 26 |
-
self.split_dir = self.source_dir / f"{split}"
|
| 27 |
-
self.data_paths = self.get_data_paths(split)
|
| 28 |
-
self.target_dir = Path(target_dir) / f"{split}"
|
| 29 |
-
self.target_dir.mkdir(exist_ok=True, parents=True)
|
| 30 |
-
|
| 31 |
-
def get_data_paths(self, split):
|
| 32 |
-
data_dir = self.source_dir / split
|
| 33 |
-
paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
|
| 34 |
-
|
| 35 |
-
if len(paths) == 0:
|
| 36 |
-
sub_path = os.listdir(data_dir)
|
| 37 |
-
for sp in sub_path:
|
| 38 |
-
data_dir = self.source_dir / split / sp
|
| 39 |
-
paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
|
| 40 |
-
return paths
|
| 41 |
-
|
| 42 |
-
def __len__(self):
|
| 43 |
-
return len(self.data_paths)
|
| 44 |
-
|
| 45 |
-
def __getitem__(self, idx):
|
| 46 |
-
|
| 47 |
-
return self.sub_get(idx)
|
| 48 |
-
# try:
|
| 49 |
-
# return self.sub_get(idx)
|
| 50 |
-
# except Exception as e:
|
| 51 |
-
# traceback.print_exc()
|
| 52 |
-
# # return self.sub_get(0)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def sub_get(self, idx):
|
| 56 |
-
action_path = self.data_paths[idx]
|
| 57 |
-
video_path = self.data_paths[idx]
|
| 58 |
-
|
| 59 |
-
action_path = video_path.with_suffix(".npz")
|
| 60 |
-
actions_pool = np.load(action_path)['actions']
|
| 61 |
-
poses_pool = np.load(action_path)['poses']
|
| 62 |
-
|
| 63 |
-
poses_pool[0,1] = poses_pool[1,1] # wrong first in place
|
| 64 |
-
|
| 65 |
-
print(poses_pool.shape)
|
| 66 |
-
|
| 67 |
-
if poses_pool[:,1].max() - poses_pool[:,1].min() < 2:
|
| 68 |
-
target_action_path = self.target_dir / action_path.parent.name / action_path.name
|
| 69 |
-
target_video_path = self.target_dir / video_path.parent.name / video_path.name
|
| 70 |
-
target_action_path.parent.mkdir(exist_ok=True, parents=True)
|
| 71 |
-
target_video_path.parent.mkdir(exist_ok=True, parents=True)
|
| 72 |
-
|
| 73 |
-
try:
|
| 74 |
-
shutil.copy2(action_path, target_action_path)
|
| 75 |
-
shutil.copy2(video_path, target_video_path)
|
| 76 |
-
except:
|
| 77 |
-
import pdb;pdb.set_trace()
|
| 78 |
-
|
| 79 |
-
return poses_pool[:10]
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
if __name__ == "__main__":
|
| 84 |
-
import torch
|
| 85 |
-
from unittest.mock import MagicMock
|
| 86 |
-
import tqdm
|
| 87 |
-
|
| 88 |
-
cfg = MagicMock()
|
| 89 |
-
cfg.resolution = 64
|
| 90 |
-
cfg.external_cond_dim = 0
|
| 91 |
-
cfg.n_frames = 64
|
| 92 |
-
cfg.save_dir = "data/minecraft"
|
| 93 |
-
cfg.validation_multiplier = 1
|
| 94 |
-
|
| 95 |
-
dataset = MinecraftVideoDataset(cfg, "training")
|
| 96 |
-
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
|
| 97 |
-
|
| 98 |
-
for batch in tqdm.tqdm(dataloader):
|
| 99 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/video/minecraft_video_dataset_pose.py
DELETED
|
@@ -1,421 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from typing import Sequence
|
| 3 |
-
import numpy as np
|
| 4 |
-
import io
|
| 5 |
-
import tarfile
|
| 6 |
-
from pytorchvideo.data.encoded_video import EncodedVideo
|
| 7 |
-
from omegaconf import DictConfig
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
-
|
| 10 |
-
from .base_video_dataset import BaseVideoDataset
|
| 11 |
-
from typing import Mapping, Sequence
|
| 12 |
-
import os
|
| 13 |
-
import math
|
| 14 |
-
from packaging import version as pver
|
| 15 |
-
from PIL import Image
|
| 16 |
-
import random
|
| 17 |
-
|
| 18 |
-
def euler_to_rotation_matrix(pitch, yaw):
|
| 19 |
-
"""
|
| 20 |
-
Convert euler angles (pitch, yaw) to a 3x3 rotation matrix.
|
| 21 |
-
pitch: rotation around x-axis (in radians)
|
| 22 |
-
yaw: rotation around y-axis (in radians)
|
| 23 |
-
"""
|
| 24 |
-
# Rotation matrix around x-axis (pitch)
|
| 25 |
-
R_x = np.array([
|
| 26 |
-
[1, 0, 0],
|
| 27 |
-
[0, math.cos(pitch), -math.sin(pitch)],
|
| 28 |
-
[0, math.sin(pitch), math.cos(pitch)]
|
| 29 |
-
])
|
| 30 |
-
|
| 31 |
-
# Rotation matrix around y-axis (yaw)
|
| 32 |
-
R_y = np.array([
|
| 33 |
-
[math.cos(yaw), 0, math.sin(yaw)],
|
| 34 |
-
[0, 1, 0],
|
| 35 |
-
[-math.sin(yaw), 0, math.cos(yaw)]
|
| 36 |
-
])
|
| 37 |
-
|
| 38 |
-
# Combined rotation matrix
|
| 39 |
-
R = np.dot(R_y, R_x)
|
| 40 |
-
return R
|
| 41 |
-
|
| 42 |
-
def custom_meshgrid(*args):
|
| 43 |
-
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
| 44 |
-
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
| 45 |
-
return torch.meshgrid(*args)
|
| 46 |
-
else:
|
| 47 |
-
return torch.meshgrid(*args, indexing='ij')
|
| 48 |
-
|
| 49 |
-
def camera_to_world_to_world_to_camera(camera_to_world):
|
| 50 |
-
"""
|
| 51 |
-
Convert Camera-to-World matrix to World-to-Camera matrix by inverting the transformation.
|
| 52 |
-
"""
|
| 53 |
-
# Extract rotation (R) and translation (T)
|
| 54 |
-
R = camera_to_world[:3, :3]
|
| 55 |
-
T = camera_to_world[:3, 3]
|
| 56 |
-
|
| 57 |
-
# Calculate World-to-Camera (inverse) matrix
|
| 58 |
-
world_to_camera = np.eye(4)
|
| 59 |
-
|
| 60 |
-
# The rotation part of World-to-Camera is the transpose of Camera-to-World's rotation
|
| 61 |
-
world_to_camera[:3, :3] = R.T
|
| 62 |
-
|
| 63 |
-
# The translation part is the negative of the rotated translation
|
| 64 |
-
world_to_camera[:3, 3] = -np.dot(R.T, T)
|
| 65 |
-
|
| 66 |
-
return world_to_camera
|
| 67 |
-
|
| 68 |
-
def euler_to_camera_to_world_matrix(pose):
|
| 69 |
-
|
| 70 |
-
x, y, z, pitch, yaw = pose
|
| 71 |
-
# Convert pitch and yaw to radians
|
| 72 |
-
pitch = math.radians(pitch)
|
| 73 |
-
yaw = math.radians(yaw)
|
| 74 |
-
|
| 75 |
-
# Get the rotation matrix from Euler angles
|
| 76 |
-
R = euler_to_rotation_matrix(pitch, yaw)
|
| 77 |
-
|
| 78 |
-
# Create the 4x4 transformation matrix (rotation + translation)
|
| 79 |
-
camera_to_world = np.eye(4)
|
| 80 |
-
|
| 81 |
-
# Set the rotation part (upper 3x3)
|
| 82 |
-
camera_to_world[:3, :3] = R
|
| 83 |
-
|
| 84 |
-
# Set the translation part (last column)
|
| 85 |
-
camera_to_world[:3, 3] = [x, y, z]
|
| 86 |
-
|
| 87 |
-
return camera_to_world
|
| 88 |
-
|
| 89 |
-
def tensor_to_gif(tensor, output_path, fps=10):
|
| 90 |
-
"""
|
| 91 |
-
Converts a PyTorch tensor of shape (F, 3, H, W) to a GIF.
|
| 92 |
-
|
| 93 |
-
Args:
|
| 94 |
-
tensor (torch.Tensor): Input tensor of shape (F, 3, H, W) with values in range [0, 1] or [0, 255].
|
| 95 |
-
output_path (str): Path to save the output GIF.
|
| 96 |
-
fps (int): Frames per second for the GIF.
|
| 97 |
-
"""
|
| 98 |
-
# Ensure the tensor is in [0, 255] range
|
| 99 |
-
if tensor.max() <= 1.0:
|
| 100 |
-
tensor = (tensor * 255).byte()
|
| 101 |
-
else:
|
| 102 |
-
tensor = tensor.byte()
|
| 103 |
-
|
| 104 |
-
# Convert tensor to numpy array and rearrange to (F, H, W, 3)
|
| 105 |
-
frames = tensor.permute(0, 2, 3, 1).cpu().numpy()
|
| 106 |
-
|
| 107 |
-
# Convert frames to PIL Images
|
| 108 |
-
pil_frames = [Image.fromarray(frame) for frame in frames]
|
| 109 |
-
|
| 110 |
-
# Save as GIF
|
| 111 |
-
pil_frames[0].save(
|
| 112 |
-
output_path,
|
| 113 |
-
save_all=True,
|
| 114 |
-
append_images=pil_frames[1:],
|
| 115 |
-
duration=int(1000 / fps),
|
| 116 |
-
loop=0
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
def get_relative_pose(cam_params, zero_first_frame_scale):
|
| 120 |
-
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
| 121 |
-
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
| 122 |
-
source_cam_c2w = abs_c2ws[0]
|
| 123 |
-
if zero_first_frame_scale:
|
| 124 |
-
cam_to_origin = 0
|
| 125 |
-
else:
|
| 126 |
-
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
|
| 127 |
-
target_cam_c2w = np.array([
|
| 128 |
-
[1, 0, 0, 0],
|
| 129 |
-
[0, 1, 0, -cam_to_origin],
|
| 130 |
-
[0, 0, 1, 0],
|
| 131 |
-
[0, 0, 0, 1]
|
| 132 |
-
])
|
| 133 |
-
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
| 134 |
-
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
| 135 |
-
ret_poses = np.array(ret_poses, dtype=np.float32)
|
| 136 |
-
return ret_poses
|
| 137 |
-
|
| 138 |
-
def ray_condition(K, c2w, H, W, device):
|
| 139 |
-
# c2w: B, V, 4, 4
|
| 140 |
-
# K: B, V, 4
|
| 141 |
-
|
| 142 |
-
B = K.shape[0]
|
| 143 |
-
|
| 144 |
-
j, i = custom_meshgrid(
|
| 145 |
-
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
| 146 |
-
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
| 147 |
-
)
|
| 148 |
-
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 149 |
-
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 150 |
-
|
| 151 |
-
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
| 152 |
-
|
| 153 |
-
zs = torch.ones_like(i) # [B, HxW]
|
| 154 |
-
xs = (i - cx) / fx * zs
|
| 155 |
-
ys = (j - cy) / fy * zs
|
| 156 |
-
zs = zs.expand_as(ys)
|
| 157 |
-
|
| 158 |
-
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
| 159 |
-
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
| 160 |
-
|
| 161 |
-
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
| 162 |
-
rays_o = c2w[..., :3, 3] # B, V, 3
|
| 163 |
-
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
| 164 |
-
# c2w @ dirctions
|
| 165 |
-
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
| 166 |
-
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
| 167 |
-
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
| 168 |
-
|
| 169 |
-
return plucker
|
| 170 |
-
|
| 171 |
-
class Camera(object):
|
| 172 |
-
def __init__(self, entry, focal_length=0.35):
|
| 173 |
-
self.fx = focal_length # 0.35 correspond to 110 fov
|
| 174 |
-
self.fy = focal_length*640/360
|
| 175 |
-
self.cx = 0.5
|
| 176 |
-
self.cy = 0.5
|
| 177 |
-
self.c2w_mat = euler_to_camera_to_world_matrix(entry)
|
| 178 |
-
self.w2c_mat = camera_to_world_to_world_to_camera(np.copy(self.c2w_mat))
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
ACTION_KEYS = [
|
| 182 |
-
"inventory",
|
| 183 |
-
"ESC",
|
| 184 |
-
"hotbar.1",
|
| 185 |
-
"hotbar.2",
|
| 186 |
-
"hotbar.3",
|
| 187 |
-
"hotbar.4",
|
| 188 |
-
"hotbar.5",
|
| 189 |
-
"hotbar.6",
|
| 190 |
-
"hotbar.7",
|
| 191 |
-
"hotbar.8",
|
| 192 |
-
"hotbar.9",
|
| 193 |
-
"forward",
|
| 194 |
-
"back",
|
| 195 |
-
"left",
|
| 196 |
-
"right",
|
| 197 |
-
"cameraY",
|
| 198 |
-
"cameraX",
|
| 199 |
-
"jump",
|
| 200 |
-
"sneak",
|
| 201 |
-
"sprint",
|
| 202 |
-
"swapHands",
|
| 203 |
-
"attack",
|
| 204 |
-
"use",
|
| 205 |
-
"pickItem",
|
| 206 |
-
"drop",
|
| 207 |
-
]
|
| 208 |
-
|
| 209 |
-
def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
|
| 210 |
-
actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
|
| 211 |
-
for i, current_actions in enumerate(actions):
|
| 212 |
-
for j, action_key in enumerate(ACTION_KEYS):
|
| 213 |
-
if action_key.startswith("camera"):
|
| 214 |
-
if action_key == "cameraX":
|
| 215 |
-
value = current_actions["camera"][0]
|
| 216 |
-
elif action_key == "cameraY":
|
| 217 |
-
value = current_actions["camera"][1]
|
| 218 |
-
else:
|
| 219 |
-
raise ValueError(f"Unknown camera action key: {action_key}")
|
| 220 |
-
max_val = 20
|
| 221 |
-
bin_size = 0.5
|
| 222 |
-
num_buckets = int(max_val / bin_size)
|
| 223 |
-
value = (value - num_buckets) / num_buckets
|
| 224 |
-
assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
|
| 225 |
-
else:
|
| 226 |
-
value = current_actions[action_key]
|
| 227 |
-
assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
|
| 228 |
-
actions_one_hot[i, j] = value
|
| 229 |
-
|
| 230 |
-
return actions_one_hot
|
| 231 |
-
|
| 232 |
-
def simpletomulti(actions):
|
| 233 |
-
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
|
| 234 |
-
vec_25[actions==1, 11] = 1
|
| 235 |
-
vec_25[actions==2, 16] = -1
|
| 236 |
-
vec_25[actions==3, 16] = 1
|
| 237 |
-
vec_25[actions==4, 15] = -1
|
| 238 |
-
vec_25[actions==5, 15] = 1
|
| 239 |
-
return vec_25
|
| 240 |
-
|
| 241 |
-
def simpletomulti2(actions):
|
| 242 |
-
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
|
| 243 |
-
vec_25[actions[:,0]==1, 11] = 1
|
| 244 |
-
vec_25[actions[:,0]==2, 12] = 1
|
| 245 |
-
vec_25[actions[:,4]==11, 16] = -1
|
| 246 |
-
vec_25[actions[:,4]==13, 16] = 1
|
| 247 |
-
vec_25[actions[:,3]==11, 15] = -1
|
| 248 |
-
vec_25[actions[:,3]==13, 15] = 1
|
| 249 |
-
vec_25[actions[:,5]==6, 24] = 1
|
| 250 |
-
vec_25[actions[:,5]==1, 24] = 1
|
| 251 |
-
vec_25[actions[:,1]==1, 13] = 1
|
| 252 |
-
vec_25[actions[:,1]==2, 14] = 1
|
| 253 |
-
vec_25[actions[:,7]==1, 2] = 1
|
| 254 |
-
return vec_25
|
| 255 |
-
|
| 256 |
-
class MinecraftVideoPoseDataset(BaseVideoDataset):
|
| 257 |
-
"""
|
| 258 |
-
Minecraft dataset
|
| 259 |
-
"""
|
| 260 |
-
|
| 261 |
-
def __init__(self, cfg: DictConfig, split: str = "training"):
|
| 262 |
-
if split == "test":
|
| 263 |
-
split = "validation"
|
| 264 |
-
super().__init__(cfg, split)
|
| 265 |
-
|
| 266 |
-
if hasattr(cfg, "n_frames_valid") and split == "validation":
|
| 267 |
-
self.n_frames = cfg.n_frames_valid
|
| 268 |
-
|
| 269 |
-
def get_data_paths(self, split):
|
| 270 |
-
data_dir = self.save_dir / split
|
| 271 |
-
paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
|
| 272 |
-
|
| 273 |
-
if len(paths) == 0:
|
| 274 |
-
sub_path = os.listdir(data_dir)
|
| 275 |
-
for sp in sub_path:
|
| 276 |
-
data_dir = self.save_dir / split / sp
|
| 277 |
-
paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
|
| 278 |
-
return paths
|
| 279 |
-
|
| 280 |
-
def get_data_lengths(self, split):
|
| 281 |
-
lengths = [300] * len(self.get_data_paths(split))
|
| 282 |
-
return lengths
|
| 283 |
-
|
| 284 |
-
def download_dataset(self) -> Sequence[int]:
|
| 285 |
-
from internetarchive import download
|
| 286 |
-
|
| 287 |
-
part_suffixes = [
|
| 288 |
-
"aa",
|
| 289 |
-
"ab",
|
| 290 |
-
"ac",
|
| 291 |
-
"ad",
|
| 292 |
-
"ae",
|
| 293 |
-
"af",
|
| 294 |
-
"ag",
|
| 295 |
-
"ah",
|
| 296 |
-
"ai",
|
| 297 |
-
"aj",
|
| 298 |
-
"ak",
|
| 299 |
-
]
|
| 300 |
-
for part_suffix in part_suffixes:
|
| 301 |
-
identifier = f"minecraft_marsh_dataset_{part_suffix}"
|
| 302 |
-
file_name = f"minecraft.tar.part{part_suffix}"
|
| 303 |
-
download(identifier, file_name, destdir=self.save_dir, verbose=True)
|
| 304 |
-
|
| 305 |
-
combined_bytes = io.BytesIO()
|
| 306 |
-
for part_suffix in part_suffixes:
|
| 307 |
-
identifier = f"minecraft_marsh_dataset_{part_suffix}"
|
| 308 |
-
file_name = f"minecraft.tar.part{part_suffix}"
|
| 309 |
-
part_file = self.save_dir / identifier / file_name
|
| 310 |
-
with open(part_file, "rb") as part:
|
| 311 |
-
combined_bytes.write(part.read())
|
| 312 |
-
combined_bytes.seek(0)
|
| 313 |
-
with tarfile.open(fileobj=combined_bytes, mode="r") as combined_archive:
|
| 314 |
-
combined_archive.extractall(self.save_dir)
|
| 315 |
-
(self.save_dir / "minecraft/test").rename(self.save_dir / "validation")
|
| 316 |
-
(self.save_dir / "minecraft/train").rename(self.save_dir / "training")
|
| 317 |
-
(self.save_dir / "minecraft").rmdir()
|
| 318 |
-
for part_suffix in part_suffixes:
|
| 319 |
-
identifier = f"minecraft_marsh_dataset_{part_suffix}"
|
| 320 |
-
file_name = f"minecraft.tar.part{part_suffix}"
|
| 321 |
-
part_file = self.save_dir / identifier / file_name
|
| 322 |
-
part_file.rmdir()
|
| 323 |
-
|
| 324 |
-
def __getitem__(self, idx):
|
| 325 |
-
# return self.load_data(idx)
|
| 326 |
-
|
| 327 |
-
max_retries = 1000
|
| 328 |
-
for mr in range(max_retries):
|
| 329 |
-
try:
|
| 330 |
-
return self.load_data(idx)
|
| 331 |
-
except Exception as e:
|
| 332 |
-
print(f"{mr} Error: {e}")
|
| 333 |
-
# idx = self.idx_remap[idx]
|
| 334 |
-
# file_idx, frame_idx = self.split_idx(idx)
|
| 335 |
-
# video_path = self.data_paths[file_idx]
|
| 336 |
-
# os.remove(video_path)
|
| 337 |
-
idx = (idx + 1) % self.__len__()
|
| 338 |
-
|
| 339 |
-
def load_data(self, idx):
|
| 340 |
-
idx = self.idx_remap[idx]
|
| 341 |
-
file_idx, frame_idx = self.split_idx(idx)
|
| 342 |
-
action_path = self.data_paths[file_idx]
|
| 343 |
-
video_path = self.data_paths[file_idx]
|
| 344 |
-
|
| 345 |
-
action_path = video_path.with_suffix(".npz")
|
| 346 |
-
actions_pool = np.load(action_path)['actions']
|
| 347 |
-
poses_pool = np.load(action_path)['poses']
|
| 348 |
-
|
| 349 |
-
poses_pool[0,1] = poses_pool[1,1] # wrong first in place
|
| 350 |
-
|
| 351 |
-
assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}"
|
| 352 |
-
|
| 353 |
-
if len(poses_pool) < len(actions_pool):
|
| 354 |
-
poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))
|
| 355 |
-
|
| 356 |
-
actions_pool = simpletomulti2(actions_pool)
|
| 357 |
-
video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
|
| 358 |
-
|
| 359 |
-
frame_idx = frame_idx + 100 # avoid first frames # first frame is useless
|
| 360 |
-
|
| 361 |
-
if self.split == "validation":
|
| 362 |
-
frame_idx = 240
|
| 363 |
-
|
| 364 |
-
total_frame = video_raw.duration.numerator
|
| 365 |
-
fps = 10 # video_raw.duration.denominator
|
| 366 |
-
total_frame = total_frame * fps / video_raw.duration.denominator
|
| 367 |
-
video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"]
|
| 368 |
-
|
| 369 |
-
video = video.permute(1, 2, 3, 0).numpy()
|
| 370 |
-
|
| 371 |
-
if self.split != "validation" and 'degrees' in np.load(action_path).keys():
|
| 372 |
-
degrees = np.load(action_path)['degrees']
|
| 373 |
-
actions_pool[:,16] *= degrees
|
| 374 |
-
|
| 375 |
-
actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) # (t, )
|
| 376 |
-
|
| 377 |
-
poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])
|
| 378 |
-
pad_len = self.n_frames - len(video)
|
| 379 |
-
poses_pool[:,:3] -= poses[:1,:3]
|
| 380 |
-
# poses_pool[:,3:] = -poses_pool[:,3:]
|
| 381 |
-
poses_pool[:,-1] = -poses_pool[:,-1]
|
| 382 |
-
poses_pool[:,3:] %= 360
|
| 383 |
-
|
| 384 |
-
poses[:,:3] -= poses[:1,:3] # do not normalize angle
|
| 385 |
-
# poses[:,3:] = -poses[:,3:]
|
| 386 |
-
poses[:,-1] = -poses[:,-1]
|
| 387 |
-
poses[:,3:] %= 360
|
| 388 |
-
|
| 389 |
-
nonterminal = np.ones(self.n_frames)
|
| 390 |
-
if len(video) < self.n_frames:
|
| 391 |
-
video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0)))
|
| 392 |
-
actions = np.pad(actions, ((0, pad_len),))
|
| 393 |
-
poses = np.pad(actions, ((0, pad_len),))
|
| 394 |
-
nonterminal[-pad_len:] = 0
|
| 395 |
-
|
| 396 |
-
video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
|
| 397 |
-
|
| 398 |
-
return (
|
| 399 |
-
video[:: self.frame_skip],
|
| 400 |
-
actions[:: self.frame_skip],
|
| 401 |
-
poses[:: self.frame_skip]
|
| 402 |
-
)
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
if __name__ == "__main__":
|
| 406 |
-
import torch
|
| 407 |
-
from unittest.mock import MagicMock
|
| 408 |
-
import tqdm
|
| 409 |
-
|
| 410 |
-
cfg = MagicMock()
|
| 411 |
-
cfg.resolution = 64
|
| 412 |
-
cfg.external_cond_dim = 0
|
| 413 |
-
cfg.n_frames = 64
|
| 414 |
-
cfg.save_dir = "data/minecraft"
|
| 415 |
-
cfg.validation_multiplier = 1
|
| 416 |
-
|
| 417 |
-
dataset = MinecraftVideoDataset(cfg, "training")
|
| 418 |
-
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
|
| 419 |
-
|
| 420 |
-
for batch in tqdm.tqdm(dataloader):
|
| 421 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/README.md
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
# experiments
|
| 2 |
-
|
| 3 |
-
`experiments` folder contains code of experiments. Each file in the experiment folder represents a certain type of
|
| 4 |
-
benchmark specific to a project. Such experiment can be instantiated with a certain dataset and a certain algorithm.
|
| 5 |
-
|
| 6 |
-
You should create a new `.py` file for your experiment,
|
| 7 |
-
inherent from any suitable base classes in `experiments/exp_base.py`,
|
| 8 |
-
and then register your new experiment in `experiments/__init__.py`.
|
| 9 |
-
|
| 10 |
-
You run an experiment by running `python -m main [options]` in the root directory of the
|
| 11 |
-
project. You should not log any data in this folder, but storing them under `outputs` under root project
|
| 12 |
-
directory.
|
| 13 |
-
|
| 14 |
-
This folder is only intend to contain formal experiments. For debug code and unit tests, put them under `debug` folder.
|
| 15 |
-
For scripts that's not meant to be an experiment please use `scripts` folder.
|
| 16 |
-
|
| 17 |
-
---
|
| 18 |
-
|
| 19 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/__init__.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
from typing import Optional, Union
|
| 2 |
-
from omegaconf import DictConfig
|
| 3 |
-
import pathlib
|
| 4 |
-
from lightning.pytorch.loggers.wandb import WandbLogger
|
| 5 |
-
|
| 6 |
-
from .exp_base import BaseExperiment
|
| 7 |
-
from .exp_video import VideoPredictionExperiment
|
| 8 |
-
from .exp_pose import PoseExperiment
|
| 9 |
-
|
| 10 |
-
# each key has to be a yaml file under '[project_root]/configurations/experiment' without .yaml suffix
|
| 11 |
-
exp_registry = dict(
|
| 12 |
-
exp_video=VideoPredictionExperiment,
|
| 13 |
-
exp_pose=PoseExperiment
|
| 14 |
-
)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def build_experiment(
|
| 18 |
-
cfg: DictConfig,
|
| 19 |
-
logger: Optional[WandbLogger] = None,
|
| 20 |
-
ckpt_path: Optional[Union[str, pathlib.Path]] = None,
|
| 21 |
-
) -> BaseExperiment:
|
| 22 |
-
"""
|
| 23 |
-
Build an experiment instance based on registry
|
| 24 |
-
:param cfg: configuration file
|
| 25 |
-
:param logger: optional logger for the experiment
|
| 26 |
-
:param ckpt_path: optional checkpoint path for saving and loading
|
| 27 |
-
:return:
|
| 28 |
-
"""
|
| 29 |
-
if cfg.experiment._name not in exp_registry:
|
| 30 |
-
raise ValueError(
|
| 31 |
-
f"Experiment {cfg.experiment._name} not found in registry {list(exp_registry.keys())}. "
|
| 32 |
-
"Make sure you register it correctly in 'experiments/__init__.py' under the same name as yaml file."
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
return exp_registry[cfg.experiment._name](cfg, logger, ckpt_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/exp_base.py
DELETED
|
@@ -1,473 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
|
| 3 |
-
template [repo](https://github.com/buoyancy99/research-template).
|
| 4 |
-
By its MIT license, you must keep the above sentence in `README.md`
|
| 5 |
-
and the `LICENSE` file to credit the author.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from abc import ABC, abstractmethod
|
| 9 |
-
from typing import Optional, Union, Literal, List, Dict
|
| 10 |
-
import pathlib
|
| 11 |
-
import os
|
| 12 |
-
|
| 13 |
-
import hydra
|
| 14 |
-
import torch
|
| 15 |
-
from lightning.pytorch.strategies.ddp import DDPStrategy
|
| 16 |
-
|
| 17 |
-
import lightning.pytorch as pl
|
| 18 |
-
from lightning.pytorch.loggers.wandb import WandbLogger
|
| 19 |
-
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
|
| 20 |
-
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
| 21 |
-
from pytorch_lightning.utilities import rank_zero_info
|
| 22 |
-
|
| 23 |
-
from omegaconf import DictConfig
|
| 24 |
-
|
| 25 |
-
from utils.print_utils import cyan
|
| 26 |
-
from utils.distributed_utils import is_rank_zero
|
| 27 |
-
from safetensors.torch import load_model
|
| 28 |
-
from pathlib import Path
|
| 29 |
-
from huggingface_hub import hf_hub_download
|
| 30 |
-
|
| 31 |
-
torch.set_float32_matmul_precision("high")
|
| 32 |
-
|
| 33 |
-
def load_custom_checkpoint(algo, optimizer, checkpoint_path):
|
| 34 |
-
if not checkpoint_path:
|
| 35 |
-
rank_zero_info("No checkpoint path provided, skipping checkpoint loading.")
|
| 36 |
-
return None
|
| 37 |
-
|
| 38 |
-
if not isinstance(checkpoint_path, Path):
|
| 39 |
-
checkpoint_path = Path(checkpoint_path)
|
| 40 |
-
|
| 41 |
-
if "yslan" in str(checkpoint_path):
|
| 42 |
-
hf_ckpt = str(checkpoint_path).split('/')
|
| 43 |
-
repo_id = '/'.join(hf_ckpt[:2])
|
| 44 |
-
file_name = '/'.join(hf_ckpt[2:])
|
| 45 |
-
model_path = hf_hub_download(repo_id=repo_id,
|
| 46 |
-
filename=file_name)
|
| 47 |
-
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
| 48 |
-
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
| 49 |
-
|
| 50 |
-
elif checkpoint_path.suffix == ".pt":
|
| 51 |
-
ckpt = torch.load(checkpoint_path, weights_only=True)
|
| 52 |
-
algo.load_state_dict(ckpt, strict=False)
|
| 53 |
-
elif checkpoint_path.suffix == ".ckpt":
|
| 54 |
-
ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
| 55 |
-
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
| 56 |
-
elif checkpoint_path.suffix == ".safetensors":
|
| 57 |
-
load_model(algo, checkpoint_path, strict=False)
|
| 58 |
-
elif os.path.isdir(checkpoint_path):
|
| 59 |
-
ckpt_files = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt')]
|
| 60 |
-
if not ckpt_files:
|
| 61 |
-
raise FileNotFoundError("在指定文件夹中未找到任何 .ckpt 文件!")
|
| 62 |
-
selected_ckpt = max(ckpt_files)
|
| 63 |
-
selected_ckpt_path = os.path.join(checkpoint_path, selected_ckpt)
|
| 64 |
-
print(f"加载的 checkpoint 文件为: {selected_ckpt_path}")
|
| 65 |
-
|
| 66 |
-
ckpt = torch.load(selected_ckpt_path, map_location=torch.device('cpu'))
|
| 67 |
-
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
| 68 |
-
|
| 69 |
-
rank_zero_info("Model weights loaded.")
|
| 70 |
-
|
| 71 |
-
class BaseExperiment(ABC):
|
| 72 |
-
"""
|
| 73 |
-
Abstract class for an experiment. This generalizes the pytorch lightning Trainer & lightning Module to more
|
| 74 |
-
flexible experiments that doesn't fit in the typical ml loop, e.g. multi-stage reinforcement learning benchmarks.
|
| 75 |
-
"""
|
| 76 |
-
|
| 77 |
-
# each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
|
| 78 |
-
compatible_algorithms: Dict = NotImplementedError
|
| 79 |
-
|
| 80 |
-
def __init__(
|
| 81 |
-
self,
|
| 82 |
-
root_cfg: DictConfig,
|
| 83 |
-
logger: Optional[WandbLogger] = None,
|
| 84 |
-
ckpt_path: Optional[Union[str, pathlib.Path]] = None,
|
| 85 |
-
) -> None:
|
| 86 |
-
"""
|
| 87 |
-
Constructor
|
| 88 |
-
|
| 89 |
-
Args:
|
| 90 |
-
cfg: configuration file that contains everything about the experiment
|
| 91 |
-
logger: a pytorch-lightning WandbLogger instance
|
| 92 |
-
ckpt_path: an optional path to saved checkpoint
|
| 93 |
-
"""
|
| 94 |
-
super().__init__()
|
| 95 |
-
self.root_cfg = root_cfg
|
| 96 |
-
self.cfg = root_cfg.experiment
|
| 97 |
-
self.debug = root_cfg.debug
|
| 98 |
-
self.logger = logger
|
| 99 |
-
self.ckpt_path = ckpt_path
|
| 100 |
-
self.algo = None
|
| 101 |
-
self.customized_load = self.cfg.customized_load
|
| 102 |
-
self.load_vae = self.cfg.load_vae
|
| 103 |
-
self.load_t_to_r = self.cfg.load_t_to_r
|
| 104 |
-
self.zero_init_gate=self.cfg.zero_init_gate
|
| 105 |
-
self.only_tune_refer = self.cfg.only_tune_refer
|
| 106 |
-
self.diffusion_path = self.cfg.diffusion_path
|
| 107 |
-
self.vae_path = self.cfg.vae_path # "/mnt/xiaozeqi/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors"
|
| 108 |
-
self.pose_predictor_path = self.cfg.pose_predictor_path # "/mnt/xiaozeqi/diffusionforcing/outputs/2025-03-28/16-45-11/checkpoints/epoch0step595000.ckpt"
|
| 109 |
-
|
| 110 |
-
def _build_algo(self):
|
| 111 |
-
"""
|
| 112 |
-
Build the lightning module
|
| 113 |
-
:return: a pytorch-lightning module to be launched
|
| 114 |
-
"""
|
| 115 |
-
algo_name = self.root_cfg.algorithm._name
|
| 116 |
-
if algo_name not in self.compatible_algorithms:
|
| 117 |
-
raise ValueError(
|
| 118 |
-
f"Algorithm {algo_name} not found in compatible_algorithms for this Experiment class. "
|
| 119 |
-
"Make sure you define compatible_algorithms correctly and make sure that each key has "
|
| 120 |
-
"same name as yaml file under '[project_root]/configurations/algorithm' without .yaml suffix"
|
| 121 |
-
)
|
| 122 |
-
return self.compatible_algorithms[algo_name](self.root_cfg.algorithm)
|
| 123 |
-
|
| 124 |
-
def exec_task(self, task: str) -> None:
|
| 125 |
-
"""
|
| 126 |
-
Executing a certain task specified by string. Each task should be a stage of experiment.
|
| 127 |
-
In most computer vision / nlp applications, tasks should be just train and test.
|
| 128 |
-
In reinforcement learning, you might have more stages such as collecting dataset etc
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
task: a string specifying a task implemented for this experiment
|
| 132 |
-
"""
|
| 133 |
-
if hasattr(self, task) and callable(getattr(self, task)):
|
| 134 |
-
if is_rank_zero:
|
| 135 |
-
print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
|
| 136 |
-
getattr(self, task)()
|
| 137 |
-
else:
|
| 138 |
-
raise ValueError(
|
| 139 |
-
f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
|
| 140 |
-
)
|
| 141 |
-
|
| 142 |
-
def exec_interactive(self, task: str) -> None:
|
| 143 |
-
"""
|
| 144 |
-
Executing a certain task specified by string. Each task should be a stage of experiment.
|
| 145 |
-
In most computer vision / nlp applications, tasks should be just train and test.
|
| 146 |
-
In reinforcement learning, you might have more stages such as collecting dataset etc
|
| 147 |
-
|
| 148 |
-
Args:
|
| 149 |
-
task: a string specifying a task implemented for this experiment
|
| 150 |
-
"""
|
| 151 |
-
if hasattr(self, task) and callable(getattr(self, task)):
|
| 152 |
-
if is_rank_zero:
|
| 153 |
-
print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
|
| 154 |
-
return getattr(self, task)()
|
| 155 |
-
else:
|
| 156 |
-
raise ValueError(
|
| 157 |
-
f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
class BaseLightningExperiment(BaseExperiment):
|
| 161 |
-
"""
|
| 162 |
-
Abstract class for pytorch lightning experiments. Useful for computer vision & nlp where main components are
|
| 163 |
-
simply models, datasets and train loop.
|
| 164 |
-
"""
|
| 165 |
-
|
| 166 |
-
# each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
|
| 167 |
-
compatible_algorithms: Dict = NotImplementedError
|
| 168 |
-
|
| 169 |
-
# each key has to be a yaml file under '[project_root]/configurations/dataset' without .yaml suffix
|
| 170 |
-
compatible_datasets: Dict = NotImplementedError
|
| 171 |
-
|
| 172 |
-
def _build_trainer_callbacks(self):
|
| 173 |
-
callbacks = []
|
| 174 |
-
if self.logger:
|
| 175 |
-
callbacks.append(LearningRateMonitor("step", True))
|
| 176 |
-
|
| 177 |
-
def _build_training_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
| 178 |
-
train_dataset = self._build_dataset("training")
|
| 179 |
-
shuffle = (
|
| 180 |
-
False if isinstance(train_dataset, torch.utils.data.IterableDataset) else self.cfg.training.data.shuffle
|
| 181 |
-
)
|
| 182 |
-
if train_dataset:
|
| 183 |
-
return torch.utils.data.DataLoader(
|
| 184 |
-
train_dataset,
|
| 185 |
-
batch_size=self.cfg.training.batch_size,
|
| 186 |
-
num_workers=min(os.cpu_count(), self.cfg.training.data.num_workers),
|
| 187 |
-
shuffle=shuffle,
|
| 188 |
-
persistent_workers=True,
|
| 189 |
-
)
|
| 190 |
-
else:
|
| 191 |
-
return None
|
| 192 |
-
|
| 193 |
-
def _build_validation_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
| 194 |
-
validation_dataset = self._build_dataset("validation")
|
| 195 |
-
shuffle = (
|
| 196 |
-
False
|
| 197 |
-
if isinstance(validation_dataset, torch.utils.data.IterableDataset)
|
| 198 |
-
else self.cfg.validation.data.shuffle
|
| 199 |
-
)
|
| 200 |
-
if validation_dataset:
|
| 201 |
-
return torch.utils.data.DataLoader(
|
| 202 |
-
validation_dataset,
|
| 203 |
-
batch_size=self.cfg.validation.batch_size,
|
| 204 |
-
num_workers=min(os.cpu_count(), self.cfg.validation.data.num_workers),
|
| 205 |
-
shuffle=shuffle,
|
| 206 |
-
persistent_workers=True,
|
| 207 |
-
)
|
| 208 |
-
else:
|
| 209 |
-
return None
|
| 210 |
-
|
| 211 |
-
def _build_test_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
| 212 |
-
test_dataset = self._build_dataset("test")
|
| 213 |
-
shuffle = False if isinstance(test_dataset, torch.utils.data.IterableDataset) else self.cfg.test.data.shuffle
|
| 214 |
-
if test_dataset:
|
| 215 |
-
return torch.utils.data.DataLoader(
|
| 216 |
-
test_dataset,
|
| 217 |
-
batch_size=self.cfg.test.batch_size,
|
| 218 |
-
num_workers=min(os.cpu_count(), self.cfg.test.data.num_workers),
|
| 219 |
-
shuffle=shuffle,
|
| 220 |
-
persistent_workers=True,
|
| 221 |
-
)
|
| 222 |
-
else:
|
| 223 |
-
return None
|
| 224 |
-
|
| 225 |
-
def training(self) -> None:
|
| 226 |
-
"""
|
| 227 |
-
All training happens here
|
| 228 |
-
"""
|
| 229 |
-
if not self.algo:
|
| 230 |
-
self.algo = self._build_algo()
|
| 231 |
-
if self.cfg.training.compile:
|
| 232 |
-
self.algo = torch.compile(self.algo)
|
| 233 |
-
|
| 234 |
-
callbacks = []
|
| 235 |
-
if self.logger:
|
| 236 |
-
callbacks.append(LearningRateMonitor("step", True))
|
| 237 |
-
if "checkpointing" in self.cfg.training:
|
| 238 |
-
callbacks.append(
|
| 239 |
-
ModelCheckpoint(
|
| 240 |
-
pathlib.Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / "checkpoints",
|
| 241 |
-
**self.cfg.training.checkpointing,
|
| 242 |
-
)
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
# TODO do not upload checkpoint to wandb
|
| 246 |
-
|
| 247 |
-
# trainer = pl.Trainer(
|
| 248 |
-
# accelerator="auto",
|
| 249 |
-
# logger=self.logger if self.logger else False,
|
| 250 |
-
# devices=torch.cuda.device_count(),
|
| 251 |
-
# num_nodes=self.cfg.num_nodes,
|
| 252 |
-
# strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
|
| 253 |
-
# callbacks=callbacks,
|
| 254 |
-
# gradient_clip_val=self.cfg.training.optim.gradient_clip_val,
|
| 255 |
-
# val_check_interval=self.cfg.validation.val_every_n_step,
|
| 256 |
-
# limit_val_batches=self.cfg.validation.limit_batch,
|
| 257 |
-
# check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch,
|
| 258 |
-
# accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches,
|
| 259 |
-
# precision=self.cfg.training.precision,
|
| 260 |
-
# detect_anomaly=False, # self.cfg.debug,
|
| 261 |
-
# num_sanity_val_steps=int(self.cfg.debug),
|
| 262 |
-
# max_epochs=self.cfg.training.max_epochs,
|
| 263 |
-
# max_steps=self.cfg.training.max_steps,
|
| 264 |
-
# max_time=self.cfg.training.max_time,
|
| 265 |
-
# )
|
| 266 |
-
|
| 267 |
-
trainer = pl.Trainer(
|
| 268 |
-
accelerator="auto",
|
| 269 |
-
devices="auto", # 自动选择设备
|
| 270 |
-
strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
|
| 271 |
-
logger=self.logger or False, # 简化写法
|
| 272 |
-
callbacks=callbacks,
|
| 273 |
-
gradient_clip_val=self.cfg.training.optim.gradient_clip_val or 0.0, # 确保默认值
|
| 274 |
-
val_check_interval=self.cfg.validation.val_every_n_step if self.cfg.validation.val_every_n_step else None,
|
| 275 |
-
limit_val_batches=self.cfg.validation.limit_batch,
|
| 276 |
-
check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch if not self.cfg.validation.val_every_n_step else None,
|
| 277 |
-
accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches or 1, # 默认累积为1
|
| 278 |
-
precision=self.cfg.training.precision or 32, # 默认32位精度
|
| 279 |
-
detect_anomaly=False, # 默认关闭异常检测
|
| 280 |
-
num_sanity_val_steps=int(self.cfg.debug) if self.cfg.debug else 0,
|
| 281 |
-
max_epochs=self.cfg.training.max_epochs,
|
| 282 |
-
max_steps=self.cfg.training.max_steps,
|
| 283 |
-
max_time=self.cfg.training.max_time
|
| 284 |
-
)
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
if self.customized_load:
|
| 288 |
-
if self.load_vae:
|
| 289 |
-
load_custom_checkpoint(algo=self.algo.diffusion_model.model,optimizer=None,checkpoint_path=self.ckpt_path)
|
| 290 |
-
load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
|
| 291 |
-
else:
|
| 292 |
-
load_custom_checkpoint(algo=self.algo,optimizer=None,checkpoint_path=self.ckpt_path)
|
| 293 |
-
|
| 294 |
-
if self.load_t_to_r:
|
| 295 |
-
param_list = []
|
| 296 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
| 297 |
-
if 't_' in name and 't_embedder' not in name:
|
| 298 |
-
print(name)
|
| 299 |
-
param_list.append(para)
|
| 300 |
-
|
| 301 |
-
it = 0
|
| 302 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
| 303 |
-
if 'r_' in name:
|
| 304 |
-
para.requires_grad_(False)
|
| 305 |
-
try:
|
| 306 |
-
para.copy_(param_list[it].detach().cpu())
|
| 307 |
-
except:
|
| 308 |
-
import pdb;pdb.set_trace()
|
| 309 |
-
para.requires_grad_(True)
|
| 310 |
-
it += 1
|
| 311 |
-
|
| 312 |
-
if self.zero_init_gate:
|
| 313 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
| 314 |
-
if 'r_adaLN_modulation' in name:
|
| 315 |
-
para.requires_grad_(False)
|
| 316 |
-
para[2*1024:3*1024] = 0
|
| 317 |
-
para[5*1024:6*1024] = 0
|
| 318 |
-
para.requires_grad_(True)
|
| 319 |
-
|
| 320 |
-
if self.only_tune_refer:
|
| 321 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
| 322 |
-
para.requires_grad_(False)
|
| 323 |
-
if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name:
|
| 324 |
-
para.requires_grad_(True)
|
| 325 |
-
|
| 326 |
-
trainer.fit(
|
| 327 |
-
self.algo,
|
| 328 |
-
train_dataloaders=self._build_training_loader(),
|
| 329 |
-
val_dataloaders=self._build_validation_loader(),
|
| 330 |
-
ckpt_path=None,
|
| 331 |
-
)
|
| 332 |
-
else:
|
| 333 |
-
|
| 334 |
-
if self.only_tune_refer:
|
| 335 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
| 336 |
-
para.requires_grad_(False)
|
| 337 |
-
if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name:
|
| 338 |
-
para.requires_grad_(True)
|
| 339 |
-
|
| 340 |
-
trainer.fit(
|
| 341 |
-
self.algo,
|
| 342 |
-
train_dataloaders=self._build_training_loader(),
|
| 343 |
-
val_dataloaders=self._build_validation_loader(),
|
| 344 |
-
ckpt_path=self.ckpt_path,
|
| 345 |
-
)
|
| 346 |
-
|
| 347 |
-
def validation(self) -> None:
|
| 348 |
-
"""
|
| 349 |
-
All validation happens here
|
| 350 |
-
"""
|
| 351 |
-
if not self.algo:
|
| 352 |
-
self.algo = self._build_algo()
|
| 353 |
-
if self.cfg.validation.compile:
|
| 354 |
-
self.algo = torch.compile(self.algo)
|
| 355 |
-
|
| 356 |
-
callbacks = []
|
| 357 |
-
|
| 358 |
-
trainer = pl.Trainer(
|
| 359 |
-
accelerator="auto",
|
| 360 |
-
logger=self.logger,
|
| 361 |
-
devices="auto",
|
| 362 |
-
num_nodes=self.cfg.num_nodes,
|
| 363 |
-
strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
|
| 364 |
-
callbacks=callbacks,
|
| 365 |
-
# limit_val_batches=self.cfg.validation.limit_batch,
|
| 366 |
-
limit_val_batches=self.cfg.validation.limit_batch,
|
| 367 |
-
precision=self.cfg.validation.precision,
|
| 368 |
-
detect_anomaly=False, # self.cfg.debug,
|
| 369 |
-
inference_mode=self.cfg.validation.inference_mode,
|
| 370 |
-
)
|
| 371 |
-
|
| 372 |
-
if self.customized_load:
|
| 373 |
-
|
| 374 |
-
if self.load_vae:
|
| 375 |
-
load_custom_checkpoint(algo=self.algo.diffusion_model.model,optimizer=None,checkpoint_path=self.ckpt_path)
|
| 376 |
-
load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
|
| 377 |
-
else:
|
| 378 |
-
load_custom_checkpoint(algo=self.algo,optimizer=None,checkpoint_path=self.ckpt_path)
|
| 379 |
-
|
| 380 |
-
if self.load_t_to_r:
|
| 381 |
-
param_list = []
|
| 382 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
| 383 |
-
if 't_' in name and 't_embedder' not in name:
|
| 384 |
-
print(name)
|
| 385 |
-
param_list.append(para)
|
| 386 |
-
|
| 387 |
-
it = 0
|
| 388 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
| 389 |
-
if 'r_' in name:
|
| 390 |
-
para.requires_grad_(False)
|
| 391 |
-
try:
|
| 392 |
-
para.copy_(param_list[it].detach().cpu())
|
| 393 |
-
except:
|
| 394 |
-
import pdb;pdb.set_trace()
|
| 395 |
-
para.requires_grad_(True)
|
| 396 |
-
it += 1
|
| 397 |
-
|
| 398 |
-
if self.zero_init_gate:
|
| 399 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
| 400 |
-
if 'r_adaLN_modulation' in name:
|
| 401 |
-
para.requires_grad_(False)
|
| 402 |
-
para[2*1024:3*1024] = 0
|
| 403 |
-
para[5*1024:6*1024] = 0
|
| 404 |
-
para.requires_grad_(True)
|
| 405 |
-
|
| 406 |
-
trainer.validate(
|
| 407 |
-
self.algo,
|
| 408 |
-
dataloaders=self._build_validation_loader(),
|
| 409 |
-
ckpt_path=None,
|
| 410 |
-
)
|
| 411 |
-
else:
|
| 412 |
-
trainer.validate(
|
| 413 |
-
self.algo,
|
| 414 |
-
dataloaders=self._build_validation_loader(),
|
| 415 |
-
ckpt_path=self.ckpt_path,
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
def test(self) -> None:
|
| 419 |
-
"""
|
| 420 |
-
All testing happens here
|
| 421 |
-
"""
|
| 422 |
-
if not self.algo:
|
| 423 |
-
self.algo = self._build_algo()
|
| 424 |
-
if self.cfg.test.compile:
|
| 425 |
-
self.algo = torch.compile(self.algo)
|
| 426 |
-
|
| 427 |
-
callbacks = []
|
| 428 |
-
|
| 429 |
-
trainer = pl.Trainer(
|
| 430 |
-
accelerator="auto",
|
| 431 |
-
logger=self.logger,
|
| 432 |
-
devices="auto",
|
| 433 |
-
num_nodes=self.cfg.num_nodes,
|
| 434 |
-
strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
|
| 435 |
-
callbacks=callbacks,
|
| 436 |
-
limit_test_batches=self.cfg.test.limit_batch,
|
| 437 |
-
precision=self.cfg.test.precision,
|
| 438 |
-
detect_anomaly=False, # self.cfg.debug,
|
| 439 |
-
)
|
| 440 |
-
|
| 441 |
-
# Only load the checkpoint if only testing. Otherwise, it will have been loaded
|
| 442 |
-
# and further trained during train.
|
| 443 |
-
trainer.test(
|
| 444 |
-
self.algo,
|
| 445 |
-
dataloaders=self._build_test_loader(),
|
| 446 |
-
ckpt_path=self.ckpt_path,
|
| 447 |
-
)
|
| 448 |
-
if not self.algo:
|
| 449 |
-
self.algo = self._build_algo()
|
| 450 |
-
if self.cfg.validation.compile:
|
| 451 |
-
self.algo = torch.compile(self.algo)
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
def interactive(self):
|
| 455 |
-
|
| 456 |
-
if not self.algo:
|
| 457 |
-
self.algo = self._build_algo()
|
| 458 |
-
if self.cfg.validation.compile:
|
| 459 |
-
self.algo = torch.compile(self.algo)
|
| 460 |
-
|
| 461 |
-
if self.customized_load:
|
| 462 |
-
load_custom_checkpoint(algo=self.algo.diffusion_model,optimizer=None,checkpoint_path=self.diffusion_path)
|
| 463 |
-
load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
|
| 464 |
-
load_custom_checkpoint(algo=self.algo.pose_prediction_model,optimizer=None,checkpoint_path=self.pose_predictor_path)
|
| 465 |
-
return self.algo
|
| 466 |
-
else:
|
| 467 |
-
raise NotImplementedError
|
| 468 |
-
|
| 469 |
-
def _build_dataset(self, split: str) -> Optional[torch.utils.data.Dataset]:
|
| 470 |
-
if split in ["training", "test", "validation"]:
|
| 471 |
-
return self.compatible_datasets[self.root_cfg.dataset._name](self.root_cfg.dataset, split=split)
|
| 472 |
-
else:
|
| 473 |
-
raise NotImplementedError(f"split '{split}' is not implemented")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/exp_pose.py
DELETED
|
@@ -1,310 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
|
| 3 |
-
template [repo](https://github.com/buoyancy99/research-template).
|
| 4 |
-
By its MIT license, you must keep the above sentence in `README.md`
|
| 5 |
-
and the `LICENSE` file to credit the author.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from abc import ABC, abstractmethod
|
| 9 |
-
from typing import Optional, Union, Literal, List, Dict
|
| 10 |
-
import pathlib
|
| 11 |
-
import os
|
| 12 |
-
|
| 13 |
-
import hydra
|
| 14 |
-
import torch
|
| 15 |
-
from lightning.pytorch.strategies.ddp import DDPStrategy
|
| 16 |
-
|
| 17 |
-
import lightning.pytorch as pl
|
| 18 |
-
from lightning.pytorch.loggers.wandb import WandbLogger
|
| 19 |
-
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
|
| 20 |
-
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
| 21 |
-
from pytorch_lightning.utilities import rank_zero_info
|
| 22 |
-
|
| 23 |
-
from omegaconf import DictConfig
|
| 24 |
-
|
| 25 |
-
from utils.print_utils import cyan
|
| 26 |
-
from utils.distributed_utils import is_rank_zero
|
| 27 |
-
from safetensors.torch import load_model
|
| 28 |
-
from pathlib import Path
|
| 29 |
-
from algorithms.worldmem import PosePrediction
|
| 30 |
-
from datasets.video import MinecraftVideoPoseDataset
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
torch.set_float32_matmul_precision("high")
|
| 34 |
-
|
| 35 |
-
def load_custom_checkpoint(algo, optimizer, checkpoint_path):
|
| 36 |
-
if not checkpoint_path:
|
| 37 |
-
rank_zero_info("No checkpoint path provided, skipping checkpoint loading.")
|
| 38 |
-
return None
|
| 39 |
-
|
| 40 |
-
if not isinstance(checkpoint_path, Path):
|
| 41 |
-
checkpoint_path = Path(checkpoint_path)
|
| 42 |
-
|
| 43 |
-
if checkpoint_path.suffix == ".pt":
|
| 44 |
-
ckpt = torch.load(checkpoint_path, weights_only=True)
|
| 45 |
-
algo.load_state_dict(ckpt, strict=False)
|
| 46 |
-
elif checkpoint_path.suffix == ".ckpt":
|
| 47 |
-
ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
| 48 |
-
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
| 49 |
-
elif checkpoint_path.suffix == ".safetensors":
|
| 50 |
-
load_model(algo, checkpoint_path, strict=False)
|
| 51 |
-
elif os.path.isdir(checkpoint_path):
|
| 52 |
-
ckpt_files = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt')]
|
| 53 |
-
if not ckpt_files:
|
| 54 |
-
raise FileNotFoundError("在指定文件夹中未找到任何 .ckpt 文件!")
|
| 55 |
-
selected_ckpt = max(ckpt_files)
|
| 56 |
-
selected_ckpt_path = os.path.join(checkpoint_path, selected_ckpt)
|
| 57 |
-
print(f"加载的 checkpoint 文件为: {selected_ckpt_path}")
|
| 58 |
-
|
| 59 |
-
ckpt = torch.load(selected_ckpt_path, map_location=torch.device('cpu'))
|
| 60 |
-
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
| 61 |
-
|
| 62 |
-
rank_zero_info("Model weights loaded.")
|
| 63 |
-
|
| 64 |
-
class PoseExperiment(ABC):
|
| 65 |
-
"""
|
| 66 |
-
Abstract class for an experiment. This generalizes the pytorch lightning Trainer & lightning Module to more
|
| 67 |
-
flexible experiments that doesn't fit in the typical ml loop, e.g. multi-stage reinforcement learning benchmarks.
|
| 68 |
-
"""
|
| 69 |
-
|
| 70 |
-
# each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
|
| 71 |
-
compatible_algorithms = dict(
|
| 72 |
-
pose_prediction=PosePrediction
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
compatible_datasets = dict(
|
| 76 |
-
video_minecraft_pose=MinecraftVideoPoseDataset
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
def __init__(
|
| 80 |
-
self,
|
| 81 |
-
root_cfg: DictConfig,
|
| 82 |
-
logger: Optional[WandbLogger] = None,
|
| 83 |
-
ckpt_path: Optional[Union[str, pathlib.Path]] = None,
|
| 84 |
-
) -> None:
|
| 85 |
-
"""
|
| 86 |
-
Constructor
|
| 87 |
-
|
| 88 |
-
Args:
|
| 89 |
-
cfg: configuration file that contains everything about the experiment
|
| 90 |
-
logger: a pytorch-lightning WandbLogger instance
|
| 91 |
-
ckpt_path: an optional path to saved checkpoint
|
| 92 |
-
"""
|
| 93 |
-
super().__init__()
|
| 94 |
-
self.root_cfg = root_cfg
|
| 95 |
-
self.cfg = root_cfg.experiment
|
| 96 |
-
self.debug = root_cfg.debug
|
| 97 |
-
self.logger = logger
|
| 98 |
-
self.ckpt_path = ckpt_path
|
| 99 |
-
self.algo = None
|
| 100 |
-
self.vae_path = "/cpfs01/user/xiaozeqi/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors"
|
| 101 |
-
|
| 102 |
-
def _build_algo(self):
|
| 103 |
-
"""
|
| 104 |
-
Build the lightning module
|
| 105 |
-
:return: a pytorch-lightning module to be launched
|
| 106 |
-
"""
|
| 107 |
-
algo_name = self.root_cfg.algorithm._name
|
| 108 |
-
if algo_name not in self.compatible_algorithms:
|
| 109 |
-
raise ValueError(
|
| 110 |
-
f"Algorithm {algo_name} not found in compatible_algorithms for this Experiment class. "
|
| 111 |
-
"Make sure you define compatible_algorithms correctly and make sure that each key has "
|
| 112 |
-
"same name as yaml file under '[project_root]/configurations/algorithm' without .yaml suffix"
|
| 113 |
-
)
|
| 114 |
-
return self.compatible_algorithms[algo_name](self.root_cfg.algorithm)
|
| 115 |
-
|
| 116 |
-
def exec_task(self, task: str) -> None:
|
| 117 |
-
"""
|
| 118 |
-
Executing a certain task specified by string. Each task should be a stage of experiment.
|
| 119 |
-
In most computer vision / nlp applications, tasks should be just train and test.
|
| 120 |
-
In reinforcement learning, you might have more stages such as collecting dataset etc
|
| 121 |
-
|
| 122 |
-
Args:
|
| 123 |
-
task: a string specifying a task implemented for this experiment
|
| 124 |
-
"""
|
| 125 |
-
if hasattr(self, task) and callable(getattr(self, task)):
|
| 126 |
-
if is_rank_zero:
|
| 127 |
-
print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
|
| 128 |
-
getattr(self, task)()
|
| 129 |
-
else:
|
| 130 |
-
raise ValueError(
|
| 131 |
-
f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def _build_trainer_callbacks(self):
|
| 136 |
-
callbacks = []
|
| 137 |
-
if self.logger:
|
| 138 |
-
callbacks.append(LearningRateMonitor("step", True))
|
| 139 |
-
|
| 140 |
-
def _build_training_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
| 141 |
-
train_dataset = self._build_dataset("training")
|
| 142 |
-
shuffle = (
|
| 143 |
-
False if isinstance(train_dataset, torch.utils.data.IterableDataset) else self.cfg.training.data.shuffle
|
| 144 |
-
)
|
| 145 |
-
if train_dataset:
|
| 146 |
-
return torch.utils.data.DataLoader(
|
| 147 |
-
train_dataset,
|
| 148 |
-
batch_size=self.cfg.training.batch_size,
|
| 149 |
-
num_workers=min(os.cpu_count(), self.cfg.training.data.num_workers),
|
| 150 |
-
shuffle=shuffle,
|
| 151 |
-
persistent_workers=True,
|
| 152 |
-
)
|
| 153 |
-
else:
|
| 154 |
-
return None
|
| 155 |
-
|
| 156 |
-
def _build_validation_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
| 157 |
-
validation_dataset = self._build_dataset("validation")
|
| 158 |
-
shuffle = (
|
| 159 |
-
False
|
| 160 |
-
if isinstance(validation_dataset, torch.utils.data.IterableDataset)
|
| 161 |
-
else self.cfg.validation.data.shuffle
|
| 162 |
-
)
|
| 163 |
-
if validation_dataset:
|
| 164 |
-
return torch.utils.data.DataLoader(
|
| 165 |
-
validation_dataset,
|
| 166 |
-
batch_size=self.cfg.validation.batch_size,
|
| 167 |
-
num_workers=min(os.cpu_count(), self.cfg.validation.data.num_workers),
|
| 168 |
-
shuffle=shuffle,
|
| 169 |
-
persistent_workers=True,
|
| 170 |
-
)
|
| 171 |
-
else:
|
| 172 |
-
return None
|
| 173 |
-
|
| 174 |
-
def _build_test_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
| 175 |
-
test_dataset = self._build_dataset("test")
|
| 176 |
-
shuffle = False if isinstance(test_dataset, torch.utils.data.IterableDataset) else self.cfg.test.data.shuffle
|
| 177 |
-
if test_dataset:
|
| 178 |
-
return torch.utils.data.DataLoader(
|
| 179 |
-
test_dataset,
|
| 180 |
-
batch_size=self.cfg.test.batch_size,
|
| 181 |
-
num_workers=min(os.cpu_count(), self.cfg.test.data.num_workers),
|
| 182 |
-
shuffle=shuffle,
|
| 183 |
-
persistent_workers=True,
|
| 184 |
-
)
|
| 185 |
-
else:
|
| 186 |
-
return None
|
| 187 |
-
|
| 188 |
-
def training(self) -> None:
|
| 189 |
-
"""
|
| 190 |
-
All training happens here
|
| 191 |
-
"""
|
| 192 |
-
if not self.algo:
|
| 193 |
-
self.algo = self._build_algo()
|
| 194 |
-
if self.cfg.training.compile:
|
| 195 |
-
self.algo = torch.compile(self.algo)
|
| 196 |
-
|
| 197 |
-
callbacks = []
|
| 198 |
-
if self.logger:
|
| 199 |
-
callbacks.append(LearningRateMonitor("step", True))
|
| 200 |
-
if "checkpointing" in self.cfg.training:
|
| 201 |
-
callbacks.append(
|
| 202 |
-
ModelCheckpoint(
|
| 203 |
-
pathlib.Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / "checkpoints",
|
| 204 |
-
**self.cfg.training.checkpointing,
|
| 205 |
-
)
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
trainer = pl.Trainer(
|
| 209 |
-
accelerator="auto",
|
| 210 |
-
devices="auto", # 自动选择设备
|
| 211 |
-
strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
|
| 212 |
-
logger=self.logger or False, # 简化写法
|
| 213 |
-
callbacks=callbacks,
|
| 214 |
-
gradient_clip_val=self.cfg.training.optim.gradient_clip_val or 0.0, # 确保默认值
|
| 215 |
-
val_check_interval=self.cfg.validation.val_every_n_step if self.cfg.validation.val_every_n_step else None,
|
| 216 |
-
limit_val_batches=self.cfg.validation.limit_batch,
|
| 217 |
-
check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch if not self.cfg.validation.val_every_n_step else None,
|
| 218 |
-
accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches or 1, # 默认累积为1
|
| 219 |
-
precision=self.cfg.training.precision or 32, # 默认32位精度
|
| 220 |
-
detect_anomaly=False, # 默认关闭异常检测
|
| 221 |
-
num_sanity_val_steps=int(self.cfg.debug) if self.cfg.debug else 0,
|
| 222 |
-
max_epochs=self.cfg.training.max_epochs,
|
| 223 |
-
max_steps=self.cfg.training.max_steps,
|
| 224 |
-
max_time=self.cfg.training.max_time
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
|
| 228 |
-
|
| 229 |
-
trainer.fit(
|
| 230 |
-
self.algo,
|
| 231 |
-
train_dataloaders=self._build_training_loader(),
|
| 232 |
-
val_dataloaders=self._build_validation_loader(),
|
| 233 |
-
ckpt_path=self.ckpt_path,
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
def validation(self) -> None:
|
| 237 |
-
"""
|
| 238 |
-
All validation happens here
|
| 239 |
-
"""
|
| 240 |
-
if not self.algo:
|
| 241 |
-
self.algo = self._build_algo()
|
| 242 |
-
if self.cfg.validation.compile:
|
| 243 |
-
self.algo = torch.compile(self.algo)
|
| 244 |
-
|
| 245 |
-
callbacks = []
|
| 246 |
-
|
| 247 |
-
trainer = pl.Trainer(
|
| 248 |
-
accelerator="auto",
|
| 249 |
-
logger=self.logger,
|
| 250 |
-
devices="auto",
|
| 251 |
-
num_nodes=self.cfg.num_nodes,
|
| 252 |
-
strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
|
| 253 |
-
callbacks=callbacks,
|
| 254 |
-
# limit_val_batches=self.cfg.validation.limit_batch,
|
| 255 |
-
limit_val_batches=self.cfg.validation.limit_batch,
|
| 256 |
-
precision=self.cfg.validation.precision,
|
| 257 |
-
detect_anomaly=False, # self.cfg.debug,
|
| 258 |
-
inference_mode=self.cfg.validation.inference_mode,
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
|
| 262 |
-
|
| 263 |
-
trainer.validate(
|
| 264 |
-
self.algo,
|
| 265 |
-
dataloaders=self._build_validation_loader(),
|
| 266 |
-
ckpt_path=self.ckpt_path,
|
| 267 |
-
)
|
| 268 |
-
|
| 269 |
-
def test(self) -> None:
|
| 270 |
-
"""
|
| 271 |
-
All testing happens here
|
| 272 |
-
"""
|
| 273 |
-
if not self.algo:
|
| 274 |
-
self.algo = self._build_algo()
|
| 275 |
-
if self.cfg.test.compile:
|
| 276 |
-
self.algo = torch.compile(self.algo)
|
| 277 |
-
|
| 278 |
-
callbacks = []
|
| 279 |
-
|
| 280 |
-
trainer = pl.Trainer(
|
| 281 |
-
accelerator="auto",
|
| 282 |
-
logger=self.logger,
|
| 283 |
-
devices="auto",
|
| 284 |
-
num_nodes=self.cfg.num_nodes,
|
| 285 |
-
strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
|
| 286 |
-
callbacks=callbacks,
|
| 287 |
-
limit_test_batches=self.cfg.test.limit_batch,
|
| 288 |
-
precision=self.cfg.test.precision,
|
| 289 |
-
detect_anomaly=False, # self.cfg.debug,
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
# Only load the checkpoint if only testing. Otherwise, it will have been loaded
|
| 293 |
-
# and further trained during train.
|
| 294 |
-
trainer.test(
|
| 295 |
-
self.algo,
|
| 296 |
-
dataloaders=self._build_test_loader(),
|
| 297 |
-
ckpt_path=self.ckpt_path,
|
| 298 |
-
)
|
| 299 |
-
if not self.algo:
|
| 300 |
-
self.algo = self._build_algo()
|
| 301 |
-
if self.cfg.validation.compile:
|
| 302 |
-
self.algo = torch.compile(self.algo)
|
| 303 |
-
|
| 304 |
-
def _build_dataset(self, split: str) -> Optional[torch.utils.data.Dataset]:
|
| 305 |
-
if split in ["training", "test", "validation"]:
|
| 306 |
-
return self.compatible_datasets[self.root_cfg.dataset._name](self.root_cfg.dataset, split=split)
|
| 307 |
-
else:
|
| 308 |
-
raise NotImplementedError(f"split '{split}' is not implemented")
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/exp_video.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
from datasets.video import (
|
| 2 |
-
MinecraftVideoDataset,
|
| 3 |
-
MinecraftVideoPoseDataset
|
| 4 |
-
)
|
| 5 |
-
|
| 6 |
-
from algorithms.worldmem import WorldMemMinecraft
|
| 7 |
-
from algorithms.worldmem import PosePrediction
|
| 8 |
-
from .exp_base import BaseLightningExperiment
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class VideoPredictionExperiment(BaseLightningExperiment):
|
| 12 |
-
"""
|
| 13 |
-
A video prediction experiment
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
compatible_algorithms = dict(
|
| 17 |
-
df_video_worldmemminecraft=WorldMemMinecraft,
|
| 18 |
-
pose_prediction=PosePrediction
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
compatible_datasets = dict(
|
| 22 |
-
# video datasets
|
| 23 |
-
video_minecraft=MinecraftVideoDataset,
|
| 24 |
-
video_minecraft_pose=MinecraftVideoPoseDataset
|
| 25 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
DELETED
|
@@ -1,219 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
|
| 3 |
-
template [repo](https://github.com/buoyancy99/research-template).
|
| 4 |
-
By its MIT license, you must keep the above sentence in `README.md`
|
| 5 |
-
and the `LICENSE` file to credit the author.
|
| 6 |
-
|
| 7 |
-
Main file for the project. This will create and run new experiments and load checkpoints from wandb.
|
| 8 |
-
Borrowed part of the code from David Charatan and wandb.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
import sys
|
| 12 |
-
import subprocess
|
| 13 |
-
import time
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
|
| 16 |
-
import hydra
|
| 17 |
-
from omegaconf import DictConfig, OmegaConf
|
| 18 |
-
from omegaconf.omegaconf import open_dict
|
| 19 |
-
|
| 20 |
-
from utils.print_utils import cyan
|
| 21 |
-
from utils.ckpt_utils import download_latest_checkpoint, is_run_id
|
| 22 |
-
from utils.cluster_utils import submit_slurm_job
|
| 23 |
-
from utils.distributed_utils import is_rank_zero
|
| 24 |
-
|
| 25 |
-
def get_latest_checkpoint(checkpoint_folder: Path, pattern: str = '*.ckpt'):
|
| 26 |
-
# 获取文件夹中所有符合 pattern 的文件
|
| 27 |
-
checkpoint_files = list(checkpoint_folder.glob(pattern))
|
| 28 |
-
if not checkpoint_files:
|
| 29 |
-
return None # 如果没有找到 checkpoint 文件,返回 None
|
| 30 |
-
# 根据文件修改时间(st_mtime)选取最新的文件
|
| 31 |
-
latest_checkpoint = max(checkpoint_files, key=lambda f: f.stat().st_mtime)
|
| 32 |
-
return latest_checkpoint
|
| 33 |
-
|
| 34 |
-
def run_local(cfg: DictConfig):
|
| 35 |
-
# delay some imports in case they are not needed in non-local envs for submission
|
| 36 |
-
from experiments import build_experiment
|
| 37 |
-
from utils.wandb_utils import OfflineWandbLogger, SpaceEfficientWandbLogger
|
| 38 |
-
|
| 39 |
-
# Get yaml names
|
| 40 |
-
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
|
| 41 |
-
cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)
|
| 42 |
-
|
| 43 |
-
with open_dict(cfg):
|
| 44 |
-
if cfg_choice["experiment"] is not None:
|
| 45 |
-
cfg.experiment._name = cfg_choice["experiment"]
|
| 46 |
-
if cfg_choice["dataset"] is not None:
|
| 47 |
-
cfg.dataset._name = cfg_choice["dataset"]
|
| 48 |
-
if cfg_choice["algorithm"] is not None:
|
| 49 |
-
cfg.algorithm._name = cfg_choice["algorithm"]
|
| 50 |
-
|
| 51 |
-
# import pdb;pdb.set_trace()
|
| 52 |
-
# Set up the output directory.
|
| 53 |
-
output_dir = getattr(cfg, "output_dir", None)
|
| 54 |
-
if output_dir is not None:
|
| 55 |
-
OmegaConf.set_readonly(hydra_cfg, False)
|
| 56 |
-
hydra_cfg.runtime.output_dir = output_dir
|
| 57 |
-
OmegaConf.set_readonly(hydra_cfg, True)
|
| 58 |
-
|
| 59 |
-
output_dir = Path(hydra_cfg.runtime.output_dir)
|
| 60 |
-
|
| 61 |
-
if is_rank_zero:
|
| 62 |
-
print(cyan(f"Outputs will be saved to:"), output_dir)
|
| 63 |
-
(output_dir.parents[1] / "latest-run").unlink(missing_ok=True)
|
| 64 |
-
(output_dir.parents[1] / "latest-run").symlink_to(output_dir, target_is_directory=True)
|
| 65 |
-
|
| 66 |
-
# Set up logging with wandb.
|
| 67 |
-
if cfg.wandb.mode != "disabled":
|
| 68 |
-
# If resuming, merge into the existing run on wandb.
|
| 69 |
-
resume = cfg.get("resume", None)
|
| 70 |
-
name = f"{cfg.name} ({output_dir.parent.name}/{output_dir.name})" if resume is None else None
|
| 71 |
-
|
| 72 |
-
if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
|
| 73 |
-
logger_cls = OfflineWandbLogger
|
| 74 |
-
else:
|
| 75 |
-
logger_cls = SpaceEfficientWandbLogger
|
| 76 |
-
|
| 77 |
-
offline = cfg.wandb.mode != "online"
|
| 78 |
-
logger = logger_cls(
|
| 79 |
-
name=name,
|
| 80 |
-
save_dir=str(output_dir),
|
| 81 |
-
offline=offline,
|
| 82 |
-
entity=cfg.wandb.entity,
|
| 83 |
-
project=cfg.wandb.project,
|
| 84 |
-
log_model=False,
|
| 85 |
-
config=OmegaConf.to_container(cfg),
|
| 86 |
-
id=resume,
|
| 87 |
-
resume="auto"
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
else:
|
| 91 |
-
logger = None
|
| 92 |
-
|
| 93 |
-
# Load ckpt
|
| 94 |
-
resume = cfg.get("resume", None)
|
| 95 |
-
load = cfg.get("load", None)
|
| 96 |
-
checkpoint_path = None
|
| 97 |
-
load_id = None
|
| 98 |
-
if load and not is_run_id(load):
|
| 99 |
-
checkpoint_path = load
|
| 100 |
-
if resume:
|
| 101 |
-
load_id = resume
|
| 102 |
-
elif load and is_run_id(load):
|
| 103 |
-
load_id = load
|
| 104 |
-
else:
|
| 105 |
-
load_id = None
|
| 106 |
-
|
| 107 |
-
if load_id:
|
| 108 |
-
run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}"
|
| 109 |
-
checkpoint_path = Path("outputs/downloaded") / run_path / "model.ckpt"
|
| 110 |
-
checkpoint_path = output_dir / get_latest_checkpoint(output_dir / "checkpoints")
|
| 111 |
-
|
| 112 |
-
if checkpoint_path and is_rank_zero:
|
| 113 |
-
print(f"Will load checkpoint from {checkpoint_path}")
|
| 114 |
-
|
| 115 |
-
# launch experiment
|
| 116 |
-
experiment = build_experiment(cfg, logger, checkpoint_path)
|
| 117 |
-
for task in cfg.experiment.tasks:
|
| 118 |
-
experiment.exec_task(task)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def run_slurm(cfg: DictConfig):
|
| 122 |
-
python_args = " ".join(sys.argv[1:]) + " +_on_compute_node=True"
|
| 123 |
-
project_root = Path.cwd()
|
| 124 |
-
while not (project_root / ".git").exists():
|
| 125 |
-
project_root = project_root.parent
|
| 126 |
-
if project_root == Path("/"):
|
| 127 |
-
raise Exception("Could not find repo directory!")
|
| 128 |
-
|
| 129 |
-
slurm_log_dir = submit_slurm_job(
|
| 130 |
-
cfg,
|
| 131 |
-
python_args,
|
| 132 |
-
project_root,
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
if "cluster" in cfg and cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
|
| 136 |
-
print("Job submitted to a compute node without internet. This requires manual syncing on login node.")
|
| 137 |
-
osh_command_dir = project_root / ".wandb_osh_command_dir"
|
| 138 |
-
|
| 139 |
-
osh_proc = None
|
| 140 |
-
# if click.confirm("Do you want us to run the sync loop for you?", default=True):
|
| 141 |
-
osh_proc = subprocess.Popen(["wandb-osh", "--command-dir", osh_command_dir])
|
| 142 |
-
print(f"Running wandb-osh in background... PID: {osh_proc.pid}")
|
| 143 |
-
print(f"To kill the sync process, run 'kill {osh_proc.pid}' in the terminal.")
|
| 144 |
-
print(
|
| 145 |
-
f"You can manually start a sync loop later by running the following:",
|
| 146 |
-
cyan(f"wandb-osh --command-dir {osh_command_dir}"),
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
print(
|
| 150 |
-
"Once the job gets allocated and starts running, we will print a command below "
|
| 151 |
-
"for you to trace the errors and outputs: (Ctrl + C to exit without waiting)"
|
| 152 |
-
)
|
| 153 |
-
msg = f"tail -f {slurm_log_dir}/* \n"
|
| 154 |
-
try:
|
| 155 |
-
while not list(slurm_log_dir.glob("*.out")) and not list(slurm_log_dir.glob("*.err")):
|
| 156 |
-
time.sleep(1)
|
| 157 |
-
print(cyan("To trace the outputs and errors, run the following command:"), msg)
|
| 158 |
-
except KeyboardInterrupt:
|
| 159 |
-
print("Keyboard interrupt detected. Exiting...")
|
| 160 |
-
print(
|
| 161 |
-
cyan("To trace the outputs and errors, manually wait for the job to start and run the following command:"),
|
| 162 |
-
msg,
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
@hydra.main(
|
| 167 |
-
version_base=None,
|
| 168 |
-
config_path="configurations",
|
| 169 |
-
config_name="config",
|
| 170 |
-
)
|
| 171 |
-
def run(cfg: DictConfig):
|
| 172 |
-
if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
|
| 173 |
-
with open_dict(cfg):
|
| 174 |
-
if cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
|
| 175 |
-
cfg.wandb.mode = "offline"
|
| 176 |
-
|
| 177 |
-
if "name" not in cfg:
|
| 178 |
-
raise ValueError("must specify a name for the run with command line argument '+name=[name]'")
|
| 179 |
-
|
| 180 |
-
if not cfg.wandb.get("entity", None):
|
| 181 |
-
raise ValueError(
|
| 182 |
-
"must specify wandb entity in 'configurations/config.yaml' or with command line"
|
| 183 |
-
" argument 'wandb.entity=[entity]' \n An entity is your wandb user name or group"
|
| 184 |
-
" name. This is used for logging. If you don't have an wandb account, please signup at https://wandb.ai/"
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
if cfg.wandb.project is None:
|
| 188 |
-
cfg.wandb.project = str(Path(__file__).parent.name)
|
| 189 |
-
|
| 190 |
-
# If resuming or loading a wandb ckpt and not on a compute node, download the checkpoint.
|
| 191 |
-
resume = cfg.get("resume", None)
|
| 192 |
-
load = cfg.get("load", None)
|
| 193 |
-
|
| 194 |
-
if resume and load:
|
| 195 |
-
raise ValueError(
|
| 196 |
-
"When resuming a wandb run with `resume=[wandb id]`, checkpoint will be loaded from the cloud"
|
| 197 |
-
"and `load` should not be specified."
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
if resume:
|
| 201 |
-
load_id = resume
|
| 202 |
-
elif load and is_run_id(load):
|
| 203 |
-
load_id = load
|
| 204 |
-
else:
|
| 205 |
-
load_id = None
|
| 206 |
-
|
| 207 |
-
# if load_id and "_on_compute_node" not in cfg:
|
| 208 |
-
# run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}"
|
| 209 |
-
# download_latest_checkpoint(run_path, Path("outputs/downloaded"))
|
| 210 |
-
|
| 211 |
-
if "cluster" in cfg and not "_on_compute_node" in cfg:
|
| 212 |
-
print(cyan("Slurm detected, submitting to compute node instead of running locally..."))
|
| 213 |
-
run_slurm(cfg)
|
| 214 |
-
else:
|
| 215 |
-
run_local(cfg)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
if __name__ == "__main__":
|
| 219 |
-
run() # pylint: disable=no-value-for-parameter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/README.md
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
# scirpts
|
| 2 |
-
|
| 3 |
-
`scripts` folder contains bash scripts for you to scale up your project on cloud.
|
| 4 |
-
Don't put your jupyter notebooks here! They belongs to `debug` folder.
|
| 5 |
-
|
| 6 |
-
General scripts that are useful for all projects can be put in the `script` folder directly.
|
| 7 |
-
|
| 8 |
-
---
|
| 9 |
-
|
| 10 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/dummy_script.sh
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
echo 'hello world'
|
|
|
|
|
|
split_checkpoint.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
ckpt_path = "/mnt/xiaozeqi/diffusionforcing/outputs/2025-03-28/16-45-11/checkpoints/epoch0step595000.ckpt"
|
| 4 |
-
checkpoint = torch.load(ckpt_path, map_location="cpu") # map_location 可根据需要更换
|
| 5 |
-
|
| 6 |
-
state_dict = checkpoint['state_dict']
|
| 7 |
-
pose_prediction_model_dict = {k.replace('pose_prediction_model.', ''): v for k, v in state_dict.items() if k.startswith('pose_prediction_model.')}
|
| 8 |
-
|
| 9 |
-
torch.save({'state_dict': pose_prediction_model_dict}, "pose_prediction_model_only.ckpt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|