Spaces:
Runtime error
Runtime error
Update inference_engine.py
Browse files- inference_engine.py +7 -1
inference_engine.py
CHANGED
|
@@ -9,6 +9,7 @@ from torchvision.transforms import ToPILImage, transforms, InterpolationMode, fu
|
|
| 9 |
import numpy as np
|
| 10 |
import pickle
|
| 11 |
import copy
|
|
|
|
| 12 |
from draw_pose import get_pose_images
|
| 13 |
from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width
|
| 14 |
|
|
@@ -18,7 +19,7 @@ def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, ds
|
|
| 18 |
normalize = transforms.Normalize([0.5], [0.5])
|
| 19 |
pretrained_model_path = "THUDM/CogVideoX-5b"
|
| 20 |
transformer_path = "yanboding/MTVCrafter/MV-DiT/CogVideoX"
|
| 21 |
-
tokenizer_path = "mp_rank_00_model_states.pt"
|
| 22 |
|
| 23 |
with open(motion_data_path, 'rb') as f:
|
| 24 |
data_list = pickle.load(f)
|
|
@@ -38,6 +39,11 @@ def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, ds
|
|
| 38 |
pipe.vae.enable_slicing()
|
| 39 |
|
| 40 |
# load VQVAE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
state_dict = torch.load(tokenizer_path, map_location="cpu")
|
| 42 |
motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1])
|
| 43 |
motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False)
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
import pickle
|
| 11 |
import copy
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
from draw_pose import get_pose_images
|
| 14 |
from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width
|
| 15 |
|
|
|
|
| 19 |
normalize = transforms.Normalize([0.5], [0.5])
|
| 20 |
pretrained_model_path = "THUDM/CogVideoX-5b"
|
| 21 |
transformer_path = "yanboding/MTVCrafter/MV-DiT/CogVideoX"
|
| 22 |
+
tokenizer_path = "4DMoT/mp_rank_00_model_states.pt"
|
| 23 |
|
| 24 |
with open(motion_data_path, 'rb') as f:
|
| 25 |
data_list = pickle.load(f)
|
|
|
|
| 39 |
pipe.vae.enable_slicing()
|
| 40 |
|
| 41 |
# load VQVAE
|
| 42 |
+
|
| 43 |
+
vqvae_model_path = hf_hub_download(
|
| 44 |
+
repo_id="yanboding/MTVCrafter",
|
| 45 |
+
filename="4DMoT/mp_rank_00_model_states.pt"
|
| 46 |
+
)
|
| 47 |
state_dict = torch.load(tokenizer_path, map_location="cpu")
|
| 48 |
motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1])
|
| 49 |
motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False)
|