Spaces:
Runtime error
Runtime error
| import fengshen.data.hubert.hubert_dataset as datasets | |
| from fengshen.data.universal_datamodule import UniversalDataModule | |
| from transformers import HubertConfig, HubertModel | |
| # from transformers.models.hubert.modeling_hubert import _compute_mask_indices | |
| import argparse | |
| from fairseq.data import Dictionary | |
| from pytorch_lightning import ( | |
| LightningModule, | |
| Trainer, | |
| loggers, | |
| ) | |
| from pytorch_lightning.callbacks import LearningRateMonitor | |
| import torch | |
| import os | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| class LabelEncoder(object): | |
| def __init__(self, dictionary: Dictionary): | |
| self.dictionary = dictionary | |
| def __call__(self, label: str): | |
| return self.dictionary.encode_line( | |
| label, | |
| append_eos=False, | |
| add_if_not_exist=False, | |
| ) | |
| class HubertPretrainDataLoader(): | |
| def __init__(self, args): | |
| self.cfg = args | |
| self.dictionaries = self.load_dictionaries() | |
| self.load_datasets = {} | |
| # TODO 改成HuggingFace Tokenizer | |
| def load_dictionaries(self): | |
| label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir | |
| dictionaries = [ | |
| Dictionary.load(f"{label_dir}/dict.{label}.txt") | |
| for label in self.cfg.labels | |
| ] | |
| return dictionaries | |
| def get_label_dir(self): | |
| if self.cfg.label_dir is None: | |
| return self.cfg.data | |
| return self.cfg.label_dir | |
| def datasets(self): | |
| return self.load_datasets | |
| def load_dataset(self, split: str, **kwargs): | |
| manifest = f"{self.cfg.data}/{split}.tsv" | |
| dicts = self.dictionaries | |
| pad_list = [dict.pad() for dict in dicts] | |
| eos_list = [dict.eos() for dict in dicts] | |
| procs = [LabelEncoder(dict) for dict in dicts] | |
| paths = [f"{self.get_label_dir()}/{split}.{lb}" for lb in self.cfg.labels] | |
| # hubert v1: pad_audio=True, random_crop=False; | |
| self.load_datasets[split] = datasets.HubertDataset( | |
| manifest, | |
| sample_rate=self.cfg.sample_rate, | |
| label_paths=paths, | |
| label_rates=self.cfg.label_rate, | |
| pad_list=pad_list, | |
| eos_list=eos_list, | |
| label_processors=procs, | |
| max_keep_sample_size=self.cfg.max_keep_size, | |
| min_keep_sample_size=self.cfg.min_sample_size, | |
| max_sample_size=self.cfg.max_sample_size, | |
| pad_audio=self.cfg.pad_audio, | |
| normalize=self.cfg.normalize, | |
| store_labels=False, | |
| random_crop=self.cfg.random_crop, | |
| single_target=self.cfg.single_target, | |
| ) | |
| def perpare_data(args): | |
| loader = HubertPretrainDataLoader(args) | |
| loader.load_dataset('train') | |
| loader.load_dataset('valid') | |
| return loader | |
| class HubertLightning(LightningModule): | |
| def add_module_specific_args(parent_parser): | |
| parser = parent_parser.add_argument_group('HuBert Lightning') | |
| parser.add_argument('--pred_masked_weight', type=float, default=1.0) | |
| parser.add_argument('--logit_temp', type=float, default=1.0) | |
| parser.add_argument('--loss_weights', type=float, nargs='+') | |
| # parser.add_argument('--mask_prob', type=float, default=0.65) | |
| # parser.add_argument('--mask_length', type=int, default=10) | |
| # parser.add_argument('--mask_selection', type=str, default='static', | |
| # choice=["static", "uniform", "normal", "poisson"]) | |
| # parser.add_argument('--mask_other', type=float, default=0) | |
| # parser.add_argument('--no_mask_overlap', type=bool, default=False) | |
| # parser.add_argument('--mask_min_space', type=int, default=1) | |
| return parent_parser | |
| def __init__(self, args, loader, ** kwargs) -> None: | |
| super().__init__() | |
| self.save_hyperparameters(args) | |
| config = HubertConfig.from_pretrained(args.model_path) | |
| self.config = config | |
| self.model = HubertModel(config=config) | |
| self.num_classes = [len(d) for d in loader.dictionaries] | |
| self.label_embs_concat = nn.Parameter( | |
| torch.FloatTensor(sum(self.num_classes), self.config.conv_dim[-1] // 2) | |
| ) | |
| self.final_proj = nn.Linear( | |
| self.config.hidden_size, self.config.conv_dim[-1] // 2 * len(loader.dictionaries) | |
| ) | |
| nn.init.uniform_(self.label_embs_concat) | |
| def setup(self, stage) -> None: | |
| if stage == 'fit': | |
| train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() | |
| # Calculate total steps | |
| if self.trainer.max_epochs > 0: | |
| world_size = self.trainer.world_size | |
| tb_size = self.hparams.train_batchsize * max(1, world_size) | |
| ab_size = self.trainer.accumulate_grad_batches | |
| self.total_steps = (len(train_loader.dataset) * | |
| self.trainer.max_epochs // tb_size) // ab_size | |
| else: | |
| self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches | |
| print('Total steps: {}' .format(self.total_steps)) | |
| def configure_optimizers(self): | |
| from fengshen.models.model_utils import configure_optimizers | |
| return configure_optimizers(self) | |
| def compute_nce(self, x, pos, negs): | |
| neg_is_pos = (pos == negs).all(-1) | |
| pos = pos.unsqueeze(0) | |
| targets = torch.cat([pos, negs], dim=0) | |
| logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) | |
| logits /= self.hparams.logit_temp | |
| if neg_is_pos.any(): | |
| logits[1:][neg_is_pos] = float("-inf") | |
| logits = logits.transpose(0, 1) # (num_x, num_cls+1) | |
| return logits | |
| def forward(self, **batch): | |
| target_list = batch['target_list'] | |
| padding_mask = batch['net_input']['padding_mask'] | |
| input_values = batch['net_input']['source'] | |
| output = self.model(input_values=input_values, | |
| attention_mask=padding_mask, | |
| target_list=target_list, | |
| mask_time_indices=None, | |
| return_dict=False) | |
| def compute_pred(proj_x, target, label_embs): | |
| # compute logits for the i-th label set | |
| y = torch.index_select(label_embs, 0, target.long()) | |
| negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) | |
| # proj_x: (S, D) | |
| # y: (S, D) | |
| # negs: (Neg, S, D) | |
| return self.compute_nce(proj_x, y, negs) | |
| label_embs_list = self.label_embs_concat.split(self.num_classes, 0) | |
| x, extra_losses, target_list, mask_indices, padding_mask = output[ | |
| 0], output[-4], output[-3], output[-2], output[-1] | |
| masked_indices = torch.logical_and(~padding_mask, mask_indices) | |
| proj_x_m = self.final_proj(x[masked_indices]) | |
| proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) | |
| logp_m_list = [ | |
| compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) | |
| for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list)) | |
| ] | |
| targ_m_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_m_list] | |
| loss = 0.0 | |
| loss_m_list = [] | |
| for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): | |
| loss_m = F.cross_entropy(logp_m, targ_m) | |
| loss_m_list.append(loss_m) | |
| self.log(f"loss_m_{i}", loss_m.detach().item()) | |
| loss += self.hparams.pred_masked_weight * sum(loss_m_list) | |
| loss_weights = self.hparams.loss_weights | |
| if loss_weights is not None: | |
| if torch.is_tensor(extra_losses): | |
| extra_losses = [extra_losses] | |
| names = ['extra'] | |
| if len(loss_weights) == 1 and len(extra_losses) != 1: | |
| loss_weights = [loss_weights[0]] * len(extra_losses) | |
| assert len(extra_losses) == len( | |
| loss_weights | |
| ), f"{len(extra_losses)}, {len(loss_weights)}" | |
| for p, n, coef in zip(extra_losses, names, loss_weights): | |
| if coef != 0 and p is not None: | |
| p = coef * p.float() | |
| loss += p | |
| self.log(f"loss_{n}", p.item()) | |
| return {'loss': loss} | |
| def training_step(self, batch, batch_idx): | |
| output = self(**batch) | |
| self.log('train_loss', output['loss']) | |
| return output | |
| def comput_metrix(self, logits, labels): | |
| y_pred = torch.argmax(logits, dim=-1) | |
| y_pred = y_pred.view(size=(-1,)) | |
| y_true = labels.view(size=(-1,)).float() | |
| corr = torch.eq(y_pred, y_true) | |
| acc = torch.sum(corr.float()) / y_true.size()[0] | |
| return acc | |
| def validation_step(self, batch, batch_idx): | |
| output = self(**batch) | |
| # self.log('val_loss', output.loss, sync_dist=True) | |
| # acc = self.comput_metrix(output.logits, batch['labels']) | |
| # self.log('val_acc', acc, sync_dist=True) | |
| return output | |
| def on_save_checkpoint(self, checkpoint) -> None: | |
| # Save the current loop info in the mid of epoch | |
| # if you lightning <= 1.6.0 uncomment the line below | |
| # checkpoint['loops'] = self.trainer.checkpoint_connector._get_loops_state_dict() | |
| if self.trainer.global_rank == 0: | |
| self.model.save_pretrained(os.path.join( | |
| self.trainer.checkpoint_callback.dirpath, | |
| 'hf_pretrained_epoch{}_step{}'.format(self.trainer.current_epoch, self.trainer.global_step))) | |
| def on_load_checkpoint(self, checkpoint) -> None: | |
| global_step_offset = checkpoint["global_step"] | |
| if 'global_samples' in checkpoint: | |
| self.consumed_samples = checkpoint['global_samples'] | |
| self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset | |
| if __name__ == '__main__': | |
| args_parser = argparse.ArgumentParser() | |
| from fengshen.utils import UniversalCheckpoint | |
| from fengshen.models.model_utils import add_module_args | |
| args_parser = add_module_args(args_parser) | |
| args_parser = datasets.add_data_specific_args(args_parser) | |
| args_parser = UniversalDataModule.add_data_specific_args(args_parser) | |
| args_parser = Trainer.add_argparse_args(args_parser) | |
| args_parser = HubertLightning.add_module_specific_args(args_parser) | |
| args_parser = UniversalCheckpoint.add_argparse_args(args_parser) | |
| args_parser.add_argument('--ckpt_path', type=str, ) | |
| args = args_parser.parse_args() | |
| data_module = UniversalDataModule(args=args, tokenizer=None, collate_fn=None) | |
| data_loader = perpare_data(args) | |
| data_module.datasets = data_loader.datasets | |
| module = HubertLightning(args, loader=data_loader) | |
| lr_monitor = LearningRateMonitor(logging_interval='step') | |
| logger = loggers.TensorBoardLogger(save_dir=os.path.join( | |
| args.default_root_dir, 'logs/'), | |
| name=os.path.basename(os.path.dirname(args.model_path))) | |
| checkpoint_callback = UniversalCheckpoint(args).callbacks | |
| if args.ckpt_path is not None and \ | |
| not os.path.exists(args.ckpt_path): | |
| print('--------warning no checkpoint found--------, remove args') | |
| args.ckpt_path = None | |
| trainer = Trainer.from_argparse_args(args, | |
| logger=logger, | |
| callbacks=[ | |
| lr_monitor, | |
| checkpoint_callback]) | |
| trainer.fit(module, data_module, ckpt_path=args.ckpt_path) | |