""" Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ import torch import torch.nn as nn from ..misc import (MetricLogger, SmoothedValue, reduce_dict) def train_one_epoch(model: nn.Module, criterion: nn.Module, dataloader, optimizer, ema, epoch, device): """ """ model.train() metric_logger = MetricLogger(delimiter=" ") metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) print_freq = 100 header = 'Epoch: [{}]'.format(epoch) for imgs, labels in metric_logger.log_every(dataloader, print_freq, header): imgs = imgs.to(device) labels = labels.to(device) preds = model(imgs) loss: torch.Tensor = criterion(preds, labels, epoch) optimizer.zero_grad() loss.backward() optimizer.step() if ema is not None: ema.update(model) loss_reduced_values = {k: v.item() for k, v in reduce_dict({'loss': loss}).items()} metric_logger.update(**loss_reduced_values) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} return stats @torch.no_grad() def evaluate(model, criterion, dataloader, device): model.eval() metric_logger = MetricLogger(delimiter=" ") # metric_logger.add_meter('acc', SmoothedValue(window_size=1, fmt='{global_avg:.4f}')) # metric_logger.add_meter('loss', SmoothedValue(window_size=1, fmt='{value:.2f}')) metric_logger.add_meter('acc', SmoothedValue(window_size=1)) metric_logger.add_meter('loss', SmoothedValue(window_size=1)) header = 'Test:' for imgs, labels in metric_logger.log_every(dataloader, 10, header): imgs, labels = imgs.to(device), labels.to(device) preds = model(imgs) acc = (preds.argmax(dim=-1) == labels).sum() / preds.shape[0] loss = criterion(preds, labels) dict_reduced = reduce_dict({'acc': acc, 'loss': loss}) reduced_values = {k: v.item() for k, v in dict_reduced.items()} metric_logger.update(**reduced_values) metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} return stats