import torch.optim as optim import torch.nn as nn import torch import itertools def add_full_model_gradient_clipping(optim, clip_norm_val): class FullModelGradientClippingOptimizer(optim): def step(self, closure=None): all_params = itertools.chain(*[x["params"] for x in self.param_groups]) torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) super().step(closure=closure) return FullModelGradientClippingOptimizer class Optimizer(object): def __init__(self, models, training_params, sep_lr=None, sep_params=None, gradient_clip=0): params = [] for model in models: if isinstance(model, nn.Parameter): params += [model] else: params += list(model.parameters()) if sep_lr is not None: print(sep_lr) add_params = [] for model in sep_params: if isinstance(model, nn.Parameter): add_params += [model] else: add_params += list(model.parameters()) params = [{'params': params}, {'params': add_params, 'lr': sep_lr}] self.lr = training_params['lr'] self.weight_decay = training_params['weight_decay'] method = training_params['optimizer'] if method == 'SGD': self.momentum = training_params['momentum'] if gradient_clip > 0: self.optim = add_full_model_gradient_clipping(optim.SGD, gradient_clip)(params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) else: self.optim = optim.SGD(params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) elif method == 'AdamW': self.optim = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay) else: raise Exception('{} is not supported'.format(method)) schedule_name = training_params['lr_schedule'] schedule_params = training_params['schedule_params'] if schedule_name == 'CosineAnnealingLR': schedule_params['T_max'] = training_params['inter_val'] * 4 self.lr_schedule = getattr(optim.lr_scheduler, schedule_name)(self.optim, **schedule_params) def update_lr(self): self.lr_schedule.step() def z_grad(self): self.optim.zero_grad() def g_step(self): self.optim.step() def get_lr(self): for param_group in self.optim.param_groups: return param_group['lr']