0xZohar's picture
Add code/cube3d/train.py
681b5d5 verified
import argparse
import os
import numpy as np
from accelerate import Accelerator
import torch
import trimesh
torch.autograd.set_detect_anomaly(True)
from cube3d.training.trainer import Trainer
from cube3d.training.bert_infer import Infer
from cube3d.training.engine import Engine, EngineFast
from cube3d.training.utils import normalize_bbox, select_device
from cube3d.training.dataset import CubeDataset, LegosDataset, LegosTestDataset
MESH_SCALE = 0.96
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_FOUND = True
except ImportError:
TENSORBOARD_FOUND = False
def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray:
"""Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0"""
vertices = vertices
bbmin = vertices.min(0)
bbmax = vertices.max(0)
center = (bbmin + bbmax) * 0.5
scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
vertices = (vertices - center) * scale
return vertices
def load_scaled_mesh(file_path: str) -> trimesh.Trimesh:
"""
Load a mesh and scale it to a unit cube, and clean the mesh.
Parameters:
file_obj: str | IO
file_type: str
Returns:
mesh: trimesh.Trimesh
"""
mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh")
mesh.remove_infinite_values()
mesh.update_faces(mesh.nondegenerate_faces())
mesh.update_faces(mesh.unique_faces())
mesh.remove_unreferenced_vertices()
if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
raise ValueError("Mesh has no vertices or faces after cleaning")
mesh.vertices = rescale(mesh.vertices)
return mesh
def load_and_process_mesh(file_path: str, n_samples: int = 8192):
"""
Loads a 3D mesh from the specified file path, samples points from its surface,
and processes the sampled points into a point cloud with normals.
Args:
file_path (str): The file path to the 3D mesh file.
n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192.
Returns:
torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud.
Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz).
"""
mesh = load_scaled_mesh(file_path)
positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples)
normals = mesh.face_normals[face_indices]
point_cloud = np.concatenate(
[positions, normals], axis=1
) # Shape: (num_samples, 6)
point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float()
return point_cloud
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="cube shape generation script")
parser.add_argument(
"--config-path",
type=str,
default="cube3d/configs/open_model_v0.5.yaml",
help="Path to the configuration YAML file.",
)
parser.add_argument(
"--mesh-path",
type=str,
required=True,
help="Path to the input mesh file.",
)
parser.add_argument(
"--data-dir",
type=str,
required=True,
help="Path to the input dataset file.",
)
parser.add_argument(
"--gpt-ckpt-path",
type=str,
required=True,
help="Path to the main GPT checkpoint file.",
)
parser.add_argument(
"--save-gpt-ckpt-path",
type=str,
required=True,
help="Path to the save main GPT checkpoint file.",
)
parser.add_argument(
"--shape-ckpt-path",
type=str,
required=True,
help="Path to the shape encoder/decoder checkpoint file.",
)
parser.add_argument(
"--expname",
type=str,
required=True,
help="Path to the tensorboard file.",
)
parser.add_argument(
"--fast-training",
help="Use optimized training with cuda graphs",
default=False,
action="store_true",
)
parser.add_argument(
"--prompt",
type=str,
required=True,
help="Text prompt for generating a 3D mesh",
)
parser.add_argument(
"--top-p",
type=float,
default=None,
help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.",
)
parser.add_argument(
"--bounding-box-xyz",
nargs=3,
type=float,
help="Three float values for x, y, z bounding box",
default=None,
required=False,
)
parser.add_argument(
"--render-gif",
help="Render a turntable gif of the mesh",
default=False,
action="store_true",
)
parser.add_argument(
"--disable-postprocessing",
help="Disable postprocessing on the mesh. This will result in a mesh with more faces.",
default=False,
action="store_true",
)
parser.add_argument(
"--resolution-base",
type=float,
default=8.0,
help="Resolution base for the shape decoder.",
)
args = parser.parse_args()
# Create Tensorboard writer
tb_writer = None
if TENSORBOARD_FOUND:
tb_writer = SummaryWriter(log_dir=os.path.join('runs', args.expname))
else:
print("Tensorboard not available: not logging progress")
device = select_device()
print(f"Using device: {device}")
mode = 'test'
accelerator = Accelerator()
# Initialize engine based on fast_training flag
if args.fast_training:
print(
"Using cuda graphs, this will take some time to warmup and capture the graph."
)
engine = EngineFast(
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, args.save_gpt_ckpt_path, device=accelerator.device, mode=mode #device
)
print("Compiled the graph.")
else:
engine = Engine(
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
)
if args.bounding_box_xyz is not None:
args.bounding_box_xyz = normalize_bbox(tuple(args.bounding_box_xyz))
point_cloud = load_and_process_mesh(args.mesh_path)
output = engine.shape_model.encode(point_cloud.to(device)) #
indices = output[3]["indices"]
print("Got the following shape indices:")
print(indices)
print("Indices shape: ", indices.shape)
train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # many possible options, see the file
train_config.max_iters = 40000
train_config.batch_size = 1 if mode=='test' else 28
train_config.save_interval = 1000
train_dataset = LegosDataset(args)
test_dataset = LegosTestDataset(args)
dataset = test_dataset if mode=='test' else train_dataset
if mode!='test':
trainer = Trainer(
config=train_config,
engine=engine,
accelerator=accelerator,
tb=tb_writer,
prompt=args.prompt,
train_dataset=dataset,
indices=indices,
resolution_base=args.resolution_base,
disable_postprocessing=args.disable_postprocessing,
top_p=args.top_p,
bounding_box_xyz=args.bounding_box_xyz,
save_gpt_ckpt_path=args.save_gpt_ckpt_path,
mode = mode
)
trainer.run()
else:
infer = Infer(
config=train_config,
engine=engine,
accelerator=accelerator,
tb=tb_writer,
prompt=args.prompt,
train_dataset=dataset,
indices=indices,
resolution_base=args.resolution_base,
disable_postprocessing=args.disable_postprocessing,
top_p=args.top_p,
bounding_box_xyz=args.bounding_box_xyz,
save_gpt_ckpt_path=args.save_gpt_ckpt_path,
mode = mode
)
infer.run()