Spaces:
Configuration error
Configuration error
| from torch.optim import lr_scheduler | |
| from models.backbone.skip import skip | |
| def get_scheduler(optimizer, opt): | |
| if opt.lr_policy == "linear": | |
| def lambda_rule(epoch): | |
| lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) | |
| return lr_l | |
| scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) | |
| elif opt.lr_policy == "step": | |
| scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) | |
| elif opt.lr_policy == "plateau": | |
| scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, threshold=0.01, patience=5) | |
| elif opt.lr_policy == "cosine": | |
| scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) | |
| else: | |
| return NotImplementedError("learning rate policy [%s] is not implemented", opt.lr_policy) | |
| return scheduler | |
| def define_G(cfg): | |
| netG = skip( | |
| 3, | |
| 4, | |
| num_channels_down=[cfg["skip_n33d"]] * cfg["num_scales"] | |
| if isinstance(cfg["skip_n33d"], int) | |
| else cfg["skip_n33d"], | |
| num_channels_up=[cfg["skip_n33u"]] * cfg["num_scales"] | |
| if isinstance(cfg["skip_n33u"], int) | |
| else cfg["skip_n33u"], | |
| num_channels_skip=[cfg["skip_n11"]] * cfg["num_scales"] | |
| if isinstance(cfg["skip_n11"], int) | |
| else cfg["skip_n11"], | |
| need_bias=True, | |
| ) | |
| return netG | |