Spaces:
Paused
Paused
File size: 4,763 Bytes
25cd1c4 daa3ea5 25cd1c4 daa3ea5 25cd1c4 daa3ea5 25cd1c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import logging
from typing import Any, Optional, Tuple
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_model, load_file
BOUNDING_BOX_MAX_SIZE = 1.925
def normalize_bbox(bounding_box_xyz: Tuple[float]):
max_l = max(bounding_box_xyz)
return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz]
def load_config(cfg_path: str) -> Any:
"""
Load and resolve a configuration file.
Args:
cfg_path (str): The path to the configuration file.
Returns:
Any: The loaded and resolved configuration object.
Raises:
AssertionError: If the loaded configuration is not an instance of DictConfig.
"""
cfg = OmegaConf.load(cfg_path)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
return cfg
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
"""
Parses a configuration dictionary into a structured configuration object.
Args:
cfg_type (Any): The type of the structured configuration object.
cfg (DictConfig): The configuration dictionary to be parsed.
Returns:
Any: The structured configuration object created from the dictionary.
"""
scfg = OmegaConf.structured(cfg_type(**cfg))
return scfg
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Handles backward compatibility for embedder weight key naming:
- Old format: 'embedder.weight'
- New format: 'encoder.embedder.weight', 'occupancy_decoder.embedder.weight'
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(
".safetensors"
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
# Load checkpoint as dictionary for key remapping
checkpoint = load_file(ckpt_path)
# Backward compatibility: remap old embedder key format to new format
# This handles cases where checkpoint has 'embedder.weight' but model expects
# 'encoder.embedder.weight' and 'occupancy_decoder.embedder.weight'
if 'embedder.weight' in checkpoint:
if 'encoder.embedder.weight' not in checkpoint:
checkpoint['encoder.embedder.weight'] = checkpoint['embedder.weight']
if 'occupancy_decoder.embedder.weight' not in checkpoint:
checkpoint['occupancy_decoder.embedder.weight'] = checkpoint['embedder.weight']
# Load remapped checkpoint into model with strict=False for flexibility
model.load_state_dict(checkpoint, strict=False)
def save_model_weights(model: torch.nn.Module, save_path: str) -> None:
"""
Save model weights in safetensors format.
Args:
model: PyTorch model to save
save_path: Output path (must end with .safetensors)
"""
assert save_path.endswith(".safetensors"), "Path must be .safetensors"
from safetensors.torch import save_file
state_dict = model.state_dict()
save_file(state_dict, save_path)
def load_model_weights_adaption(model: torch.nn.Module, ckpt_path: str, adaption_path: str) -> torch.nn.Module:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(
".safetensors"
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
load_model(model, ckpt_path, strict=False)
from peft import PeftModel
model = PeftModel.from_pretrained(model, adaption_path)
custom_weights = torch.load(f"{adaption_path}/unfrozen_weights.pth", map_location=torch.device('cuda:0'))
model.ldr_proj.load_state_dict(custom_weights["ldr_proj"])
model.ldr_head.load_state_dict(custom_weights["ldr_head"])
model.dte.load_state_dict(custom_weights["dte"])
model.rte.load_state_dict(custom_weights["rte"])
model.xte.load_state_dict(custom_weights["xte"])
model.yte.load_state_dict(custom_weights["yte"])
model.zte.load_state_dict(custom_weights["zte"])
return model
def select_device() -> Any:
"""
Selects the appropriate PyTorch device for tensor allocation.
Returns:
Any: The `torch.device` object.
"""
return torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
def decode_ldr(output_ids: torch.Tensor,):
"""
Returns:
Decode ldr file
"""
return ldr
|