File size: 4,763 Bytes
25cd1c4
 
 
 
 
daa3ea5
25cd1c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa3ea5
 
 
 
25cd1c4
 
 
 
 
 
 
 
 
 
daa3ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25cd1c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import logging
from typing import Any, Optional, Tuple

import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_model, load_file

BOUNDING_BOX_MAX_SIZE = 1.925


def normalize_bbox(bounding_box_xyz: Tuple[float]):
    max_l = max(bounding_box_xyz)
    return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz]


def load_config(cfg_path: str) -> Any:
    """
    Load and resolve a configuration file.
    Args:
        cfg_path (str): The path to the configuration file.
    Returns:
        Any: The loaded and resolved configuration object.
    Raises:
        AssertionError: If the loaded configuration is not an instance of DictConfig.
    """

    cfg = OmegaConf.load(cfg_path)
    OmegaConf.resolve(cfg)
    assert isinstance(cfg, DictConfig)
    return cfg


def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
    """
    Parses a configuration dictionary into a structured configuration object.
    Args:
        cfg_type (Any): The type of the structured configuration object.
        cfg (DictConfig): The configuration dictionary to be parsed.
    Returns:
        Any: The structured configuration object created from the dictionary.
    """

    scfg = OmegaConf.structured(cfg_type(**cfg))
    return scfg


def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
    """
    Load a safetensors checkpoint into a PyTorch model.
    The model is updated in place.

    Handles backward compatibility for embedder weight key naming:
    - Old format: 'embedder.weight'
    - New format: 'encoder.embedder.weight', 'occupancy_decoder.embedder.weight'

    Args:
        model: PyTorch model to load weights into
        ckpt_path: Path to the safetensors checkpoint file

    Returns:
        None
    """
    assert ckpt_path.endswith(
        ".safetensors"
    ), f"Checkpoint path '{ckpt_path}' is not a safetensors file"

    # Load checkpoint as dictionary for key remapping
    checkpoint = load_file(ckpt_path)

    # Backward compatibility: remap old embedder key format to new format
    # This handles cases where checkpoint has 'embedder.weight' but model expects
    # 'encoder.embedder.weight' and 'occupancy_decoder.embedder.weight'
    if 'embedder.weight' in checkpoint:
        if 'encoder.embedder.weight' not in checkpoint:
            checkpoint['encoder.embedder.weight'] = checkpoint['embedder.weight']
        if 'occupancy_decoder.embedder.weight' not in checkpoint:
            checkpoint['occupancy_decoder.embedder.weight'] = checkpoint['embedder.weight']

    # Load remapped checkpoint into model with strict=False for flexibility
    model.load_state_dict(checkpoint, strict=False)


def save_model_weights(model: torch.nn.Module, save_path: str) -> None:
    """
    Save model weights in safetensors format.
    
    Args:
        model: PyTorch model to save
        save_path: Output path (must end with .safetensors)
    """
    assert save_path.endswith(".safetensors"), "Path must be .safetensors"
    
    from safetensors.torch import save_file
    
    state_dict = model.state_dict()
    
    save_file(state_dict, save_path)


def load_model_weights_adaption(model: torch.nn.Module, ckpt_path: str, adaption_path: str) -> torch.nn.Module:
    """
    Load a safetensors checkpoint into a PyTorch model.
    The model is updated in place.

    Args:
        model: PyTorch model to load weights into
        ckpt_path: Path to the safetensors checkpoint file

    Returns:
        None
    """
    assert ckpt_path.endswith(
        ".safetensors"
    ), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
    
    load_model(model, ckpt_path, strict=False)
    from peft import PeftModel
    model = PeftModel.from_pretrained(model, adaption_path)
    custom_weights = torch.load(f"{adaption_path}/unfrozen_weights.pth", map_location=torch.device('cuda:0'))


    model.ldr_proj.load_state_dict(custom_weights["ldr_proj"])
    model.ldr_head.load_state_dict(custom_weights["ldr_head"])
    model.dte.load_state_dict(custom_weights["dte"])
    model.rte.load_state_dict(custom_weights["rte"])
    model.xte.load_state_dict(custom_weights["xte"])
    model.yte.load_state_dict(custom_weights["yte"])
    model.zte.load_state_dict(custom_weights["zte"])

    return model


def select_device() -> Any:
    """
    Selects the appropriate PyTorch device for tensor allocation.

    Returns:
        Any: The `torch.device` object.
    """
    return torch.device(
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )

def decode_ldr(output_ids: torch.Tensor,):
    """

    Returns:
        Decode ldr file
    """
    
    return ldr