Spaces:
Configuration error
Configuration error
| import torch.nn | |
| from models.clip_extractor import ClipExtractor | |
| from util.losses import LossG | |
| class AtlasLoss(torch.nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.clip_extractor = ClipExtractor(config) | |
| common_config = { | |
| key: config[key] | |
| for key in [ | |
| "lambda_composition", | |
| "lambda_sparsity", | |
| "lambda_screen", | |
| "lambda_alpha_l1", | |
| "lambda_alpha_l0", | |
| "text_criterion", | |
| "clip_model_name", | |
| "bootstrap_epoch", | |
| "lambda_bootstrap", | |
| "relevancy_num_layers", | |
| "lambda_structure", | |
| "bootstrap_text", | |
| "bootstrap_scheduler", | |
| "bootstrapping_min_cover", | |
| "use_negative_bootstrap", | |
| "bootstrap_negative_text", | |
| "bootstrap_negative_map_threshold", | |
| "lambda_bootstrap_min", | |
| "device", | |
| ] | |
| } | |
| texts_config = { | |
| "screen_text": config["screen_text"], | |
| "comp_text": config["comp_text"], | |
| "src_text": config["src_text"], | |
| } | |
| common_config.update(texts_config) | |
| self.loss = LossG(common_config, self.clip_extractor) | |
| self.config = config | |
| def forward(self, outputs, inputs): | |
| losses = {} | |
| if self.config["finetune_background"]: | |
| inputs["input_crop"] = [el.squeeze(0) for el in outputs["background"]["cnn_inputs"]] | |
| losses["background"] = self.loss(outputs["background"], inputs) | |
| elif self.config["finetune_foreground"]: | |
| inputs["input_crop"] = [el.squeeze(0) for el in outputs["foreground"]["cnn_inputs"]] | |
| losses["foreground"] = self.loss(outputs["foreground"], inputs) | |
| return losses | |