Spaces:
Running
on
T4
Running
on
T4
| # -*- coding: utf-8 -*- | |
| import sys | |
| import os | |
| import torch | |
| # Import files from the local folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from opt import opt | |
| from architecture.grl import GRL # This place need to adjust for different models | |
| from train_code.train_master import train_master | |
| # Mixed precision training | |
| scaler = torch.cuda.amp.GradScaler() | |
| class train_grl(train_master): | |
| def __init__(self, options, args) -> None: | |
| super().__init__(options, args, "grl") # Pass a model name unique code | |
| def loss_init(self): | |
| # Prepare pixel loss | |
| self.pixel_loss_load() | |
| def call_model(self): | |
| patch_size = 144 | |
| window_size = 8 | |
| if opt['model_size'] == "small": | |
| # GRL small model | |
| self.generator = GRL( | |
| upscale = opt['scale'], | |
| img_size = patch_size, | |
| window_size = 8, | |
| depths = [4, 4, 4, 4], | |
| embed_dim = 128, | |
| num_heads_window = [2, 2, 2, 2], | |
| num_heads_stripe = [2, 2, 2, 2], | |
| mlp_ratio = 2, | |
| qkv_proj_type = "linear", | |
| anchor_proj_type = "avgpool", | |
| anchor_window_down_factor = 2, | |
| out_proj_type = "linear", | |
| conv_type = "1conv", | |
| upsampler = "pixelshuffle", | |
| ).cuda() | |
| elif opt['model_size'] == "tiny": | |
| # GRL tiny model | |
| self.generator = GRL( | |
| upscale = opt['scale'], | |
| img_size = 64, | |
| window_size = 8, | |
| depths = [4, 4, 4, 4], | |
| embed_dim = 64, | |
| num_heads_window = [2, 2, 2, 2], | |
| num_heads_stripe = [2, 2, 2, 2], | |
| mlp_ratio = 2, | |
| qkv_proj_type = "linear", | |
| anchor_proj_type = "avgpool", | |
| anchor_window_down_factor = 2, | |
| out_proj_type = "linear", | |
| conv_type = "1conv", | |
| upsampler = "pixelshuffledirect", | |
| ).cuda() | |
| elif opt['model_size'] == "tiny2": | |
| # GRL tiny model | |
| self.generator = GRL( | |
| upscale = opt['scale'], | |
| img_size = 64, | |
| window_size = 8, | |
| depths = [4, 4, 4, 4], | |
| embed_dim = 64, | |
| num_heads_window = [2, 2, 2, 2], | |
| num_heads_stripe = [2, 2, 2, 2], | |
| mlp_ratio = 2, | |
| qkv_proj_type = "linear", | |
| anchor_proj_type = "avgpool", | |
| anchor_window_down_factor = 2, | |
| out_proj_type = "linear", | |
| conv_type = "1conv", | |
| upsampler = "nearest+conv", # Change | |
| ).cuda() | |
| else: | |
| raise NotImplementedError("We don't support such model size in GRL model") | |
| # self.generator = torch.compile(self.generator).cuda() # Don't use this for 3090Ti | |
| self.generator.train() | |
| def run(self): | |
| self.master_run() | |
| def calculate_loss(self, gen_hr, imgs_hr): | |
| # Define the loss function here | |
| # Generator pixel loss (l1 loss): generated vs. GT | |
| l_g_pix = self.cri_pix(gen_hr, imgs_hr, self.batch_idx) | |
| self.weight_store["pixel_loss"] = l_g_pix | |
| self.generator_loss += l_g_pix | |
| def tensorboard_report(self, iteration): | |
| # self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration) | |
| self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) | |