import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr import re # ====================== # Load RL Model (General Summary) # ====================== rl_model_path = "./final_rl_model" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") rl_tokenizer = AutoTokenizer.from_pretrained(rl_model_path) rl_model = AutoModelForSeq2SeqLM.from_pretrained(rl_model_path).to(device) rl_model.eval() # ====================== # Load Role-based Pegasus Model # ====================== role_model_name = "nsi319/legal-pegasus" role_tokenizer = AutoTokenizer.from_pretrained(role_model_name) role_model = AutoModelForSeq2SeqLM.from_pretrained(role_model_name).to(device) role_model.eval() # ====================== # Text Preprocessing # Keep only alphabets, numbers, dots, and spaces # ====================== def clean_text(text): text = text.strip() text = re.sub(r'[^a-zA-Z0-9\. ]+', '', text) text = re.sub(r'\s+', ' ', text) text = re.sub(r'\.\s*', '. ', text) return text # ====================== # Inference Function # ====================== def generate_summaries(file_bytes, max_length, role): try: # Decode uploaded file content text = file_bytes.decode('utf-8') # ----- General RL Summary ----- inputs = rl_tokenizer( text, max_length=512, truncation=True, padding='max_length', return_tensors='pt' ).to(device) with torch.no_grad(): rl_summary_ids = rl_model.generate( inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=max_length, min_length=30, num_beams=4, length_penalty=2.0, early_stopping=True ) rl_summary = rl_tokenizer.decode(rl_summary_ids[0], skip_special_tokens=True) rl_summary = clean_text(rl_summary) # ----- Role-based Summary ----- if role == "Lawyer": prompt = ( "From the following legal document, extract and summarize the key legal arguments " "and claims made by the lawyer, focusing on the points raised during the trial, " "evidence presented, and the legal basis for their arguments:\n\n" f"{text}" ) elif role == "Judge": prompt = ( "From the following legal document, extract and summarize the judge's final decision, " "legal reasoning, and justification for the ruling. Include references to relevant laws, " "prior case references, and the overall rationale:\n\n" f"{text}" ) elif role == "Client": prompt = ( "From the following legal document, extract and summarize the client's concerns, " "claims, and statements. Focus on the client's expectations, grievances, and key facts " "presented from the client's perspective:\n\n" f"{text}" ) else: role_summary = "Invalid role selected." return rl_summary, role_summary role_inputs = role_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device) with torch.no_grad(): role_summary_ids = role_model.generate( role_inputs['input_ids'], max_length=150, num_beams=4, no_repeat_ngram_size=3, repetition_penalty=1.5, early_stopping=True ) role_summary = role_tokenizer.decode(role_summary_ids[0], skip_special_tokens=True) role_summary = clean_text(role_summary) return rl_summary, role_summary except Exception as e: error_msg = f"Error during summarization: {str(e)}" return error_msg, error_msg # ====================== # Gradio UI # ====================== iface = gr.Interface( fn=generate_summaries, inputs=[ gr.File(label="Upload a .txt Document", type="binary"), gr.Slider(minimum=50, maximum=300, value=128, step=10, label="Max Length for General Summary"), gr.Dropdown(choices=["Lawyer", "Judge", "Client"], label="Select Role for Role-based Summary") ], outputs=[ gr.Textbox(label="General RL Summary"), gr.Textbox(label="Role-based Summary") ], title="Legal Abstractive Summarizer", description=( ) ) if __name__ == "__main__": iface.launch()
verified