Spaces:
Paused
Paused
| import logging | |
| from typing import Any, Optional, Tuple | |
| import torch | |
| from omegaconf import DictConfig, OmegaConf | |
| from safetensors.torch import load_model, load_file | |
| BOUNDING_BOX_MAX_SIZE = 1.925 | |
| def normalize_bbox(bounding_box_xyz: Tuple[float]): | |
| max_l = max(bounding_box_xyz) | |
| return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz] | |
| def load_config(cfg_path: str) -> Any: | |
| """ | |
| Load and resolve a configuration file. | |
| Args: | |
| cfg_path (str): The path to the configuration file. | |
| Returns: | |
| Any: The loaded and resolved configuration object. | |
| Raises: | |
| AssertionError: If the loaded configuration is not an instance of DictConfig. | |
| """ | |
| cfg = OmegaConf.load(cfg_path) | |
| OmegaConf.resolve(cfg) | |
| assert isinstance(cfg, DictConfig) | |
| return cfg | |
| def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any: | |
| """ | |
| Parses a configuration dictionary into a structured configuration object. | |
| Args: | |
| cfg_type (Any): The type of the structured configuration object. | |
| cfg (DictConfig): The configuration dictionary to be parsed. | |
| Returns: | |
| Any: The structured configuration object created from the dictionary. | |
| """ | |
| scfg = OmegaConf.structured(cfg_type(**cfg)) | |
| return scfg | |
| def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None: | |
| """ | |
| Load a safetensors checkpoint into a PyTorch model. | |
| The model is updated in place. | |
| Handles backward compatibility for embedder weight key naming: | |
| - Old format: 'embedder.weight' | |
| - New format: 'encoder.embedder.weight', 'occupancy_decoder.embedder.weight' | |
| Args: | |
| model: PyTorch model to load weights into | |
| ckpt_path: Path to the safetensors checkpoint file | |
| Returns: | |
| None | |
| """ | |
| assert ckpt_path.endswith( | |
| ".safetensors" | |
| ), f"Checkpoint path '{ckpt_path}' is not a safetensors file" | |
| # Load checkpoint as dictionary for key remapping | |
| checkpoint = load_file(ckpt_path) | |
| # Backward compatibility: remap old embedder key format to new format | |
| # This handles cases where checkpoint has 'embedder.weight' but model expects | |
| # 'encoder.embedder.weight' and 'occupancy_decoder.embedder.weight' | |
| if 'embedder.weight' in checkpoint: | |
| if 'encoder.embedder.weight' not in checkpoint: | |
| checkpoint['encoder.embedder.weight'] = checkpoint['embedder.weight'] | |
| if 'occupancy_decoder.embedder.weight' not in checkpoint: | |
| checkpoint['occupancy_decoder.embedder.weight'] = checkpoint['embedder.weight'] | |
| # Load remapped checkpoint into model with strict=False for flexibility | |
| model.load_state_dict(checkpoint, strict=False) | |
| def save_model_weights(model: torch.nn.Module, save_path: str) -> None: | |
| """ | |
| Save model weights in safetensors format. | |
| Args: | |
| model: PyTorch model to save | |
| save_path: Output path (must end with .safetensors) | |
| """ | |
| assert save_path.endswith(".safetensors"), "Path must be .safetensors" | |
| from safetensors.torch import save_file | |
| state_dict = model.state_dict() | |
| save_file(state_dict, save_path) | |
| def load_model_weights_adaption(model: torch.nn.Module, ckpt_path: str, adaption_path: str) -> torch.nn.Module: | |
| """ | |
| Load a safetensors checkpoint into a PyTorch model. | |
| The model is updated in place. | |
| Args: | |
| model: PyTorch model to load weights into | |
| ckpt_path: Path to the safetensors checkpoint file | |
| Returns: | |
| None | |
| """ | |
| assert ckpt_path.endswith( | |
| ".safetensors" | |
| ), f"Checkpoint path '{ckpt_path}' is not a safetensors file" | |
| load_model(model, ckpt_path, strict=False) | |
| from peft import PeftModel | |
| model = PeftModel.from_pretrained(model, adaption_path) | |
| custom_weights = torch.load(f"{adaption_path}/unfrozen_weights.pth", map_location=torch.device('cuda:0')) | |
| model.ldr_proj.load_state_dict(custom_weights["ldr_proj"]) | |
| model.ldr_head.load_state_dict(custom_weights["ldr_head"]) | |
| model.dte.load_state_dict(custom_weights["dte"]) | |
| model.rte.load_state_dict(custom_weights["rte"]) | |
| model.xte.load_state_dict(custom_weights["xte"]) | |
| model.yte.load_state_dict(custom_weights["yte"]) | |
| model.zte.load_state_dict(custom_weights["zte"]) | |
| return model | |
| def select_device() -> Any: | |
| """ | |
| Selects the appropriate PyTorch device for tensor allocation. | |
| Returns: | |
| Any: The `torch.device` object. | |
| """ | |
| return torch.device( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else "mps" | |
| if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| def decode_ldr(output_ids: torch.Tensor,): | |
| """ | |
| Returns: | |
| Decode ldr file | |
| """ | |
| return ldr | |