Spaces:
Sleeping
Sleeping
| """ | |
| DEIM: DETR with Improved Matching for Fast Convergence | |
| Copyright (c) 2024 The DEIM Authors. All Rights Reserved. | |
| """ | |
| import math | |
| from functools import partial | |
| def flat_cosine_schedule(total_iter, warmup_iter, flat_iter, no_aug_iter, current_iter, init_lr, min_lr): | |
| """ | |
| Computes the learning rate using a warm-up, flat, and cosine decay schedule. | |
| Args: | |
| total_iter (int): Total number of iterations. | |
| warmup_iter (int): Number of iterations for warm-up phase. | |
| flat_iter (int): Number of iterations for flat phase. | |
| no_aug_iter (int): Number of iterations for no-augmentation phase. | |
| current_iter (int): Current iteration. | |
| init_lr (float): Initial learning rate. | |
| min_lr (float): Minimum learning rate. | |
| Returns: | |
| float: Calculated learning rate. | |
| """ | |
| if current_iter <= warmup_iter: | |
| return init_lr * (current_iter / float(warmup_iter)) ** 2 | |
| elif warmup_iter < current_iter <= flat_iter: | |
| return init_lr | |
| elif current_iter >= total_iter - no_aug_iter: | |
| return min_lr | |
| else: | |
| cosine_decay = 0.5 * (1 + math.cos(math.pi * (current_iter - flat_iter) / | |
| (total_iter - flat_iter - no_aug_iter))) | |
| return min_lr + (init_lr - min_lr) * cosine_decay | |
| class FlatCosineLRScheduler: | |
| """ | |
| Learning rate scheduler with warm-up, optional flat phase, and cosine decay following RTMDet. | |
| Args: | |
| optimizer (torch.optim.Optimizer): Optimizer instance. | |
| lr_gamma (float): Scaling factor for the minimum learning rate. | |
| iter_per_epoch (int): Number of iterations per epoch. | |
| total_epochs (int): Total number of training epochs. | |
| warmup_epochs (int): Number of warm-up epochs. | |
| flat_epochs (int): Number of flat epochs (for flat-cosine scheduler). | |
| no_aug_epochs (int): Number of no-augmentation epochs. | |
| """ | |
| def __init__(self, optimizer, lr_gamma, iter_per_epoch, total_epochs, | |
| warmup_iter, flat_epochs, no_aug_epochs, scheduler_type="cosine"): | |
| self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] | |
| self.min_lrs = [base_lr * lr_gamma for base_lr in self.base_lrs] | |
| total_iter = int(iter_per_epoch * total_epochs) | |
| no_aug_iter = int(iter_per_epoch * no_aug_epochs) | |
| flat_iter = int(iter_per_epoch * flat_epochs) | |
| print(self.base_lrs, self.min_lrs, total_iter, warmup_iter, flat_iter, no_aug_iter) | |
| self.lr_func = partial(flat_cosine_schedule, total_iter, warmup_iter, flat_iter, no_aug_iter) | |
| def step(self, current_iter, optimizer): | |
| """ | |
| Updates the learning rate of the optimizer at the current iteration. | |
| Args: | |
| current_iter (int): Current iteration. | |
| optimizer (torch.optim.Optimizer): Optimizer instance. | |
| """ | |
| for i, group in enumerate(optimizer.param_groups): | |
| group["lr"] = self.lr_func(current_iter, self.base_lrs[i], self.min_lrs[i]) | |
| return optimizer | |