Spaces:
Paused
Paused
| """ | |
| Simple training loop; Boilerplate that could apply to any arbitrary neural network, | |
| so nothing in this file really has anything to do with GPT specifically. | |
| """ | |
| from typing import Optional, Tuple, List | |
| import time | |
| import os | |
| from collections import defaultdict | |
| from accelerate import Accelerator | |
| import torch | |
| from torch.nn import functional as F | |
| from torch.utils.data.dataloader import DataLoader | |
| from mingpt.utils import CfgNode as CN | |
| from cube3d.training.utils import save_model_weights, mask_cross_entropy, normalize_bboxs, top_k_prob_mask | |
| from cube3d.training.process_single_ldr import logits2ldr, logits2ldrot, logits2ldrp, logits2flatldrp, logits2flatldrpr | |
| from cube3d.inference.utils import load_model_weights | |
| from tqdm import tqdm | |
| def generate_tokens( | |
| engine, | |
| prompt, | |
| inputs_ids, | |
| latent, | |
| resolution_base=8.0, | |
| disable_postprocess=False, | |
| top_p=None, | |
| bounding_box_xyz=None, | |
| strategy=None | |
| ): | |
| output_ids = engine.t2t( | |
| #[prompt], | |
| prompt, | |
| #use_kv_cache=True, | |
| inputs_ids=inputs_ids, | |
| latent=latent, | |
| use_kv_cache=False, | |
| resolution_base=resolution_base, | |
| top_p=top_p, | |
| bounding_box_xyz=bounding_box_xyz, | |
| strategy=strategy | |
| ) | |
| return output_ids | |
| class Trainer: | |
| def get_default_config(): | |
| C = CN() | |
| # device to train on | |
| C.device = 'auto' | |
| # dataloder parameters | |
| C.num_workers = 4 | |
| # optimizer parameters | |
| C.max_iters = None | |
| C.batch_size = 4 | |
| C.learning_rate = 3e-4 | |
| C.betas = (0.9, 0.95) | |
| C.weight_decay = 0.1 # only applied on matmul weights | |
| C.grad_norm_clip = 1.0 | |
| C.save_interval = None | |
| return C | |
| def __init__( | |
| self, | |
| config, | |
| engine, | |
| train_dataset, | |
| accelerator, | |
| tb, | |
| prompt: str, | |
| indices: Optional[List[int]] = None, | |
| resolution_base: float = 8.0, | |
| disable_postprocessing: bool = False, | |
| top_p: float = None, | |
| bounding_box_xyz: Optional[Tuple[float]] = None, | |
| save_gpt_ckpt_path: str = None, | |
| mode: str = 'train' | |
| ): | |
| self.config = config | |
| self.engine = engine | |
| self.model = engine.gpt_model | |
| self.optimizer = None | |
| self.callbacks = defaultdict(list) | |
| self.train_dataset = train_dataset | |
| self.accelerator = accelerator | |
| # Training parameters | |
| self.prompt = prompt | |
| self.targets = indices | |
| self.resolution_base = resolution_base | |
| self.disable_postprocessing = disable_postprocessing | |
| self.top_p = top_p | |
| self.bounding_box_xyz = bounding_box_xyz | |
| self.save_gpt_ckpt_path = save_gpt_ckpt_path | |
| # determine the device we'll train on | |
| if config.device == 'auto': | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| else: | |
| self.device = config.device | |
| self.model = self.model.to(self.device) | |
| print("running on device", self.device) | |
| # variables that will be assigned to trainer class later for logging and etc | |
| self.iter_num = 0 | |
| self.iter_time = 0.0 | |
| self.iter_dt = 0.0 | |
| self.tb_writer = tb | |
| self.mode = mode | |
| def add_callback(self, onevent: str, callback): | |
| self.callbacks[onevent].append(callback) | |
| def set_callback(self, onevent: str, callback): | |
| self.callbacks[onevent] = [callback] | |
| def trigger_callbacks(self, onevent: str): | |
| for callback in self.callbacks.get(onevent, []): | |
| callback(self) | |
| def run(self): | |
| model, config = self.model, self.config | |
| # setup the optimizer | |
| #self.optimizer = self.engine.configure_optimizers(config) | |
| self.optimizer, self.scheduler = self.engine.configure_optimizers_scratch_linear(config) #self.engine.configure_optimizers_lora_linear(config) | |
| # setup the dataloader | |
| train_loader = DataLoader( | |
| self.train_dataset, | |
| shuffle=False if self.mode!='train' else True, | |
| batch_size=config.batch_size, | |
| ) | |
| model.train() | |
| model, self.optimizer, train_loader = self.accelerator.prepare(model, self.optimizer, train_loader) | |
| self.iter_num = 0 | |
| self.iter_time = time.time() | |
| data_iter = iter(train_loader) | |
| ema_loss_for_log = 0.0 | |
| ema_ploss_for_log = 0.0 | |
| ema_rloss_for_log = 0.0 | |
| ema_dloss_for_log = 0.0 | |
| ema_floss_for_log = 0.0 | |
| #loss | |
| dat_num = 1217 #286 | |
| x_num = 251 | |
| y_num = 215 | |
| z_num = 525 | |
| rot_num = 24 | |
| shift = 0 | |
| stride = 5 | |
| attr_shift = stride-3 #with dat and rot,+1 for bert | |
| bert_shift = 1 | |
| x = x_num | |
| xy = x_num + y_num + rot_num | |
| xyz = x_num + y_num + z_num + rot_num | |
| progress_bar = tqdm(range(0, config.max_iters), desc="Training progress") | |
| #while True: | |
| for self.iter_num in range(0, config.max_iters+1): | |
| # fetch the next batch (x, y) and re-init iterator if needed | |
| try: | |
| batch = next(data_iter) | |
| except StopIteration: | |
| data_iter = iter(train_loader) | |
| batch = next(data_iter) | |
| #batch = [t['latent'].to(self.device) for t in batch] | |
| self.prompt, self.targets, self.box = batch['prompt'], batch['target'].to(self.device), batch['bbox'] | |
| #self.targets = batch['latent'].to(self.device) | |
| targets = self.targets.clone() | |
| logits, inputs_ids, strategy, mask, cut_idx = generate_tokens( | |
| self.engine, | |
| self.prompt, | |
| targets, | |
| None, | |
| self.resolution_base, | |
| self.disable_postprocessing, | |
| self.top_p, | |
| #self.bounding_box_xyz, | |
| normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) | |
| None | |
| ) | |
| # rotation_loss = F.cross_entropy( | |
| # logits[:,:-1,:rot_num].permute(0, 2, 1), | |
| # inputs_ids[:,shift:,:rot_num].argmax(-1), | |
| # ) | |
| # px_loss = mask_cross_entropy(rot_num, x+rot_num, self.box[:, 0], logits, inputs_ids, shift) | |
| # py_loss = mask_cross_entropy(x+rot_num, xy, self.box[:, 1], logits, inputs_ids, shift) | |
| # pz_loss = mask_cross_entropy(xy, xyz, self.box[:, 2], logits, inputs_ids, shift) | |
| px_loss = F.cross_entropy( | |
| logits[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1), | |
| inputs_ids[:,shift:,-5], | |
| ignore_index=-1 #+1 for padding | |
| ) | |
| py_loss = F.cross_entropy( | |
| logits[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1), | |
| inputs_ids[:,shift:,-4], | |
| ignore_index=-1 | |
| ) | |
| pz_loss = F.cross_entropy( | |
| logits[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4].permute(0, 2, 1), | |
| inputs_ids[:,shift:,-3], | |
| ignore_index=-1 | |
| ) | |
| position_loss = px_loss + py_loss + pz_loss | |
| # dat_loss = F.cross_entropy( | |
| # logits[:,0:-4:stride,:dat_num+1].permute(0, 2, 1), | |
| # inputs_ids[:,shift:,-6], | |
| # ignore_index=-1 | |
| # ) | |
| rotation_loss = F.cross_entropy( | |
| logits[:,1+bert_shift:-3:stride,:rot_num+1].permute(0, 2, 1), | |
| inputs_ids[:,shift:,-7], | |
| ignore_index=-1 | |
| ) | |
| # flag_loss = F.cross_entropy( | |
| # logits[:,:-1,xyz+dat_num:xyz+dat_num+2].permute(0, 2, 1), | |
| # inputs_ids[:,shift:,xyz+dat_num:xyz+dat_num+2].argmax(-1), | |
| # ) | |
| # flag_loss = F.cross_entropy( | |
| # logits[:,:-1,-2:].permute(0, 2, 1), | |
| # inputs_ids[:,shift:,-2:].argmax(-1), | |
| # ) | |
| lambda_posiition = 1.0 | |
| lambda_rotation = 1.0 | |
| lambda_dat = 1.0 | |
| lambda_flag = 50.0 | |
| self.loss = lambda_posiition * position_loss #+ \ | |
| #lambda_rotation * rotation_loss #+ \ | |
| #lambda_flag * flag_loss | |
| #lambda_dat * dat_loss + \ | |
| if strategy==1 or strategy==2: | |
| self.loss+=lambda_rotation * rotation_loss | |
| # targets = self.targets.clone() | |
| # # mask_topk, mask_inv = top_k_prob_mask(F.softmax(logits[:, 1:-3:stride, :rot_num+1], dim=2), cut_idx, top_percent=0.5) | |
| # # targets[:,shift:,-7][mask_topk] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)[mask_topk] | |
| # # targets[:,shift:,-7][mask_inv] = self.engine.gpt_model.rot_num+1 | |
| # targets[:,shift:,-7] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1) | |
| # #targets[:,shift:,-4] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1) | |
| # logits_x, inputs_ids, strategy, mask, cut_idx = generate_tokens( | |
| # self.engine, | |
| # self.prompt, | |
| # targets, | |
| # None, | |
| # self.resolution_base, | |
| # self.disable_postprocessing, | |
| # self.top_p, | |
| # #self.bounding_box_xyz, | |
| # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) | |
| # 0 | |
| # ) | |
| # targets = self.targets.clone() | |
| # targets[:,shift:,-7] = logits_x[:,1+bert_shift:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1) | |
| # mask_x, mask_x_inv = top_k_prob_mask(F.softmax(logits[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1], dim=2), cut_idx, top_percent=0.5) | |
| # mask_y, mask_y_inv = top_k_prob_mask(F.softmax(logits[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3], dim=2), cut_idx, top_percent=0.5) | |
| # mask_z, mask_z_inv = top_k_prob_mask(F.softmax(logits[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4], dim=2), cut_idx, top_percent=0.5) | |
| # targets[:,shift:,-5][mask_x] = logits_x[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1).argmax(dim=1)[mask_x] | |
| # targets[:,shift:,-5][mask_x_inv] = self.engine.gpt_model.x_num+1 | |
| # targets[:,shift:,-4][mask_y] = logits_x[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1)[mask_y] | |
| # targets[:,shift:,-4][mask_y_inv] = self.engine.gpt_model.y_num+1 | |
| # targets[:,shift:,-3][mask_z] = logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4].permute(0, 2, 1).argmax(dim=1)[mask_z] | |
| # targets[:,shift:,-3][mask_z_inv] = self.engine.gpt_model.z_num+1 | |
| # logits_p, inputs_ids, strategy, mask, cut_idx = generate_tokens( | |
| # self.engine, | |
| # self.prompt, | |
| # targets, | |
| # None, | |
| # self.resolution_base, | |
| # self.disable_postprocessing, | |
| # self.top_p, | |
| # #self.bounding_box_xyz, | |
| # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) | |
| # None | |
| # ) | |
| # logits_p[:,1+bert_shift:-3:stride,:rot_num+1] = logits[:,1+bert_shift:-3:stride,:rot_num+1] | |
| # logits2flatldrpr(logits_p[0].cpu().detach().numpy(), inputs_ids[0].cpu().detach().numpy(), stride, 0, output_file=f"test_rightd2r2p2p_{self.iter_num}_scratch_0p5_bert.ldr") | |
| # targets = self.targets.clone() | |
| # targets[:,shift:,-7] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1) | |
| # targets[:,shift:,-4] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1) | |
| # targets[:,shift:,-5] = logits_x[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1).argmax(dim=1) | |
| # logits_z, inputs_ids, strategy = generate_tokens( | |
| # self.engine, | |
| # self.prompt, | |
| # targets, | |
| # None, | |
| # self.resolution_base, | |
| # self.disable_postprocessing, | |
| # self.top_p, | |
| # #self.bounding_box_xyz, | |
| # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) | |
| # 3 | |
| # ) | |
| # backprop and update the parameters | |
| model.zero_grad(set_to_none=True) | |
| # #if self.mode!='train': | |
| # logits_z[:,1:-3:stride,:rot_num+1] = logits[:,1:-3:stride,:rot_num+1] | |
| # logits_z[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3] | |
| # logits_z[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1] = logits_x[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1] | |
| # if self.iter_num>4: | |
| # break | |
| self.accelerator.backward(self.loss) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) | |
| self.optimizer.step() | |
| self.scheduler.step() | |
| with torch.no_grad(): | |
| # Progress bar | |
| ema_loss_for_log = 0.4 * self.loss.item() + 0.6 * ema_loss_for_log | |
| ema_ploss_for_log = 0.4 * position_loss.item() + 0.6 * ema_ploss_for_log | |
| ema_rloss_for_log = 0.4 * rotation_loss.item() + 0.6 * ema_rloss_for_log | |
| #ema_dloss_for_log = 0.4 * dat_loss.item() + 0.6 * ema_dloss_for_log | |
| #ema_floss_for_log = 0.4 * flag_loss.item() + 0.6 * ema_floss_for_log | |
| if self.iter_num % 10 == 0: | |
| progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", | |
| "Positon_Loss": f"{ema_ploss_for_log:.{7}f}", | |
| "Rotation_Loss": f"{ema_rloss_for_log:.{7}f}", | |
| #"Dat_Loss": f"{ema_dloss_for_log:.{7}f}", | |
| #"Flag_Loss": f"{ema_floss_for_log:.{7}f}", | |
| }) | |
| progress_bar.update(10) | |
| #logits2ldr(logits[0].cpu().detach().numpy()) | |
| if (self.iter_num % config.save_interval == 0 and self.iter_num != 0): | |
| if self.accelerator.is_main_process: | |
| save_model_weights( | |
| self.engine.gpt_model, | |
| self.save_gpt_ckpt_path, | |
| ) | |
| # self.engine.gpt_model.save_pretrained(self.save_gpt_ckpt_path) | |
| # torch.save({ | |
| # "ldr_proj": self.engine.gpt_model.ldr_proj.state_dict(), | |
| # "ldr_head": self.engine.gpt_model.ldr_head.state_dict(), | |
| # "rte": self.engine.gpt_model.rte.state_dict(), | |
| # "dte": self.engine.gpt_model.dte.state_dict(), | |
| # "xte": self.engine.gpt_model.xte.state_dict(), | |
| # "yte": self.engine.gpt_model.yte.state_dict(), | |
| # "zte": self.engine.gpt_model.zte.state_dict(), | |
| # }, f"{self.save_gpt_ckpt_path}/unfrozen_weights.pth") | |
| if self.tb_writer: #and self.accelerator.is_main_process: | |
| self.tb_writer.add_scalar(f'train_loss/position_loss', position_loss.item(), self.iter_num) | |
| self.tb_writer.add_scalar(f'train_loss/rotation_loss', rotation_loss.item(), self.iter_num) | |
| #self.tb_writer.add_scalar(f'train_loss/dat_loss', dat_loss.item(), self.iter_num) | |
| #self.tb_writer.add_scalar(f'train_loss/flag_loss', flag_loss.item(), self.iter_num) | |
| self.tb_writer.add_scalar(f'train_loss/total_loss', self.loss.item(), self.iter_num) | |
| if self.iter_num == config.max_iters: | |
| progress_bar.close() |