# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn import torch.nn.functional as F import torch.distributed.nn import torch.distributed as dist from torch.nn.init import trunc_normal_ from torch.nn.utils import weight_norm import models_dinov2 from models_IB import IF_Module import math class MetaArch(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg student_model_dict = dict() teacher_model_dict = dict() import_student = getattr(models_dinov2, cfg.target_model) student = import_student(img_size=224, patch_size=cfg.patch_size, init_values=1.0, ffn_layer='mlp', block_chunks=0, num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.1) embed_dim = student.embed_dim if cfg.teacher_model == 'vit_base': teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc') elif cfg.teacher_model == 'vit_small': teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc') elif cfg.teacher_model == 'vit_large': teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc') elif cfg.teacher_model == 'vit_giant': teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc') teacher_backbone.eval() student_model_dict['backbone'] = student teacher_model_dict['backbone'] = teacher_backbone.backbone self.embed_dim = embed_dim # initialize parameters and checks self.total_n_global_crops = cfg.batch_size self.student = nn.ModuleDict(student_model_dict) self.teacher = nn.ModuleDict(teacher_model_dict) teacher_embed_dim = teacher_backbone.backbone.embed_dim self.ibot_head = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, teacher_embed_dim)) self.token_head = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, teacher_embed_dim)) self.fea_head = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, teacher_embed_dim)) self.soft_criterion = torch.nn.MSELoss() self.info_bottleneck = IF_Module(embed_dim=embed_dim, num_heads=12, mlp_ratio=4, depth=4) for param in self.teacher.backbone.parameters(): param.requires_grad = False def cal_bpp(self, image, unmask_likelihood, mask_likelihood): b, _, h, w = image.size() num_pixels = b * h * w log_unmask_likelihoods = torch.log(unmask_likelihood) log_mask_likelihoods = torch.log(mask_likelihood) bpp = (log_unmask_likelihoods.sum() + log_mask_likelihoods.sum()) / (-math.log(2) * num_pixels * 1.5) return bpp def forward(self, inputs): global_crops = inputs["collated_global_crops"] masks = inputs["collated_masks"] mask_indices_list = inputs["mask_indices_list"] n_masked_patches = mask_indices_list.shape[0] upperbound = inputs["upperbound"] n_global_crops = 1 # compute teacher output # @torch.no_grad() def compute_teacher_output(): with torch.no_grad(): teacher_backbone_output_dict = self.teacher.backbone(global_crops, is_training=True) teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"] teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] _dim = teacher_patch_tokens.shape[-1] # mask teacher patch tokens buffer_tensor_teacher = teacher_patch_tokens.new_zeros(upperbound, _dim) torch.index_select( teacher_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list, out=buffer_tensor_teacher[:n_masked_patches], ) teacher_patch_tokens_masked = buffer_tensor_teacher[:n_masked_patches] return teacher_cls_tokens, teacher_patch_tokens, teacher_patch_tokens_masked # get the teacher outputs ( teacher_cls_tokens, teacher_patch_tokens, teacher_patch_tokens_masked ) = compute_teacher_output() cur_masks = masks if self.cfg.mask_probability > 0 else None student_backbone_output_dict, student_backbone_output_dict_unmask = self.student.backbone( [global_crops, global_crops], masks=[cur_masks, None], is_training=True ) student_cls_token_unmask = student_backbone_output_dict_unmask["x_norm_clstoken"] student_patch_tokens_unmask = student_backbone_output_dict_unmask["x_norm_patchtokens"] student_patch_tokens = student_backbone_output_dict["x_norm_patchtokens"] # calculate bitrate student_patch_tokens_unmask, unmask_likelihood = self.info_bottleneck(student_patch_tokens_unmask, is_training=True) student_patch_tokens, mask_likelihood = self.info_bottleneck(student_patch_tokens, is_training=True) bpp = self.cal_bpp(global_crops, unmask_likelihood, mask_likelihood) # mask student patch tokens _dim = student_patch_tokens.shape[-1] buffer_tensor_student = student_patch_tokens.new_zeros(upperbound, _dim) buffer_tensor_student[:n_masked_patches].copy_( torch.index_select(student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list) ) ## projection head student_patch_tokens_unmask = self.fea_head(student_patch_tokens_unmask) student_cls_token_unmask = self.token_head(student_cls_token_unmask) tokens_after_head = self.ibot_head(buffer_tensor_student) student_patch_tokens_masked = tokens_after_head[:n_masked_patches] ## token objective distillation_loss_token = self.soft_criterion(student_cls_token_unmask, teacher_cls_tokens) ## fea objective student_whole_fea = torch.cat((student_cls_token_unmask.unsqueeze(1),student_patch_tokens_unmask),dim=1) teacher_whole_fea = torch.cat((teacher_cls_tokens.unsqueeze(1),teacher_patch_tokens),dim=1) distillation_loss_fea = self.soft_criterion(student_whole_fea, teacher_whole_fea) ## patch objective patch_loss = self.soft_criterion(student_patch_tokens_masked, teacher_patch_tokens_masked) # coefficient token_loss = self.cfg.lambda_token * distillation_loss_token fea_loss = self.cfg.lambda_fea * distillation_loss_fea patch_loss_weighted = self.cfg.lambda_patch * patch_loss # print(f"self.cfg: {self.cfg}") # print(f"self.cfg.lambda_token: {self.cfg.lambda_token}, self.cfg.lambda_fea: {self.cfg.lambda_fea}, self.cfg.lambda_patch: {self.cfg.lambda_patch}") # compute the total loss total_loss = patch_loss_weighted + fea_loss + token_loss + 0.48 * bpp # task_loss = patch_loss + fea_loss + token_loss task_loss = patch_loss + distillation_loss_fea + distillation_loss_token # return the final loss dict loss_dict = {"bpp_loss": bpp, "patch_loss": patch_loss, "fea_loss": distillation_loss_fea, "token_loss": token_loss, "loss": total_loss, "task_loss": task_loss, } return loss_dict