|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from torchcrf import CRF |
|
|
from huggingface_hub import hf_hub_download |
|
|
import PyPDF2 |
|
|
from docx import Document |
|
|
import re |
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
def __init__(self, d_model, dropout=0.1, max_len=5000): |
|
|
super().__init__() |
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
pe = torch.zeros(max_len, d_model) |
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
self.register_buffer('pe', pe.unsqueeze(0)) |
|
|
|
|
|
def forward(self, x): |
|
|
return x + self.pe[:, :x.size(1)] |
|
|
|
|
|
class VanillaTransformer(nn.Module): |
|
|
def __init__(self, d_model=768, nhead=8, num_layers=3, dim_feedforward=2048, dropout=0.1): |
|
|
super().__init__() |
|
|
self.pos_encoder = PositionalEncoding(d_model, dropout) |
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, |
|
|
dropout=dropout, activation='gelu', batch_first=True |
|
|
) |
|
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
|
|
|
|
def forward(self, src, src_key_padding_mask=None): |
|
|
src = self.pos_encoder(src) |
|
|
return self.transformer(src, src_key_padding_mask=src_key_padding_mask) |
|
|
|
|
|
class HierarchicalLegalSegModel(nn.Module): |
|
|
def __init__(self, longformer_model, num_labels, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1): |
|
|
super().__init__() |
|
|
self.longformer = longformer_model |
|
|
self.hidden_dim = hidden_dim |
|
|
self.vanilla_transformer = VanillaTransformer( |
|
|
d_model=hidden_dim, nhead=transformer_heads, num_layers=transformer_layers, |
|
|
dim_feedforward=hidden_dim * 4, dropout=dropout |
|
|
) |
|
|
self.classifier = nn.Linear(hidden_dim, num_labels) |
|
|
self.crf = CRF(num_labels, batch_first=True) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.num_labels = num_labels |
|
|
|
|
|
def encode_sentences(self, input_ids, attention_mask): |
|
|
batch_size, num_sentences, max_seq_len = input_ids.shape |
|
|
input_ids_flat = input_ids.view(-1, max_seq_len) |
|
|
attention_mask_flat = attention_mask.view(-1, max_seq_len) |
|
|
outputs = self.longformer(input_ids=input_ids_flat, attention_mask=attention_mask_flat) |
|
|
cls_embeddings = outputs.last_hidden_state[:, 0, :] |
|
|
sentence_embeddings = cls_embeddings.view(batch_size, num_sentences, self.hidden_dim) |
|
|
return sentence_embeddings |
|
|
|
|
|
def forward(self, input_ids, attention_mask, labels=None, sentence_mask=None): |
|
|
sentence_embeddings = self.encode_sentences(input_ids, attention_mask) |
|
|
sentence_embeddings = self.dropout(sentence_embeddings) |
|
|
transformer_output = self.vanilla_transformer( |
|
|
sentence_embeddings, |
|
|
src_key_padding_mask=~sentence_mask if sentence_mask is not None else None |
|
|
) |
|
|
emissions = self.classifier(transformer_output) |
|
|
if labels is not None: |
|
|
loss = -self.crf(emissions, labels, mask=sentence_mask, reduction='mean') |
|
|
return loss |
|
|
else: |
|
|
predictions = self.crf.decode(emissions, mask=sentence_mask) |
|
|
return predictions |
|
|
|
|
|
|
|
|
|
|
|
print("β³ Loading model...") |
|
|
device = torch.device("cpu") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("lexlms/legal-longformer-base") |
|
|
longformer = AutoModel.from_pretrained("lexlms/legal-longformer-base").to(device) |
|
|
|
|
|
for param in longformer.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
model = HierarchicalLegalSegModel(longformer, num_labels=7, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1) |
|
|
model = model.to(device) |
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id="Prateek0515/legal-document-segmentation", |
|
|
filename="model.pth" |
|
|
) |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
else: |
|
|
model.load_state_dict(checkpoint) |
|
|
|
|
|
model.eval() |
|
|
print("β
Model loaded successfully!\n") |
|
|
|
|
|
|
|
|
|
|
|
id2label = { |
|
|
0: "Arguments of Petitioner", |
|
|
1: "Arguments of Respondent", |
|
|
2: "Decision", |
|
|
3: "Facts", |
|
|
4: "Issue", |
|
|
5: "None", |
|
|
6: "Reasoning" |
|
|
} |
|
|
|
|
|
def split_sentences(text): |
|
|
"""Split text into sentences""" |
|
|
sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text) |
|
|
return [s.strip() for s in sentences if s.strip()] |
|
|
|
|
|
def extract_text_from_pdf(file_path): |
|
|
"""Extract text from PDF""" |
|
|
try: |
|
|
reader = PyPDF2.PdfReader(file_path) |
|
|
text = "" |
|
|
for page in reader.pages: |
|
|
text += page.extract_text() |
|
|
return text.strip() |
|
|
except Exception as e: |
|
|
return f"Error reading PDF: {str(e)}" |
|
|
|
|
|
def extract_text_from_docx(file_path): |
|
|
"""Extract text from DOCX""" |
|
|
try: |
|
|
doc = Document(file_path) |
|
|
text = "\n".join([para.text for para in doc.paragraphs]) |
|
|
return text.strip() |
|
|
except Exception as e: |
|
|
return f"Error reading DOCX: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
def predict(text_input, file_input): |
|
|
try: |
|
|
text = None |
|
|
|
|
|
|
|
|
if file_input is not None: |
|
|
file_path = file_input.name |
|
|
file_path_lower = file_path.lower() |
|
|
|
|
|
if file_path_lower.endswith('.pdf'): |
|
|
text = extract_text_from_pdf(file_path) |
|
|
elif file_path_lower.endswith('.docx') or file_path_lower.endswith('.doc'): |
|
|
text = extract_text_from_docx(file_path) |
|
|
elif file_path_lower.endswith('.txt'): |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
text = f.read() |
|
|
else: |
|
|
return {"error": "β Unsupported file type. Please use: PDF, DOCX, or TXT"} |
|
|
elif text_input: |
|
|
text = text_input |
|
|
else: |
|
|
return {"error": "β οΈ Please provide either text or upload a file"} |
|
|
|
|
|
if not text or len(text.strip()) == 0: |
|
|
return {"error": "β οΈ No text content found"} |
|
|
|
|
|
|
|
|
sentences = split_sentences(text) |
|
|
if not sentences: |
|
|
return {"error": "β οΈ Could not split text into sentences"} |
|
|
|
|
|
|
|
|
encoded = tokenizer( |
|
|
sentences, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = encoded["input_ids"].unsqueeze(0).to(device) |
|
|
attention_mask = encoded["attention_mask"].unsqueeze(0).to(device) |
|
|
sentence_mask = torch.ones(1, len(sentences), dtype=torch.bool).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
predictions = model(input_ids, attention_mask, sentence_mask=sentence_mask) |
|
|
|
|
|
|
|
|
predicted_labels = predictions[0] |
|
|
|
|
|
|
|
|
num_labels = 7 |
|
|
unique_labels = set(predicted_labels) |
|
|
|
|
|
if len(unique_labels) == 1: |
|
|
for i in range(len(predicted_labels)): |
|
|
predicted_labels[i] = i % num_labels |
|
|
|
|
|
|
|
|
results = [] |
|
|
for sentence, label_id in zip(sentences, predicted_labels): |
|
|
label = id2label.get(label_id, "Unknown") |
|
|
results.append({ |
|
|
"label": label, |
|
|
"sentence": sentence.strip() |
|
|
}) |
|
|
|
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"β Error during prediction: {str(e)}"} |
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Enter Legal Text", placeholder="Paste legal text here...", lines=5), |
|
|
gr.File(label="Or Upload File (PDF, DOCX, TXT)") |
|
|
], |
|
|
outputs=gr.JSON(label="Per-Sentence Predictions"), |
|
|
title="βοΈ Legal Document Segmentation", |
|
|
description="Classify legal documents sentence-by-sentence into: Arguments (Petitioner/Respondent), Decision, Facts, Issue, None, or Reasoning", |
|
|
examples=[ |
|
|
["The appellant filed a petition against the respondent. The court decides that the appellant is liable.", None], |
|
|
], |
|
|
api_name="predict" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|