IndoBERT-MultiPredict / inference.py
orva06's picture
copas dari training gcolab
c0d9f51 verified
# inference.py
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
class IndoBertMultiPredict(nn.Module):
def __init__(self, n_topic, n_tax, dropout=0.3, pretrained="indobenchmark/indobert-base-p2"):
super().__init__()
self.bert = AutoModel.from_pretrained(pretrained)
hidden = self.bert.config.hidden_size
self.dropout = nn.Dropout(dropout)
self.head_topic = nn.Linear(hidden, n_topic)
self.head_tax = nn.Linear(hidden, n_tax)
def forward(self, input_ids, attention_mask):
out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled = out.last_hidden_state[:, 0] # CLS token
pooled = self.dropout(pooled)
logits_topic = self.head_topic(pooled)
logits_tax = self.head_tax(pooled)
return logits_topic, logits_tax
# --- 2. CLASS UNTUK MENANGANI PREDIKSI ---
class InferenceEngine:
def __init__(self, ckpt_path, tokenizer_name):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
n_topic = 15
n_tax = 4
print(f"Loading model architecture...")
self.model = IndoBertMultiPredict(n_topic, n_tax, pretrained=tokenizer_name)
print(f"Loading weights from {ckpt_path}...")
# Load state dict
# map_location='cpu' biar aman di space free tier
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
# Masukkan weights ke model
try:
self.model.load_state_dict(state_dict)
print("Weights loaded successfully!")
except RuntimeError as e:
print("ERROR: Struktur model tidak cocok dengan file weights.")
print("Detail:", e)
raise e
self.model.to(self.device)
self.model.eval()
# Load Label Mapping (jika ada file .npy, jika tidak pakai angka)
import numpy as np
try:
self.le_topic = np.load("le_topic_classes.npy", allow_pickle=True)
self.le_tax = np.load("le_tax_classes.npy", allow_pickle=True)
except:
print("Warning: .npy classes not found, using index numbers.")
self.le_topic = [str(i) for i in range(n_topic)]
self.le_tax = [str(i) for i in range(n_tax)]
def predict_texts(self, texts):
# Tokenize batch
inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
# Forward pass
logits_topic, logits_tax = self.model(inputs["input_ids"], inputs["attention_mask"])
# Hitung probabilitas
probs_topic = torch.nn.functional.softmax(logits_topic, dim=1)
probs_tax = torch.nn.functional.softmax(logits_tax, dim=1)
# Ambil kelas dengan probabilitas tertinggi
conf_topic, idx_topic = torch.max(probs_topic, dim=1)
conf_tax, idx_tax = torch.max(probs_tax, dim=1)
results = []
for i in range(len(texts)):
t_idx = idx_topic[i].item()
x_idx = idx_tax[i].item()
# Convert index ke label string
t_label = self.le_topic[t_idx] if t_idx < len(self.le_topic) else str(t_idx)
x_label = self.le_tax[x_idx] if x_idx < len(self.le_tax) else str(x_idx)
results.append({
"text": texts[i],
"topic_label": t_label,
"topic_idx": t_idx,
"topic_conf": conf_topic[i].item(),
"topic_probs": probs_topic[i].cpu().numpy().tolist(),
"tax_label": x_label,
"tax_idx": x_idx,
"tax_conf": conf_tax[i].item()
})
return results