File size: 4,814 Bytes
ebc7f2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import sys

import torch
from lightning import seed_everything
from safetensors.torch import load_file as load_safetensors

from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config

# Set tokenizers parallelism to false to avoid warnings in multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def load_model_from_config():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.set_float32_matmul_precision("high")
    cfg = load_config()
    seed_everything(cfg.seed)
    
    # Get the directory containing the config file
    # Try to find config directory from sys.argv or use current directory
    if '--config' in sys.argv:
        config_idx = sys.argv.index('--config') + 1
        config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx]))
    else:
        config_dir = os.getcwd()

    vae = instantiate(
        target=cfg.test_vae.target,
        cfg=None,
        hfstyle=False,
        **cfg.test_vae.params,
    )
    
    # Handle relative paths
    vae_path = cfg.test_vae_ckpt
    if not os.path.isabs(vae_path):
        vae_path = os.path.join(config_dir, vae_path)
    
    # Load from safetensors (already contains EMA weights)
    vae_state_dict = load_safetensors(vae_path)
    vae.load_state_dict(vae_state_dict, strict=True)
    print(f"Loaded VAE model from {vae_path}")

    compare_statedict_and_parameters(
        state_dict=vae.state_dict(),
        named_parameters=vae.named_parameters(),
        named_buffers=vae.named_buffers(),
    )
    vae.to(device)
    vae.eval()

    # Model - fix relative paths in model params
    model_params = dict(cfg.model.params)
    # Convert relative paths to absolute paths
    if 'checkpoint_path' in model_params and model_params['checkpoint_path']:
        if not os.path.isabs(model_params['checkpoint_path']):
            model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path'])
    if 'tokenizer_path' in model_params and model_params['tokenizer_path']:
        if not os.path.isabs(model_params['tokenizer_path']):
            model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path'])
    
    model = instantiate(
        target=cfg.model.target, cfg=None, hfstyle=False, **model_params
    )
    
    # Handle relative paths
    model_path = cfg.test_ckpt
    if not os.path.isabs(model_path):
        model_path = os.path.join(config_dir, model_path)
    
    # Load from safetensors (already contains EMA weights)
    model_state_dict = load_safetensors(model_path)
    model.load_state_dict(model_state_dict, strict=True)
    print(f"Loaded model from {model_path}")

    compare_statedict_and_parameters(
        state_dict=model.state_dict(),
        named_parameters=model.named_parameters(),
        named_buffers=model.named_buffers(),
    )
    model.to(device)
    model.eval()

    return vae, model


@torch.inference_mode()
def generate_feature_stream(
    model, feature_length, text, feature_text_end=None, num_denoise_steps=None
):
    """
    Streaming interface for feature generation
    Args:
        model: Loaded model
        feature_length: List[int], generation length for each sample
        text: List[str] or List[List[str]], text prompts
        feature_text_end: List[List[int]], time points where text ends (if text is list of list)
        num_denoise_steps: Number of denoising steps
    Yields:
        dict: Contains "generated" (current generated feature segment)
    """

    # Construct input dict x
    # stream_generate needs x to contain "feature_length", "text", "feature_text_end" (if text is list of list)
    x = {"feature_length": torch.tensor(feature_length), "text": text}

    if feature_text_end is not None:
        x["feature_text_end"] = feature_text_end

    # Call model's stream_generate
    # Note: stream_generate is a generator
    generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps)

    for step_output in generator:
        # step_output is already a dict with "generated" key
        yield step_output


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to config")
    parser.add_argument(
        "--text", type=str, default="a person walks forward", help="Text prompt"
    )
    parser.add_argument("--length", type=int, default=120, help="Motion length")
    parser.add_argument(
        "--output", type=str, default="output.mp4", help="Output video path"
    )
    parser.add_argument(
        "--num_denoise_steps", type=int, default=None, help="Number of denoising steps"
    )
    args = parser.parse_args()

    print("Loading model...")
    vae, model = load_model_from_config()