Commit History

Update final_rl_model/config.json
56f2e6d
verified

suruthik71 commited on

Update app.py
4d00bc3
verified

suruthik71 commited on

Update app.py
22f0593
verified

suruthik71 commited on

Update requirements.txt
625ad6e
verified

suruthik71 commited on

Update requirements.txt
cb05324
verified

suruthik71 commited on

Update app.py
3497aa3
verified

suruthik71 commited on

Update app.py
9346b84
verified

suruthik71 commited on

Update app.py
7c2c177
verified

suruthik71 commited on

Upload 7 files
9f79875
verified

suruthik71 commited on

Upload 4 files
873a04f
verified

suruthik71 commited on

Delete app.py
cf64432
verified

suruthik71 commited on

import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr import re # ====================== # Load RL Model (General Summary) # ====================== rl_model_path = "./final_rl_model" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") rl_tokenizer = AutoTokenizer.from_pretrained(rl_model_path) rl_model = AutoModelForSeq2SeqLM.from_pretrained(rl_model_path).to(device) rl_model.eval() # ====================== # Load Role-based Pegasus Model # ====================== role_model_name = "nsi319/legal-pegasus" role_tokenizer = AutoTokenizer.from_pretrained(role_model_name) role_model = AutoModelForSeq2SeqLM.from_pretrained(role_model_name).to(device) role_model.eval() # ====================== # Text Preprocessing # Keep only alphabets, numbers, dots, and spaces # ====================== def clean_text(text): text = text.strip() text = re.sub(r'[^a-zA-Z0-9\. ]+', '', text) text = re.sub(r'\s+', ' ', text) text = re.sub(r'\.\s*', '. ', text) return text # ====================== # Inference Function # ====================== def generate_summaries(file_bytes, max_length, role): try: # Decode uploaded file content text = file_bytes.decode('utf-8') # ----- General RL Summary ----- inputs = rl_tokenizer( text, max_length=512, truncation=True, padding='max_length', return_tensors='pt' ).to(device) with torch.no_grad(): rl_summary_ids = rl_model.generate( inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=max_length, min_length=30, num_beams=4, length_penalty=2.0, early_stopping=True ) rl_summary = rl_tokenizer.decode(rl_summary_ids[0], skip_special_tokens=True) rl_summary = clean_text(rl_summary) # ----- Role-based Summary ----- if role == "Lawyer": prompt = ( "From the following legal document, extract and summarize the key legal arguments " "and claims made by the lawyer, focusing on the points raised during the trial, " "evidence presented, and the legal basis for their arguments:\n\n" f"{text}" ) elif role == "Judge": prompt = ( "From the following legal document, extract and summarize the judge's final decision, " "legal reasoning, and justification for the ruling. Include references to relevant laws, " "prior case references, and the overall rationale:\n\n" f"{text}" ) elif role == "Client": prompt = ( "From the following legal document, extract and summarize the client's concerns, " "claims, and statements. Focus on the client's expectations, grievances, and key facts " "presented from the client's perspective:\n\n" f"{text}" ) else: role_summary = "Invalid role selected." return rl_summary, role_summary role_inputs = role_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device) with torch.no_grad(): role_summary_ids = role_model.generate( role_inputs['input_ids'], max_length=150, num_beams=4, no_repeat_ngram_size=3, repetition_penalty=1.5, early_stopping=True ) role_summary = role_tokenizer.decode(role_summary_ids[0], skip_special_tokens=True) role_summary = clean_text(role_summary) return rl_summary, role_summary except Exception as e: error_msg = f"Error during summarization: {str(e)}" return error_msg, error_msg # ====================== # Gradio UI # ====================== iface = gr.Interface( fn=generate_summaries, inputs=[ gr.File(label="Upload a .txt Document", type="binary"), gr.Slider(minimum=50, maximum=300, value=128, step=10, label="Max Length for General Summary"), gr.Dropdown(choices=["Lawyer", "Judge", "Client"], label="Select Role for Role-based Summary") ], outputs=[ gr.Textbox(label="General RL Summary"), gr.Textbox(label="Role-based Summary") ], title="Legal Abstractive Summarizer", description=( ) ) if __name__ == "__main__": iface.launch()
eedec1e
verified

suruthik71 commited on

initial commit
c2b0b97
verified

suruthik71 commited on