Spaces:
Configuration error
Configuration error
| from pathlib import Path | |
| import imageio | |
| import torch | |
| from util.util import tensor2im | |
| class DataLogger: | |
| def __init__(self, config, dataset): | |
| self.layers_edits = { | |
| "background": dataset.original_video.detach().cpu(), | |
| "foreground": dataset.original_video.detach().cpu(), | |
| } | |
| self.alpha_video = dataset.all_alpha.detach().cpu() | |
| self.config = config | |
| self.layer_name = "foreground" if config["finetune_foreground"] else "background" | |
| def log_data(self, epoch, lr, losses, model, dataset): | |
| log_data = {} | |
| for layer, layer_losses in losses.items(): | |
| for key in layer_losses: | |
| log_data[f"Loss/{layer}_{key}"] = layer_losses[key].detach() | |
| log_data["epoch"] = epoch | |
| log_data["lr"] = lr | |
| if epoch % self.config["log_images_freq"] == 0: | |
| layer = self.layer_name | |
| edited_atlas_dict, edit_dict, uv_mask = dataset.render_video_from_atlas(model, layer=layer) | |
| alpha_of_edit = None | |
| edit_only = None | |
| for key in edited_atlas_dict.keys(): | |
| if key != "edit": | |
| masked = tensor2im(edited_atlas_dict[key].detach().cpu() * uv_mask) | |
| log_data[f"Atlases/{layer}_masked_{key}"] = ( | |
| wandb.Image(masked) if self.config["use_wandb"] else masked | |
| ) | |
| if key == "alpha": | |
| alpha_of_edit = edited_atlas_dict[key].detach().cpu() * uv_mask | |
| if key == "edit": | |
| edit_only = edited_atlas_dict[key].detach().cpu() * uv_mask | |
| rgba_edit = tensor2im(torch.cat((edit_only, alpha_of_edit[:, [0]]), dim=1)) | |
| log_data[f"Atlases/{layer}_rgba_layer"] = wandb.Image(rgba_edit) if self.config["use_wandb"] else rgba_edit | |
| for key in edit_dict.keys(): | |
| if key != "composite" and key != "edit": | |
| video = (255 * edit_dict[key].detach().cpu()).to(torch.uint8) | |
| log_data[f"Videos/{layer}_{key}"] = ( | |
| wandb.Video(video, fps=10, format="mp4") if self.config["use_wandb"] else video | |
| ) | |
| if self.config[f"finetune_{layer}"]: | |
| self.layers_edits[layer] = edit_dict["composite"].detach().cpu() | |
| full_video = ( | |
| self.alpha_video * self.layers_edits["foreground"] | |
| + (1 - self.alpha_video) * self.layers_edits["background"] | |
| ) | |
| full_video = (255 * full_video.detach().cpu()).to(torch.uint8) | |
| log_data["Videos/full_video"] = ( | |
| wandb.Video(full_video, fps=10, format="mp4") if self.config["use_wandb"] else full_video | |
| ) | |
| # save model checkpoint | |
| if epoch > self.config["save_model_starting_epoch"]: | |
| filename = f"checkpoint_epoch_{epoch}.pt" | |
| dict_to_save = { | |
| "model": model.state_dict(), | |
| } | |
| if self.config["use_wandb"]: | |
| checkpoint_path = f"{wandb.run.dir}/{filename}" | |
| else: | |
| checkpoint_path = f"{self.config['results_folder']}/{filename}" | |
| torch.save(dict_to_save, checkpoint_path) | |
| return log_data | |
| def save_locally(self, log_data): | |
| path = Path(self.config["results_folder"], str(log_data["epoch"])) | |
| path.mkdir(parents=True, exist_ok=True) | |
| for key in log_data.keys(): | |
| save_name = key.replace("/", "_") | |
| if key.startswith("Videos"): | |
| imageio.mimwrite(f"{path}/{save_name}.mp4", log_data[key].permute(0, 2, 3, 1)) | |
| elif key.startswith("Atlases"): | |
| imageio.imwrite(f"{path}/{save_name}.png", log_data[key]) | |