import torch import json import torch.nn.functional as F from transformers import BertTokenizer import gradio as gr from typing import List, Dict from model import CommentMTLModel # your class # ------------ Device optimisation ----------------------------------------------------------------- if torch.backends.mps.is_available(): device = torch.device("mps") elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") # ------------ Model / tokenizer ------------------------------------------------------ TOKENIZER_DIR = "/app/bert-base-uncased" # 新增 tokenizer = BertTokenizer.from_pretrained( TOKENIZER_DIR, local_files_only=True # 强制离线 ) with open("config.json") as f: cfg = json.load(f) model = CommentMTLModel( model_name="bert-base-uncased", num_sentiment_labels=cfg["num_sentiment_labels"], num_toxicity_labels=cfg["num_toxicity_labels"], dropout_prob=cfg.get("dropout_prob", 0.1) ) model.load_state_dict(torch.load("pytorch_model.bin", map_location=device)) model.to(device).eval() sentiment_labels = ["Negative", "Neutral", "Positive"] toxicity_labels = ["Toxic", "Severe Toxic", "Obscene", "Threat", "Insult", "Identity Hate"] # ------------ Core inference function ------------------------------------------------ @torch.inference_mode() def analyse_batch(comments_text: str) -> Dict: """ comments_text: multiline string, each line is a comment (≤100 lines) returns: aggregated statistics dict """ # Split input into list of comments, remove blank lines comments: List[str] = [line for line in comments_text.splitlines() if line.strip()] # Ensure we have at most 100 comments comments = comments[:100] # ---- encode all comments (batched) ---------- enc = tokenizer( comments, return_tensors="pt", padding=True, truncation=True, max_length=512 ) enc = {k: v.to(device) for k, v in enc.items()} # ---- forward pass (split to mini-batches in case 100 is too big) ---- batch_size = 32 n = enc["input_ids"].shape[0] # counters sent_counts = {lab: 0 for lab in sentiment_labels} tox_counts = {lab: 0 for lab in toxicity_labels} comments_with_any_tox = 0 for i in range(0, n, batch_size): sl = slice(i, i + batch_size) out = model( input_ids = enc["input_ids"][sl], attention_mask = enc["attention_mask"][sl], token_type_ids = enc.get("token_type_ids", None)[sl] if "token_type_ids" in enc else None ) # ----- sentiment (softmax, pick max) ---------------------------- sent_logits = out["sentiment_logits"] # (b, 3) sent_pred = sent_logits.softmax(dim=1).argmax(dim=1) # (b,) for idx in sent_pred.tolist(): sent_counts[sentiment_labels[idx]] += 1 # ----- toxicity (sigmoid, multi-label) -------------------------- tox_probs = out["toxicity_logits"].sigmoid() # (b, 6) toxic_mask = tox_probs > 0.30 # boolean mask comments_with_any_tox += toxic_mask.any(dim=1).sum().item() # add per-label counts for lab_idx, lab in enumerate(toxicity_labels): tox_counts[lab] += toxic_mask[:, lab_idx].sum().item() return { "sentiment_counts": sent_counts, "toxicity_counts": tox_counts, "comments_with_any_toxicity": int(comments_with_any_tox) } # ------------ Gradio interface ------------------------------------------------------- iface = gr.Interface( fn=analyse_batch, inputs=gr.Textbox( label="YouTube comments (max 100, one per line)", placeholder="Paste up to 100 comments, each on its own line.", lines=20, max_lines=100 ), outputs=gr.JSON(label="Aggregated statistics"), title="YouTube Comment Sentiment & Toxicity Batch API", description=( "Paste up to 100 raw comment strings, each on a new line, " "then click Analyze to receive counts of Positive/Neutral/Negative comments " "plus counts of toxicity labels where probability > 0.30." ), allow_flagging="never" ) if __name__ == "__main__": iface.launch()