import argparse import os import numpy as np from accelerate import Accelerator import torch import trimesh torch.autograd.set_detect_anomaly(True) from cube3d.training.trainer import Trainer from cube3d.training.bert_infer import Infer from cube3d.training.engine import Engine, EngineFast from cube3d.training.utils import normalize_bbox, select_device from cube3d.training.dataset import CubeDataset, LegosDataset, LegosTestDataset MESH_SCALE = 0.96 try: from torch.utils.tensorboard import SummaryWriter TENSORBOARD_FOUND = True except ImportError: TENSORBOARD_FOUND = False def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray: """Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0""" vertices = vertices bbmin = vertices.min(0) bbmax = vertices.max(0) center = (bbmin + bbmax) * 0.5 scale = 2.0 * mesh_scale / (bbmax - bbmin).max() vertices = (vertices - center) * scale return vertices def load_scaled_mesh(file_path: str) -> trimesh.Trimesh: """ Load a mesh and scale it to a unit cube, and clean the mesh. Parameters: file_obj: str | IO file_type: str Returns: mesh: trimesh.Trimesh """ mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh") mesh.remove_infinite_values() mesh.update_faces(mesh.nondegenerate_faces()) mesh.update_faces(mesh.unique_faces()) mesh.remove_unreferenced_vertices() if len(mesh.vertices) == 0 or len(mesh.faces) == 0: raise ValueError("Mesh has no vertices or faces after cleaning") mesh.vertices = rescale(mesh.vertices) return mesh def load_and_process_mesh(file_path: str, n_samples: int = 8192): """ Loads a 3D mesh from the specified file path, samples points from its surface, and processes the sampled points into a point cloud with normals. Args: file_path (str): The file path to the 3D mesh file. n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192. Returns: torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud. Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz). """ mesh = load_scaled_mesh(file_path) positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples) normals = mesh.face_normals[face_indices] point_cloud = np.concatenate( [positions, normals], axis=1 ) # Shape: (num_samples, 6) point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float() return point_cloud if __name__ == "__main__": parser = argparse.ArgumentParser(description="cube shape generation script") parser.add_argument( "--config-path", type=str, default="cube3d/configs/open_model_v0.5.yaml", help="Path to the configuration YAML file.", ) parser.add_argument( "--mesh-path", type=str, required=True, help="Path to the input mesh file.", ) parser.add_argument( "--data-dir", type=str, required=True, help="Path to the input dataset file.", ) parser.add_argument( "--gpt-ckpt-path", type=str, required=True, help="Path to the main GPT checkpoint file.", ) parser.add_argument( "--save-gpt-ckpt-path", type=str, required=True, help="Path to the save main GPT checkpoint file.", ) parser.add_argument( "--shape-ckpt-path", type=str, required=True, help="Path to the shape encoder/decoder checkpoint file.", ) parser.add_argument( "--expname", type=str, required=True, help="Path to the tensorboard file.", ) parser.add_argument( "--fast-training", help="Use optimized training with cuda graphs", default=False, action="store_true", ) parser.add_argument( "--prompt", type=str, required=True, help="Text prompt for generating a 3D mesh", ) parser.add_argument( "--top-p", type=float, default=None, help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.", ) parser.add_argument( "--bounding-box-xyz", nargs=3, type=float, help="Three float values for x, y, z bounding box", default=None, required=False, ) parser.add_argument( "--render-gif", help="Render a turntable gif of the mesh", default=False, action="store_true", ) parser.add_argument( "--disable-postprocessing", help="Disable postprocessing on the mesh. This will result in a mesh with more faces.", default=False, action="store_true", ) parser.add_argument( "--resolution-base", type=float, default=8.0, help="Resolution base for the shape decoder.", ) args = parser.parse_args() # Create Tensorboard writer tb_writer = None if TENSORBOARD_FOUND: tb_writer = SummaryWriter(log_dir=os.path.join('runs', args.expname)) else: print("Tensorboard not available: not logging progress") device = select_device() print(f"Using device: {device}") mode = 'test' accelerator = Accelerator() # Initialize engine based on fast_training flag if args.fast_training: print( "Using cuda graphs, this will take some time to warmup and capture the graph." ) engine = EngineFast( args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, args.save_gpt_ckpt_path, device=accelerator.device, mode=mode #device ) print("Compiled the graph.") else: engine = Engine( args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device ) if args.bounding_box_xyz is not None: args.bounding_box_xyz = normalize_bbox(tuple(args.bounding_box_xyz)) point_cloud = load_and_process_mesh(args.mesh_path) output = engine.shape_model.encode(point_cloud.to(device)) # indices = output[3]["indices"] print("Got the following shape indices:") print(indices) print("Indices shape: ", indices.shape) train_config = Trainer.get_default_config() train_config.learning_rate = 5e-4 # many possible options, see the file train_config.max_iters = 40000 train_config.batch_size = 1 if mode=='test' else 28 train_config.save_interval = 1000 train_dataset = LegosDataset(args) test_dataset = LegosTestDataset(args) dataset = test_dataset if mode=='test' else train_dataset if mode!='test': trainer = Trainer( config=train_config, engine=engine, accelerator=accelerator, tb=tb_writer, prompt=args.prompt, train_dataset=dataset, indices=indices, resolution_base=args.resolution_base, disable_postprocessing=args.disable_postprocessing, top_p=args.top_p, bounding_box_xyz=args.bounding_box_xyz, save_gpt_ckpt_path=args.save_gpt_ckpt_path, mode = mode ) trainer.run() else: infer = Infer( config=train_config, engine=engine, accelerator=accelerator, tb=tb_writer, prompt=args.prompt, train_dataset=dataset, indices=indices, resolution_base=args.resolution_base, disable_postprocessing=args.disable_postprocessing, top_p=args.top_p, bounding_box_xyz=args.bounding_box_xyz, save_gpt_ckpt_path=args.save_gpt_ckpt_path, mode = mode ) infer.run()