| | from logging import getLogger |
| | import math |
| | import os |
| | from typing import Union, Tuple |
| | from types import MethodType |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from torch.nn.utils import parametrize |
| | from torch.nn.utils.parametrizations import _SpectralNorm |
| |
|
| | from timm.models.vision_transformer import Attention, Mlp |
| |
|
| | _EPS = 1e-5 |
| |
|
| |
|
| | class _SNReweight(_SpectralNorm): |
| | def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs): |
| | super().__init__(weight, *args, **kwargs) |
| |
|
| | self.alpha = alpha |
| | self.version = version |
| | self.register_buffer('_sn_version', torch.tensor(version)) |
| |
|
| | if init_norm_to_current: |
| | |
| | init_scale = self._get_sigma(weight).item() |
| | else: |
| | init_scale = 1.0 |
| |
|
| | if version == 1: |
| | init_value = init_scale |
| | elif version == 2: |
| | t = init_scale - alpha |
| | if t < _EPS: |
| | getLogger("spectral_reparam").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.') |
| | t = _EPS |
| |
|
| | init_value = math.log(math.exp(t) - 1) |
| | else: |
| | raise ValueError(f'Unsupported version: {version}') |
| |
|
| | |
| | self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device)) |
| |
|
| | |
| | def _get_sigma(self, weight: torch.Tensor) -> torch.Tensor: |
| | if weight.ndim == 1: |
| | |
| | sigma = weight.norm() |
| | else: |
| | weight_mat = self._reshape_weight_to_matrix(weight) |
| | if self.training: |
| | self._power_method(weight_mat, self.n_power_iterations) |
| | |
| | u = self._u.clone(memory_format=torch.contiguous_format) |
| | v = self._v.clone(memory_format=torch.contiguous_format) |
| | |
| | |
| | |
| | sigma = torch.dot(u, torch.mv(weight_mat, v)) |
| |
|
| | return sigma + self.eps |
| |
|
| | def forward(self, weight: torch.Tensor, *args, **kwargs): |
| | dtype = weight.dtype |
| | sigma = self._get_sigma(weight, *args, **kwargs) |
| |
|
| | if self.version == 1: |
| | scale = self.scale |
| | elif self.version == 2: |
| | scale = F.softplus(self.scale) + self.alpha |
| | else: |
| | raise ValueError(f'Unsupported version: {self.version}') |
| |
|
| | scale = scale.float() / sigma.float() |
| |
|
| | y = weight * scale |
| |
|
| | if dtype in (torch.float16, torch.bfloat16): |
| | y = y.to(dtype) |
| | return y |
| |
|
| | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
| | version_key = f'{prefix}_sn_version' |
| | if version_key not in state_dict: |
| | self.version = 1 |
| | state_dict[version_key] = torch.tensor(1) |
| | return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
| |
|
| |
|
| | class _AttnSNReweight(nn.Module): |
| | def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs): |
| | super().__init__() |
| |
|
| | parts = weight.split(weight.shape[0] // 3, dim=0) |
| |
|
| | ct = 2 if not renorm_values else 3 |
| |
|
| | self.parts = nn.ModuleList([ |
| | _SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs) if i < ct else nn.Identity() |
| | for i, p in enumerate(parts) |
| | ]) |
| |
|
| | def forward(self, weight: torch.Tensor, *args, **kwargs): |
| | parts = weight.split(weight.shape[0] // 3, dim=0) |
| |
|
| | parts = [ |
| | fn(p) |
| | for fn, p in zip(self.parts, parts) |
| | ] |
| |
|
| | return torch.cat(parts, dim=0) |
| |
|
| |
|
| | def enable_spectral_reparam(model: nn.Module, |
| | n_power_iterations: int = 1, |
| | eps: float = 1e-6, |
| | init_norm_to_current: bool = False, |
| | renorm_values: bool = True, |
| | renorm_mlp: bool = True): |
| | |
| | for mod in model.modules(): |
| | if isinstance(mod, Attention): |
| | parametrize.register_parametrization( |
| | mod.qkv, |
| | 'weight', |
| | _AttnSNReweight(mod.qkv.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current, renorm_values=renorm_values), |
| | ) |
| | pass |
| | elif isinstance(mod, Mlp) and renorm_mlp: |
| | parametrize.register_parametrization( |
| | mod.fc1, |
| | 'weight', |
| | _SNReweight(mod.fc1.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current), |
| | ) |
| | parametrize.register_parametrization( |
| | mod.fc2, |
| | 'weight', |
| | _SNReweight(mod.fc2.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current), |
| | ) |
| | pass |
| |
|
| |
|
| | def configure_spectral_reparam_from_args(model: nn.Module, args): |
| | spectral_reparam = getattr(args, 'spectral_reparam', False) |
| | if isinstance(spectral_reparam, bool) and spectral_reparam: |
| | enable_spectral_reparam(model, init_norm_to_current=args.pretrained) |
| | elif isinstance(spectral_reparam, dict): |
| | enable_spectral_reparam( |
| | model, |
| | n_power_iterations=spectral_reparam.get('n_power_iterations', 1), |
| | eps=spectral_reparam.get('eps', 1e-12), |
| | init_norm_to_current=args.pretrained, |
| | ) |
| |
|
| |
|
| | def disable_spectral_reparam(model: nn.Module): |
| | for mod in model.modules(): |
| | if isinstance(mod, Attention): |
| | parametrize.remove_parametrizations(mod.qkv, 'weight') |
| | pass |
| | elif isinstance(mod, Mlp): |
| | parametrize.remove_parametrizations(mod.fc1, 'weight') |
| | parametrize.remove_parametrizations(mod.fc2, 'weight') |
| | pass |
| |
|
| |
|
| | if __name__ == '__main__': |
| | import argparse |
| | from . import radio_model as create_model |
| |
|
| | parser = argparse.ArgumentParser(description='Remove parametrization from state dict') |
| | parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load') |
| | parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint') |
| | parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields') |
| | parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict') |
| |
|
| | args = parser.parse_args() |
| |
|
| | if not args.output: |
| | chk_dir, chk_name = os.path.split(args.checkpoint) |
| | args.output = os.path.join(chk_dir, f'clean_{chk_name}') |
| | print(f'Set output to "{args.output}"') |
| |
|
| | chk = torch.load(args.checkpoint, map_location='cpu', mmap=True) |
| |
|
| | model = create_model.create_model_from_args(chk['args']) |
| |
|
| | key = 'base_model.' |
| | mod_state = dict() |
| | extra_state = dict() |
| | for k, v in chk['state_dict'].items(): |
| | if k.startswith(key): |
| | mod_state[k[len(key):]] = v |
| | else: |
| | extra_state[k] = v |
| |
|
| | chk_load_info = model.load_state_dict(mod_state, strict=args.strict) |
| | if chk_load_info.unexpected_keys or chk_load_info.missing_keys: |
| | print(chk_load_info) |
| |
|
| | if chk['args'].spectral_reparam: |
| | disable_spectral_reparam(model) |
| |
|
| | if hasattr(chk['args'], 'dtype'): |
| | model.to(dtype=chk['args'].dtype) |
| |
|
| | mod_state = model.state_dict() |
| | final_state = dict() |
| | final_state.update({f'{key}{k}': v for k, v in mod_state.items()}) |
| | final_state.update(extra_state) |
| |
|
| | chk['state_dict'] = final_state |
| | chk['args'].spectral_reparam = False |
| |
|
| | if args.release: |
| | chk = { |
| | 'arch': chk['arch'], |
| | 'epoch': chk['epoch'], |
| | 'state_dict': chk['state_dict'], |
| | 'args': chk['args'], |
| | } |
| |
|
| | torch.save(chk, args.output) |
| | pass |
| |
|