Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| from collections import Counter | |
| import torch | |
| from torch import nn | |
| # import seqeval | |
| from .utils_ner import get_entities | |
| class metrics_mlm_acc(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, logits, labels, masked_lm_metric): | |
| # if len(list(logits.shape))==3: | |
| mask_label_size = 0 | |
| for i in masked_lm_metric: | |
| for j in i: | |
| if j > 0: | |
| mask_label_size += 1 | |
| y_pred = torch.argmax(logits, dim=-1) | |
| y_pred = y_pred.view(size=(-1,)) | |
| y_true = labels.view(size=(-1,)) | |
| masked_lm_metric = masked_lm_metric.view(size=(-1,)) | |
| corr = torch.eq(y_pred, y_true) | |
| corr = torch.multiply(masked_lm_metric, corr) | |
| acc = torch.sum(corr.float())/mask_label_size | |
| return acc | |
| class SeqEntityScore(object): | |
| def __init__(self, id2label, markup='bios', middle_prefix='I-'): | |
| self.id2label = id2label | |
| self.markup = markup | |
| self.middle_prefix = middle_prefix | |
| self.reset() | |
| def reset(self): | |
| self.origins = [] | |
| self.founds = [] | |
| self.rights = [] | |
| def compute(self, origin, found, right): | |
| recall = 0 if origin == 0 else (right / origin) | |
| precision = 0 if found == 0 else (right / found) | |
| f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) | |
| return recall, precision, f1 | |
| def result(self): | |
| class_info = {} | |
| origin_counter = Counter([x[0] for x in self.origins]) | |
| found_counter = Counter([x[0] for x in self.founds]) | |
| right_counter = Counter([x[0] for x in self.rights]) | |
| for type_, count in origin_counter.items(): | |
| origin = count | |
| found = found_counter.get(type_, 0) | |
| right = right_counter.get(type_, 0) | |
| # print('origin:', origin, ' found:', found, ' right:', right) | |
| recall, precision, f1 = self.compute(origin, found, right) | |
| class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} | |
| origin = len(self.origins) | |
| found = len(self.founds) | |
| right = len(self.rights) | |
| recall, precision, f1 = self.compute(origin, found, right) | |
| return {'acc': precision, 'recall': recall, 'f1': f1}, class_info | |
| def update(self, label_paths, pred_paths): | |
| ''' | |
| labels_paths: [[],[],[],....] | |
| pred_paths: [[],[],[],.....] | |
| :param label_paths: | |
| :param pred_paths: | |
| :return: | |
| Example: | |
| >>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] | |
| >>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] | |
| ''' | |
| for label_path, pre_path in zip(label_paths, pred_paths): | |
| label_entities = get_entities(label_path, self.id2label, self.markup, self.middle_prefix) | |
| pre_entities = get_entities(pre_path, self.id2label, self.markup, self.middle_prefix) | |
| # print('label:', label_path, ',label_entities: ', label_entities) | |
| # print('pred:', pre_path, ',pre_entities: ', pre_entities) | |
| self.origins.extend(label_entities) | |
| self.founds.extend(pre_entities) | |
| self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities]) | |