""" DEIM: DETR with Improved Matching for Fast Convergence Copyright (c) 2024 The DEIM Authors. All Rights Reserved. --------------------------------------------------------------------------------- Modified from DETR (https://github.com/facebookresearch/detr/blob/main/engine.py) Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ import sys import math import gc import os from typing import Iterable import torch import torch.amp from torch.utils.tensorboard import SummaryWriter from torch.cuda.amp.grad_scaler import GradScaler import cv2 import numpy as np from ..optim import ModelEMA, Warmup from ..data import CocoEvaluator from ..misc import MetricLogger, SmoothedValue, dist_utils try: import wandb _WANDB_AVAILABLE = True except ImportError: _WANDB_AVAILABLE = False wandb = None def visualize_augmented_batch(samples, targets, epoch, step): """ Visualize augmented images with bounding boxes using OpenCV imshow. Args: samples (torch.Tensor): Batch of images [B, C, H, W] targets (list): List of target dictionaries epoch (int): Current epoch step (int): Current step """ for idx in range(len(samples)): img = samples[idx].detach().cpu().permute(1, 2, 0).numpy() img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) img_bgr = (img_bgr * 255).astype(np.uint8) print(targets[idx]["image_id"]) boxes = targets[idx]['boxes'].detach().cpu() labels = targets[idx].get('labels', None) h, w = img_bgr.shape[:2] print(len(boxes)) boxes_np = boxes.numpy() for i, box in enumerate(boxes_np): cx, cy, bw, bh = box x1 = int((cx - bw / 2) * w) y1 = int((cy - bh / 2) * h) x2 = int((cx + bw / 2) * w) y2 = int((cy + bh / 2) * h) cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (0, 0, 255), 3) print(f"box: {(x1, y1), (x2, y2)}") label = f"cls_{int(labels[i])}" label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0] label_y = max(label_size[1] + 5, y1 - 5) cv2.rectangle(img_bgr, (x1, label_y - label_size[1] - 5), (x1 + label_size[0] + 10, label_y + 5), (0, 0, 255), -1) cv2.putText(img_bgr, label, (x1 + 5, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) cv2.namedWindow('img', cv2.WINDOW_KEEPRATIO) cv2.imshow("img", img_bgr) cv2.waitKey(0) def train_one_epoch(self_lr_scheduler, lr_scheduler, model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0, **kwargs): model.train() criterion.train() metric_logger = MetricLogger(delimiter=" ") metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = kwargs.get('print_freq', 10) writer :SummaryWriter = kwargs.get('writer', None) ema :ModelEMA = kwargs.get('ema', None) scaler :GradScaler = kwargs.get('scaler', None) lr_warmup_scheduler :Warmup = kwargs.get('lr_warmup_scheduler', None) # wandb parameters use_wandb = kwargs.get('use_wandb', False) wandb_run = kwargs.get('wandb_run', None) wandb_log_freq = kwargs.get('wandb_log_freq', 10) # gradient accumulation parameters gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1) cur_iters = epoch * len(data_loader) # Zero gradients at the beginning of epoch optimizer.zero_grad() for i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] global_step = epoch * len(data_loader) + i # Create metas as a simple dict without accumulating references metas = {'epoch': epoch, 'step': i, 'global_step': global_step, 'epoch_step': len(data_loader)} # Periodic garbage collection to prevent memory accumulation if i > 0 and i % 100 == 0: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # --- VISUALIZE AUGMENTED IMAGES & TARGETS (before model inference) --- enable_vis = kwargs.get('enable_vis', False) if enable_vis and dist_utils.is_main_process(): visualize_augmented_batch(samples, targets, epoch, i) # ------------------------------------------------------------------------ if scaler is not None: with torch.autocast(device_type=str(device), cache_enabled=True): outputs = model(samples, targets=targets) if torch.isnan(outputs['pred_boxes']).any() or torch.isinf(outputs['pred_boxes']).any(): print(outputs['pred_boxes']) state = model.state_dict() new_state = {} for key, value in model.state_dict().items(): # Replace 'module' with 'model' in each key new_key = key.replace('module.', '') # Add the updated key-value pair to the state dictionary state[new_key] = value new_state['model'] = state dist_utils.save_on_master(new_state, "./NaN.pth") with torch.autocast(device_type=str(device), enabled=False): loss_dict = criterion(outputs, targets, **metas) # Store keys for later reconstruction loss_keys = list(loss_dict.keys()) loss_values = list(loss_dict.values()) loss = sum(loss_values) # Scale loss for gradient accumulation loss = loss / gradient_accumulation_steps scaler.scale(loss).backward() # Only step and zero gradients every gradient_accumulation_steps if (i + 1) % gradient_accumulation_steps == 0: if max_norm > 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() else: # --- VISUALIZE AUGMENTED IMAGES & TARGETS (before model inference) --- enable_vis = kwargs.get('enable_vis', False) if enable_vis and dist_utils.is_main_process(): visualize_augmented_batch(samples, targets, epoch, i) # ------------------------------------------------------------------------ outputs = model(samples, targets=targets) loss_dict = criterion(outputs, targets, **metas) # Store keys for later reconstruction loss_keys = list(loss_dict.keys()) loss_values = list(loss_dict.values()) loss : torch.Tensor = sum(loss_values) # Scale loss for gradient accumulation loss = loss / gradient_accumulation_steps loss.backward() # Only step and zero gradients every gradient_accumulation_steps if (i + 1) % gradient_accumulation_steps == 0: if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() optimizer.zero_grad() # ema if ema is not None: ema.update(model) if self_lr_scheduler: optimizer = lr_scheduler.step(cur_iters + i, optimizer) else: if lr_warmup_scheduler is not None: lr_warmup_scheduler.step() # Recreate loss_dict for logging loss_dict = dict(zip(loss_keys, loss_values)) loss_dict_reduced = dist_utils.reduce_dict(loss_dict) loss_value = sum(loss_dict_reduced.values()) # Clean up references del loss_dict, outputs if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) metric_logger.update(loss=loss_value, **loss_dict_reduced) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) if writer and dist_utils.is_main_process() and global_step % 10 == 0: writer.add_scalar('Loss/total', loss_value.item(), global_step) for j, pg in enumerate(optimizer.param_groups): writer.add_scalar(f'Lr/pg_{j}', pg['lr'], global_step) for k, v in loss_dict_reduced.items(): writer.add_scalar(f'Loss/{k}', v.item(), global_step) # wandb logging if (use_wandb and wandb_run is not None and _WANDB_AVAILABLE and dist_utils.is_main_process() and global_step % wandb_log_freq == 0): log_dict = { 'train/loss_total': loss_value.item(), 'train/learning_rate': optimizer.param_groups[0]['lr'], 'train/epoch': epoch, 'train/step': global_step } # Add individual loss components for k, v in loss_dict_reduced.items(): log_dict[f'train/loss_{k}'] = v.item() wandb.log(log_dict, step=global_step) # Step optimizer if there are remaining accumulated gradients at the end of epoch if (i + 1) % gradient_accumulation_steps != 0: if scaler is not None: if max_norm > 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) scaler.step(optimizer) scaler.update() else: if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() optimizer.zero_grad() # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) # Final cleanup at end of epoch gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @torch.no_grad() def evaluate(model: torch.nn.Module, criterion: torch.nn.Module, postprocessor, data_loader, coco_evaluator: CocoEvaluator, device): model.eval() criterion.eval() coco_evaluator.cleanup() metric_logger = MetricLogger(delimiter=" ") # metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Test:' # iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessor.keys()) iou_types = coco_evaluator.iou_types # coco_evaluator = CocoEvaluator(base_ds, iou_types) # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] for samples, targets in metric_logger.log_every(data_loader, 10, header): samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(samples) orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) results = postprocessor(outputs, orig_target_sizes) # if 'segm' in postprocessor.keys(): # target_sizes = torch.stack([t["size"] for t in targets], dim=0) # results = postprocessor['segm'](results, outputs, orig_target_sizes, target_sizes) res = {target['image_id'].item(): output for target, output in zip(targets, results)} if coco_evaluator is not None: coco_evaluator.update(res) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) if coco_evaluator is not None: coco_evaluator.synchronize_between_processes() # accumulate predictions from all images if coco_evaluator is not None: coco_evaluator.accumulate() coco_evaluator.summarize() stats = {} # stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} if coco_evaluator is not None: if 'bbox' in iou_types: stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() if 'segm' in iou_types: stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() return stats, coco_evaluator