Spaces:
Sleeping
Sleeping
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| from torch.optim.lr_scheduler import LRScheduler | |
| from ..core import register | |
| class Warmup(object): | |
| def __init__(self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int=-1) -> None: | |
| self.lr_scheduler = lr_scheduler | |
| self.warmup_end_values = [pg['lr'] for pg in lr_scheduler.optimizer.param_groups] | |
| self.last_step = last_step | |
| self.warmup_duration = warmup_duration | |
| self.step() | |
| def state_dict(self): | |
| return {k: v for k, v in self.__dict__.items() if k != 'lr_scheduler'} | |
| def load_state_dict(self, state_dict): | |
| self.__dict__.update(state_dict) | |
| def get_warmup_factor(self, step, **kwargs): | |
| raise NotImplementedError | |
| def step(self, ): | |
| self.last_step += 1 | |
| if self.last_step >= self.warmup_duration: | |
| return | |
| factor = self.get_warmup_factor(self.last_step) | |
| for i, pg in enumerate(self.lr_scheduler.optimizer.param_groups): | |
| pg['lr'] = factor * self.warmup_end_values[i] | |
| def finished(self, ): | |
| if self.last_step >= self.warmup_duration: | |
| return True | |
| return False | |
| class LinearWarmup(Warmup): | |
| def __init__(self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1) -> None: | |
| super().__init__(lr_scheduler, warmup_duration, last_step) | |
| def get_warmup_factor(self, step): | |
| return min(1.0, (step + 1) / self.warmup_duration) | |