Spaces:
Runtime error
Runtime error
| # Source: https://github.com/wayveai/LingoQA/blob/main/benchmark/judge.py | |
| from enum import Enum | |
| from typing import List | |
| import torch | |
| from torch import nn | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| LINGOQA_TEST = "https://drive.usercontent.google.com/u/1/uc?id=1I8u6uYysQUstoVYZapyRQkXmOwr-AG3d&export=download" | |
| LINGO_JUDGE = "wayveai/Lingo-Judge" | |
| class Keys(str, Enum): | |
| question_id = "question_id" | |
| segment_id = "segment_id" | |
| question = "question" | |
| answer = "answer" | |
| references = "references" | |
| prediction = "prediction" | |
| max_score = "max_score" | |
| score = "score" | |
| probability = "probability" | |
| correct = "correct" | |
| class LingoJudge(nn.Module): | |
| """ | |
| LingoJudge is a textual classifier that evaluates the truthfulness of an answer on the LingoQA benchmark. | |
| """ | |
| def __init__(self, pretrained_model=LINGO_JUDGE): | |
| super().__init__() | |
| self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model).eval() | |
| def forward(self, question: str, references: List[str], prediction: str): | |
| """ | |
| Inference function for textual classifier with multiple reference answers. | |
| Args: | |
| question: Input question. | |
| references: List of references. | |
| prediction: Model prediction. | |
| Output: | |
| scores: Score indicating truthfulness. | |
| """ | |
| device = next(self.parameters()).device | |
| texts = [ | |
| f"{self.tokenizer.cls_token}\nQuestion: {question}\nAnswer: {a_gt}\nStudent: {prediction}" | |
| for a_gt in references | |
| ] | |
| encoded_input = self.tokenizer( | |
| texts, return_tensors="pt", padding=True, truncation=True, max_length=128 | |
| ) | |
| encoded_input = {k: v.to(device) for k, v in encoded_input.items()} | |
| output = self.model(**encoded_input) | |
| scores = output.logits.squeeze(-1) | |
| return scores | |
| def compute(self, questions: List[str], references: List[List[str]], predictions: List[str]): | |
| """ | |
| Compute maximum classifier metric. For multiple reference answers, selects the highest one. | |
| Args: | |
| questions: List of input questions. | |
| references: List of lists, with multiple references per question supported. | |
| predictions: List of model predictions. | |
| Output: | |
| scores: Score indicating truthfulness. | |
| """ | |
| max_scores = [] | |
| for index, question in enumerate(questions): | |
| references_preprocessed = [ | |
| self.preprocess(reference) for reference in references[index] | |
| ] | |
| prediction_preprocessed = self.preprocess(predictions[index]) | |
| scores = self.forward(question, references_preprocessed, prediction_preprocessed) | |
| max_score = [max(scores)] | |
| max_scores.extend(max_score) | |
| return torch.Tensor(max_scores) | |
| def preprocess(self, string: str): | |
| """ | |
| Preprocessing function for consistency. | |
| Args: | |
| string: input string to be processed. | |
| Output: | |
| output: processed string with lower cases and trailing lines removed. | |
| """ | |
| output = str(string).lower().lstrip().rstrip() | |
| return output | |