lucid-hf's picture
CI: deploy Docker/PDM Space
98a3af2 verified
"""
DEIM: DETR with Improved Matching for Fast Convergence
Copyright (c) 2024 The DEIM Authors. All Rights Reserved.
---------------------------------------------------------------------------------
Modified from D-FINE (https://github.com/Peterande/D-FINE)
Copyright (c) 2024 D-FINE authors. All Rights Reserved.
"""
import time
import json
import datetime
import torch
from ..misc import dist_utils, stats
try:
import wandb
_WANDB_AVAILABLE = True
except ImportError:
_WANDB_AVAILABLE = False
wandb = None
from ._solver import BaseSolver
from .det_engine import train_one_epoch, evaluate
from ..optim.lr_scheduler import FlatCosineLRScheduler
class DetSolver(BaseSolver):
def fit(self, ):
self.train()
args = self.cfg
n_parameters, model_stats = stats(self.cfg)
print(model_stats)
print("-"*42 + "Start training" + "-"*43)
self.self_lr_scheduler = False
if args.lrsheduler is not None:
iter_per_epoch = len(self.train_dataloader)
print(" ## Using Self-defined Scheduler-{} ## ".format(args.lrsheduler))
self.lr_scheduler = FlatCosineLRScheduler(self.optimizer, args.lr_gamma, iter_per_epoch, total_epochs=args.epoches,
warmup_iter=args.warmup_iter, flat_epochs=args.flat_epoch, no_aug_epochs=args.no_aug_epoch)
self.self_lr_scheduler = True
n_parameters = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
print(f'number of trainable parameters: {n_parameters}')
top1 = 0
best_stat = {'epoch': -1, }
# evaluate again before resume training
if self.last_epoch > 0:
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module,
self.criterion,
self.postprocessor,
self.val_dataloader,
self.evaluator,
self.device
)
for k in test_stats:
best_stat['epoch'] = self.last_epoch
best_stat[k] = test_stats[k][0]
top1 = test_stats[k][0]
print(f'best_stat: {best_stat}')
best_stat_print = best_stat.copy()
start_time = time.time()
start_epoch = self.last_epoch + 1
for epoch in range(start_epoch, args.epoches):
self.train_dataloader.set_epoch(epoch)
# self.train_dataloader.dataset.set_epoch(epoch)
if dist_utils.is_dist_available_and_initialized():
self.train_dataloader.sampler.set_epoch(epoch)
if epoch == self.train_dataloader.collate_fn.stop_epoch:
self.load_resume_state(str(self.output_dir / 'best_stg1.pth'))
self.ema.decay = self.train_dataloader.collate_fn.ema_restart_decay
print(f'Refresh EMA at epoch {epoch} with decay {self.ema.decay}')
train_stats = train_one_epoch(
self.self_lr_scheduler,
self.lr_scheduler,
self.model,
self.criterion,
self.train_dataloader,
self.optimizer,
self.device,
epoch,
max_norm=args.clip_max_norm,
print_freq=args.print_freq,
ema=self.ema,
scaler=self.scaler,
lr_warmup_scheduler=self.lr_warmup_scheduler,
writer=self.writer,
use_wandb=args.use_wandb,
wandb_run=args.wandb_run,
wandb_log_freq=args.wandb.get('log_frequency', 10) if args.wandb else 10,
gradient_accumulation_steps=args.gradient_accumulation_steps,
enable_vis=getattr(args, 'enable_vis', False)
)
if not self.self_lr_scheduler: # update by epoch
if self.lr_warmup_scheduler is None or self.lr_warmup_scheduler.finished():
self.lr_scheduler.step()
self.last_epoch += 1
if self.output_dir and epoch < self.train_dataloader.collate_fn.stop_epoch:
checkpoint_paths = [self.output_dir / 'last.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.checkpoint_freq == 0:
checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
dist_utils.save_on_master(self.state_dict(), checkpoint_path)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module,
self.criterion,
self.postprocessor,
self.val_dataloader,
self.evaluator,
self.device
)
# TODO
for k in test_stats:
if self.writer and dist_utils.is_main_process():
for i, v in enumerate(test_stats[k]):
self.writer.add_scalar(f'Test/{k}_{i}'.format(k), v, epoch)
# wandb logging for validation metrics
if (args.use_wandb and args.wandb_run is not None and _WANDB_AVAILABLE and
dist_utils.is_main_process()):
val_log_dict = {}
for i, v in enumerate(test_stats[k]):
val_log_dict[f'val/{k}_{i}'] = v
val_log_dict['val/epoch'] = epoch
wandb.log(val_log_dict, step=epoch)
if k in best_stat:
best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']
best_stat[k] = max(best_stat[k], test_stats[k][0])
else:
best_stat['epoch'] = epoch
best_stat[k] = test_stats[k][0]
if best_stat[k] > top1:
best_stat_print['epoch'] = epoch
top1 = best_stat[k]
if self.output_dir:
if epoch >= self.train_dataloader.collate_fn.stop_epoch:
checkpoint_path = self.output_dir / 'best_stg2.pth'
dist_utils.save_on_master(self.state_dict(), checkpoint_path)
else:
checkpoint_path = self.output_dir / 'best_stg1.pth'
dist_utils.save_on_master(self.state_dict(), checkpoint_path)
# Save model as wandb artifact
if (args.use_wandb and args.wandb_run is not None and _WANDB_AVAILABLE and
dist_utils.is_main_process() and args.wandb.get('save_artifacts', True)):
artifact = wandb.Artifact(
name=f'model-epoch-{epoch}',
type='model',
description=f'Best model at epoch {epoch} with {k}={best_stat[k]:.4f}'
)
artifact.add_file(str(checkpoint_path))
args.wandb_run.log_artifact(artifact)
best_stat_print[k] = max(best_stat[k], top1)
print(f'best_stat: {best_stat_print}') # global best
if best_stat['epoch'] == epoch and self.output_dir:
if epoch >= self.train_dataloader.collate_fn.stop_epoch:
if test_stats[k][0] > top1:
top1 = test_stats[k][0]
dist_utils.save_on_master(self.state_dict(), self.output_dir / 'best_stg2.pth')
else:
top1 = max(test_stats[k][0], top1)
dist_utils.save_on_master(self.state_dict(), self.output_dir / 'best_stg1.pth')
elif epoch >= self.train_dataloader.collate_fn.stop_epoch:
best_stat = {'epoch': -1, }
self.ema.decay -= 0.0001
self.load_resume_state(str(self.output_dir / 'best_stg1.pth'))
print(f'Refresh EMA at epoch {epoch} with decay {self.ema.decay}')
log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters
}
if self.output_dir and dist_utils.is_main_process():
with (self.output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(self.output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
self.output_dir / "eval" / name)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def val(self, ):
self.eval()
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor,
self.val_dataloader, self.evaluator, self.device)
if self.output_dir:
dist_utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth")
return