0xZohar's picture
Fix: Add embedder key remapping for backward compatibility
daa3ea5 verified
import logging
from typing import Any, Optional, Tuple
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_model, load_file
BOUNDING_BOX_MAX_SIZE = 1.925
def normalize_bbox(bounding_box_xyz: Tuple[float]):
max_l = max(bounding_box_xyz)
return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz]
def load_config(cfg_path: str) -> Any:
"""
Load and resolve a configuration file.
Args:
cfg_path (str): The path to the configuration file.
Returns:
Any: The loaded and resolved configuration object.
Raises:
AssertionError: If the loaded configuration is not an instance of DictConfig.
"""
cfg = OmegaConf.load(cfg_path)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
return cfg
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
"""
Parses a configuration dictionary into a structured configuration object.
Args:
cfg_type (Any): The type of the structured configuration object.
cfg (DictConfig): The configuration dictionary to be parsed.
Returns:
Any: The structured configuration object created from the dictionary.
"""
scfg = OmegaConf.structured(cfg_type(**cfg))
return scfg
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Handles backward compatibility for embedder weight key naming:
- Old format: 'embedder.weight'
- New format: 'encoder.embedder.weight', 'occupancy_decoder.embedder.weight'
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(
".safetensors"
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
# Load checkpoint as dictionary for key remapping
checkpoint = load_file(ckpt_path)
# Backward compatibility: remap old embedder key format to new format
# This handles cases where checkpoint has 'embedder.weight' but model expects
# 'encoder.embedder.weight' and 'occupancy_decoder.embedder.weight'
if 'embedder.weight' in checkpoint:
if 'encoder.embedder.weight' not in checkpoint:
checkpoint['encoder.embedder.weight'] = checkpoint['embedder.weight']
if 'occupancy_decoder.embedder.weight' not in checkpoint:
checkpoint['occupancy_decoder.embedder.weight'] = checkpoint['embedder.weight']
# Load remapped checkpoint into model with strict=False for flexibility
model.load_state_dict(checkpoint, strict=False)
def save_model_weights(model: torch.nn.Module, save_path: str) -> None:
"""
Save model weights in safetensors format.
Args:
model: PyTorch model to save
save_path: Output path (must end with .safetensors)
"""
assert save_path.endswith(".safetensors"), "Path must be .safetensors"
from safetensors.torch import save_file
state_dict = model.state_dict()
save_file(state_dict, save_path)
def load_model_weights_adaption(model: torch.nn.Module, ckpt_path: str, adaption_path: str) -> torch.nn.Module:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(
".safetensors"
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
load_model(model, ckpt_path, strict=False)
from peft import PeftModel
model = PeftModel.from_pretrained(model, adaption_path)
custom_weights = torch.load(f"{adaption_path}/unfrozen_weights.pth", map_location=torch.device('cuda:0'))
model.ldr_proj.load_state_dict(custom_weights["ldr_proj"])
model.ldr_head.load_state_dict(custom_weights["ldr_head"])
model.dte.load_state_dict(custom_weights["dte"])
model.rte.load_state_dict(custom_weights["rte"])
model.xte.load_state_dict(custom_weights["xte"])
model.yte.load_state_dict(custom_weights["yte"])
model.zte.load_state_dict(custom_weights["zte"])
return model
def select_device() -> Any:
"""
Selects the appropriate PyTorch device for tensor allocation.
Returns:
Any: The `torch.device` object.
"""
return torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
def decode_ldr(output_ids: torch.Tensor,):
"""
Returns:
Decode ldr file
"""
return ldr