Spaces:
Paused
Paused
| import argparse | |
| import os | |
| import numpy as np | |
| from accelerate import Accelerator | |
| import torch | |
| import trimesh | |
| torch.autograd.set_detect_anomaly(True) | |
| from cube3d.training.trainer import Trainer | |
| from cube3d.training.bert_infer import Infer | |
| from cube3d.training.engine import Engine, EngineFast | |
| from cube3d.training.utils import normalize_bbox, select_device | |
| from cube3d.training.dataset import CubeDataset, LegosDataset, LegosTestDataset | |
| MESH_SCALE = 0.96 | |
| try: | |
| from torch.utils.tensorboard import SummaryWriter | |
| TENSORBOARD_FOUND = True | |
| except ImportError: | |
| TENSORBOARD_FOUND = False | |
| def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray: | |
| """Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0""" | |
| vertices = vertices | |
| bbmin = vertices.min(0) | |
| bbmax = vertices.max(0) | |
| center = (bbmin + bbmax) * 0.5 | |
| scale = 2.0 * mesh_scale / (bbmax - bbmin).max() | |
| vertices = (vertices - center) * scale | |
| return vertices | |
| def load_scaled_mesh(file_path: str) -> trimesh.Trimesh: | |
| """ | |
| Load a mesh and scale it to a unit cube, and clean the mesh. | |
| Parameters: | |
| file_obj: str | IO | |
| file_type: str | |
| Returns: | |
| mesh: trimesh.Trimesh | |
| """ | |
| mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh") | |
| mesh.remove_infinite_values() | |
| mesh.update_faces(mesh.nondegenerate_faces()) | |
| mesh.update_faces(mesh.unique_faces()) | |
| mesh.remove_unreferenced_vertices() | |
| if len(mesh.vertices) == 0 or len(mesh.faces) == 0: | |
| raise ValueError("Mesh has no vertices or faces after cleaning") | |
| mesh.vertices = rescale(mesh.vertices) | |
| return mesh | |
| def load_and_process_mesh(file_path: str, n_samples: int = 8192): | |
| """ | |
| Loads a 3D mesh from the specified file path, samples points from its surface, | |
| and processes the sampled points into a point cloud with normals. | |
| Args: | |
| file_path (str): The file path to the 3D mesh file. | |
| n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192. | |
| Returns: | |
| torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud. | |
| Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz). | |
| """ | |
| mesh = load_scaled_mesh(file_path) | |
| positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples) | |
| normals = mesh.face_normals[face_indices] | |
| point_cloud = np.concatenate( | |
| [positions, normals], axis=1 | |
| ) # Shape: (num_samples, 6) | |
| point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float() | |
| return point_cloud | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="cube shape generation script") | |
| parser.add_argument( | |
| "--config-path", | |
| type=str, | |
| default="cube3d/configs/open_model_v0.5.yaml", | |
| help="Path to the configuration YAML file.", | |
| ) | |
| parser.add_argument( | |
| "--mesh-path", | |
| type=str, | |
| required=True, | |
| help="Path to the input mesh file.", | |
| ) | |
| parser.add_argument( | |
| "--data-dir", | |
| type=str, | |
| required=True, | |
| help="Path to the input dataset file.", | |
| ) | |
| parser.add_argument( | |
| "--gpt-ckpt-path", | |
| type=str, | |
| required=True, | |
| help="Path to the main GPT checkpoint file.", | |
| ) | |
| parser.add_argument( | |
| "--save-gpt-ckpt-path", | |
| type=str, | |
| required=True, | |
| help="Path to the save main GPT checkpoint file.", | |
| ) | |
| parser.add_argument( | |
| "--shape-ckpt-path", | |
| type=str, | |
| required=True, | |
| help="Path to the shape encoder/decoder checkpoint file.", | |
| ) | |
| parser.add_argument( | |
| "--expname", | |
| type=str, | |
| required=True, | |
| help="Path to the tensorboard file.", | |
| ) | |
| parser.add_argument( | |
| "--fast-training", | |
| help="Use optimized training with cuda graphs", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| required=True, | |
| help="Text prompt for generating a 3D mesh", | |
| ) | |
| parser.add_argument( | |
| "--top-p", | |
| type=float, | |
| default=None, | |
| help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.", | |
| ) | |
| parser.add_argument( | |
| "--bounding-box-xyz", | |
| nargs=3, | |
| type=float, | |
| help="Three float values for x, y, z bounding box", | |
| default=None, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--render-gif", | |
| help="Render a turntable gif of the mesh", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--disable-postprocessing", | |
| help="Disable postprocessing on the mesh. This will result in a mesh with more faces.", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--resolution-base", | |
| type=float, | |
| default=8.0, | |
| help="Resolution base for the shape decoder.", | |
| ) | |
| args = parser.parse_args() | |
| # Create Tensorboard writer | |
| tb_writer = None | |
| if TENSORBOARD_FOUND: | |
| tb_writer = SummaryWriter(log_dir=os.path.join('runs', args.expname)) | |
| else: | |
| print("Tensorboard not available: not logging progress") | |
| device = select_device() | |
| print(f"Using device: {device}") | |
| mode = 'test' | |
| accelerator = Accelerator() | |
| # Initialize engine based on fast_training flag | |
| if args.fast_training: | |
| print( | |
| "Using cuda graphs, this will take some time to warmup and capture the graph." | |
| ) | |
| engine = EngineFast( | |
| args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, args.save_gpt_ckpt_path, device=accelerator.device, mode=mode #device | |
| ) | |
| print("Compiled the graph.") | |
| else: | |
| engine = Engine( | |
| args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device | |
| ) | |
| if args.bounding_box_xyz is not None: | |
| args.bounding_box_xyz = normalize_bbox(tuple(args.bounding_box_xyz)) | |
| point_cloud = load_and_process_mesh(args.mesh_path) | |
| output = engine.shape_model.encode(point_cloud.to(device)) # | |
| indices = output[3]["indices"] | |
| print("Got the following shape indices:") | |
| print(indices) | |
| print("Indices shape: ", indices.shape) | |
| train_config = Trainer.get_default_config() | |
| train_config.learning_rate = 5e-4 # many possible options, see the file | |
| train_config.max_iters = 40000 | |
| train_config.batch_size = 1 if mode=='test' else 28 | |
| train_config.save_interval = 1000 | |
| train_dataset = LegosDataset(args) | |
| test_dataset = LegosTestDataset(args) | |
| dataset = test_dataset if mode=='test' else train_dataset | |
| if mode!='test': | |
| trainer = Trainer( | |
| config=train_config, | |
| engine=engine, | |
| accelerator=accelerator, | |
| tb=tb_writer, | |
| prompt=args.prompt, | |
| train_dataset=dataset, | |
| indices=indices, | |
| resolution_base=args.resolution_base, | |
| disable_postprocessing=args.disable_postprocessing, | |
| top_p=args.top_p, | |
| bounding_box_xyz=args.bounding_box_xyz, | |
| save_gpt_ckpt_path=args.save_gpt_ckpt_path, | |
| mode = mode | |
| ) | |
| trainer.run() | |
| else: | |
| infer = Infer( | |
| config=train_config, | |
| engine=engine, | |
| accelerator=accelerator, | |
| tb=tb_writer, | |
| prompt=args.prompt, | |
| train_dataset=dataset, | |
| indices=indices, | |
| resolution_base=args.resolution_base, | |
| disable_postprocessing=args.disable_postprocessing, | |
| top_p=args.top_p, | |
| bounding_box_xyz=args.bounding_box_xyz, | |
| save_gpt_ckpt_path=args.save_gpt_ckpt_path, | |
| mode = mode | |
| ) | |
| infer.run() | |