import logging from typing import Any, Optional, Tuple import torch from omegaconf import DictConfig, OmegaConf from safetensors.torch import load_model, load_file BOUNDING_BOX_MAX_SIZE = 1.925 def normalize_bbox(bounding_box_xyz: Tuple[float]): max_l = max(bounding_box_xyz) return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz] def load_config(cfg_path: str) -> Any: """ Load and resolve a configuration file. Args: cfg_path (str): The path to the configuration file. Returns: Any: The loaded and resolved configuration object. Raises: AssertionError: If the loaded configuration is not an instance of DictConfig. """ cfg = OmegaConf.load(cfg_path) OmegaConf.resolve(cfg) assert isinstance(cfg, DictConfig) return cfg def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any: """ Parses a configuration dictionary into a structured configuration object. Args: cfg_type (Any): The type of the structured configuration object. cfg (DictConfig): The configuration dictionary to be parsed. Returns: Any: The structured configuration object created from the dictionary. """ scfg = OmegaConf.structured(cfg_type(**cfg)) return scfg def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None: """ Load a safetensors checkpoint into a PyTorch model. The model is updated in place. Handles backward compatibility for embedder weight key naming: - Old format: 'embedder.weight' - New format: 'encoder.embedder.weight', 'occupancy_decoder.embedder.weight' Args: model: PyTorch model to load weights into ckpt_path: Path to the safetensors checkpoint file Returns: None """ assert ckpt_path.endswith( ".safetensors" ), f"Checkpoint path '{ckpt_path}' is not a safetensors file" # Load checkpoint as dictionary for key remapping checkpoint = load_file(ckpt_path) # Backward compatibility: remap old embedder key format to new format # This handles cases where checkpoint has 'embedder.weight' but model expects # 'encoder.embedder.weight' and 'occupancy_decoder.embedder.weight' if 'embedder.weight' in checkpoint: if 'encoder.embedder.weight' not in checkpoint: checkpoint['encoder.embedder.weight'] = checkpoint['embedder.weight'] if 'occupancy_decoder.embedder.weight' not in checkpoint: checkpoint['occupancy_decoder.embedder.weight'] = checkpoint['embedder.weight'] # Load remapped checkpoint into model with strict=False for flexibility model.load_state_dict(checkpoint, strict=False) def save_model_weights(model: torch.nn.Module, save_path: str) -> None: """ Save model weights in safetensors format. Args: model: PyTorch model to save save_path: Output path (must end with .safetensors) """ assert save_path.endswith(".safetensors"), "Path must be .safetensors" from safetensors.torch import save_file state_dict = model.state_dict() save_file(state_dict, save_path) def load_model_weights_adaption(model: torch.nn.Module, ckpt_path: str, adaption_path: str) -> torch.nn.Module: """ Load a safetensors checkpoint into a PyTorch model. The model is updated in place. Args: model: PyTorch model to load weights into ckpt_path: Path to the safetensors checkpoint file Returns: None """ assert ckpt_path.endswith( ".safetensors" ), f"Checkpoint path '{ckpt_path}' is not a safetensors file" load_model(model, ckpt_path, strict=False) from peft import PeftModel model = PeftModel.from_pretrained(model, adaption_path) custom_weights = torch.load(f"{adaption_path}/unfrozen_weights.pth", map_location=torch.device('cuda:0')) model.ldr_proj.load_state_dict(custom_weights["ldr_proj"]) model.ldr_head.load_state_dict(custom_weights["ldr_head"]) model.dte.load_state_dict(custom_weights["dte"]) model.rte.load_state_dict(custom_weights["rte"]) model.xte.load_state_dict(custom_weights["xte"]) model.yte.load_state_dict(custom_weights["yte"]) model.zte.load_state_dict(custom_weights["zte"]) return model def select_device() -> Any: """ Selects the appropriate PyTorch device for tensor allocation. Returns: Any: The `torch.device` object. """ return torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) def decode_ldr(output_ids: torch.Tensor,): """ Returns: Decode ldr file """ return ldr