Spaces:
Running
Running
| import streamlit as st | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| from pathlib import Path | |
| import sys | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from transformers import BertTokenizer | |
| import nltk | |
| # Download required NLTK data | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt') | |
| try: | |
| nltk.data.find('corpora/stopwords') | |
| except LookupError: | |
| nltk.download('stopwords') | |
| try: | |
| nltk.data.find('tokenizers/punkt_tab') | |
| except LookupError: | |
| nltk.download('punkt_tab') | |
| try: | |
| nltk.data.find('corpora/wordnet') | |
| except LookupError: | |
| nltk.download('wordnet') | |
| # Add project root to Python path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.append(str(project_root)) | |
| from src.models.hybrid_model import HybridFakeNewsDetector | |
| from src.config.config import * | |
| from src.data.preprocessor import TextPreprocessor | |
| # Custom CSS for streamlined styling with sidebar | |
| st.markdown(""" | |
| <style> | |
| /* Import Google Fonts */ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap'); | |
| /* Global Styles */ | |
| * { | |
| margin: 0; | |
| padding: 0; | |
| box-sizing: border-box; | |
| } | |
| .stApp { | |
| font-family: 'Inter', sans-serif; | |
| background: #f8fafc; | |
| min-height: 100vh; | |
| color: #1a202c; | |
| } | |
| /* Ensure sidebar is visible */ | |
| #MainMenu {visibility: visible;} | |
| footer {visibility: hidden;} | |
| .stDeployButton {display: none;} | |
| header {visibility: hidden;} | |
| .stApp > header {visibility: hidden;} | |
| /* Container */ | |
| .container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 1rem; | |
| } | |
| /* Header */ | |
| .header { | |
| padding: 1rem 0; | |
| text-align: center; | |
| } | |
| .header-title { | |
| font-size: 2rem; | |
| font-weight: 800; | |
| color: #1a202c; | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 0.5rem; | |
| } | |
| /* Hero Section */ | |
| .hero { | |
| display: flex; | |
| align-items: center; | |
| gap: 2rem; | |
| margin-bottom: 2rem; | |
| } | |
| .hero-left { | |
| flex: 1; | |
| padding: 1rem; | |
| } | |
| .hero-right { | |
| flex: 1; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| } | |
| .hero-right img { | |
| max-width: 100%; | |
| height: auto; | |
| border-radius: 8px; | |
| } | |
| .hero-title { | |
| font-size: 2.5rem; | |
| font-weight: 700; | |
| color: #1a202c; | |
| margin-bottom: 0.5rem; | |
| } | |
| .hero-text { | |
| font-size: 1rem; | |
| color: #4a5568; | |
| line-height: 1.5; | |
| max-width: 450px; | |
| } | |
| /* About Section */ | |
| .about-section { | |
| margin-bottom: 2rem; | |
| text-align: center; | |
| } | |
| .about-title { | |
| font-size: 1.8rem; | |
| font-weight: 600; | |
| color: #1a202c; | |
| margin-bottom: 0.5rem; | |
| } | |
| .about-text { | |
| font-size: 1rem; | |
| color: #4a5568; | |
| line-height: 1.5; | |
| max-width: 600px; | |
| margin: 0 auto; | |
| } | |
| /* Input Section */ | |
| .input-container { | |
| max-width: 800px; | |
| margin: 0 auto; | |
| } | |
| .stTextArea > div > div > textarea { | |
| border-radius: 8px !important; | |
| border: 1px solid #d1d5db !important; | |
| padding: 1rem !important; | |
| font-size: 1rem !important; | |
| font-family: 'Inter', sans-serif !important; | |
| background: #ffffff !important; | |
| min-height: 150px !important; | |
| transition: all 0.2s ease !important; | |
| } | |
| .stTextArea > div > div > textarea:focus { | |
| border-color: #6366f1 !important; | |
| box-shadow: 0 0 0 2px rgba(99, 102, 241, 0.1) !important; | |
| outline: none !important; | |
| } | |
| .stTextArea > div > div > textarea::placeholder { | |
| color: #9ca3af !important; | |
| } | |
| /* Button Styling */ | |
| .stButton > button { | |
| background: #6366f1 !important; | |
| color: white !important; | |
| border-radius: 8px !important; | |
| padding: 0.75rem 2rem !important; | |
| font-size: 1rem !important; | |
| font-weight: 600 !important; | |
| font-family: 'Inter', sans-serif !important; | |
| transition: all 0.2s ease !important; | |
| border: none !important; | |
| width: 100% !important; | |
| } | |
| .stButton > button:hover { | |
| background: #4f46e5 !important; | |
| transform: translateY(-1px) !important; | |
| } | |
| /* Results Section */ | |
| .results-container { | |
| margin-top: 1rem; | |
| padding: 1rem; | |
| border-radius: 8px; | |
| } | |
| .result-card { | |
| padding: 1rem; | |
| border-radius: 8px; | |
| border-left: 4px solid transparent; | |
| margin-bottom: 1rem; | |
| } | |
| .fake-news { | |
| background: #fef2f2; | |
| border-left-color: #ef4444; | |
| } | |
| .real-news { | |
| background: #ecfdf5; | |
| border-left-color: #10b981; | |
| } | |
| .prediction-badge { | |
| font-weight: 600; | |
| font-size: 1rem; | |
| margin-bottom: 0.5rem; | |
| display: flex; | |
| align-items: center; | |
| gap: 0.5rem; | |
| } | |
| .confidence-score { | |
| font-weight: 600; | |
| margin-left: auto; | |
| font-size: 1rem; | |
| } | |
| /* Chart Containers */ | |
| .chart-container { | |
| padding: 1rem; | |
| border-radius: 8px; | |
| margin: 1rem 0; | |
| } | |
| /* Footer */ | |
| .footer { | |
| margin-top: 2rem; | |
| padding: 1rem 0; | |
| text-align: center; | |
| border-top: 1px solid #e5e7eb; | |
| } | |
| /* Sidebar Styling */ | |
| .stSidebar { | |
| background: #ffffff; | |
| border-right: 1px solid #e5e7eb; | |
| } | |
| .stSidebar .sidebar-content { | |
| padding: 1rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def load_model_and_tokenizer(): | |
| """Load the model and tokenizer (cached).""" | |
| model = HybridFakeNewsDetector( | |
| bert_model_name=BERT_MODEL_NAME, | |
| lstm_hidden_size=LSTM_HIDDEN_SIZE, | |
| lstm_num_layers=LSTM_NUM_LAYERS, | |
| dropout_rate=DROPOUT_RATE | |
| ) | |
| state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')) | |
| model_state_dict = model.state_dict() | |
| filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} | |
| model.load_state_dict(filtered_state_dict, strict=False) | |
| model.eval() | |
| tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) | |
| return model, tokenizer | |
| def get_preprocessor(): | |
| """Get the text preprocessor (cached).""" | |
| return TextPreprocessor() | |
| def predict_news(text): | |
| """Predict if the given news is fake or real.""" | |
| model, tokenizer = load_model_and_tokenizer() | |
| preprocessor = get_preprocessor() | |
| processed_text = preprocessor.preprocess_text(text) | |
| encoding = tokenizer.encode_plus( | |
| processed_text, | |
| add_special_tokens=True, | |
| max_length=MAX_SEQUENCE_LENGTH, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| with torch.no_grad(): | |
| outputs = model( | |
| encoding['input_ids'], | |
| encoding['attention_mask'] | |
| ) | |
| probabilities = torch.softmax(outputs['logits'], dim=1) | |
| prediction = torch.argmax(outputs['logits'], dim=1) | |
| attention_weights = outputs['attention_weights'] | |
| attention_weights_np = attention_weights[0].cpu().numpy() | |
| return { | |
| 'prediction': prediction.item(), | |
| 'label': 'FAKE' if prediction.item() == 1 else 'REAL', | |
| 'confidence': torch.max(probabilities, dim=1)[0].item(), | |
| 'probabilities': { | |
| 'REAL': probabilities[0][0].item(), | |
| 'FAKE': probabilities[0][1].item() | |
| }, | |
| 'attention_weights': attention_weights_np | |
| } | |
| def plot_confidence(probabilities): | |
| """Plot prediction confidence with simplified styling.""" | |
| fig = go.Figure(data=[ | |
| go.Bar( | |
| x=list(probabilities.keys()), | |
| y=list(probabilities.values()), | |
| text=[f'{p:.1%}' for p in probabilities.values()], | |
| textposition='auto', | |
| marker=dict( | |
| color=['#10b981', '#ef4444'], | |
| line=dict(color='#ffffff', width=1), | |
| ), | |
| ) | |
| ]) | |
| fig.update_layout( | |
| title={'text': 'Prediction Confidence', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}}, | |
| xaxis=dict(title='Classification', titlefont={'size': 12}, tickfont={'size': 10}), | |
| yaxis=dict(title='Probability', range=[0, 1], tickformat='.0%', titlefont={'size': 12}, tickfont={'size': 10}), | |
| template='plotly_white', | |
| height=300, | |
| margin=dict(t=60, b=60) | |
| ) | |
| return fig | |
| def plot_attention(text, attention_weights): | |
| """Plot attention weights with simplified styling.""" | |
| tokens = text.split()[:20] | |
| attention_weights = attention_weights[:len(tokens)] | |
| if isinstance(attention_weights, (list, np.ndarray)): | |
| attention_weights = np.array(attention_weights).flatten() | |
| normalized_weights = attention_weights / max(attention_weights) if max(attention_weights) > 0 else attention_weights | |
| colors = [f'rgba(99, 102, 241, {0.4 + 0.6 * float(w)})' for w in normalized_weights] | |
| fig = go.Figure(data=[ | |
| go.Bar( | |
| x=tokens, | |
| y=attention_weights, | |
| text=[f'{float(w):.3f}' for w in attention_weights], | |
| textposition='auto', | |
| marker=dict(color=colors), | |
| ) | |
| ]) | |
| fig.update_layout( | |
| title={'text': 'Attention Weights', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}}, | |
| xaxis=dict(title='Words', tickangle=45, titlefont={'size': 12}, tickfont={'size': 10}), | |
| yaxis=dict(title='Attention Score', titlefont={'size': 12}, tickfont={'size': 10}), | |
| template='plotly_white', | |
| height=350, | |
| margin=dict(t=60, b=80) | |
| ) | |
| return fig | |
| def main(): | |
| # Sidebar | |
| with st.sidebar: | |
| st.markdown("## TruthCheck Menu") | |
| st.markdown("Navigate through the options below:") | |
| st.button("Home", disabled=True) | |
| st.button("Analyze News", key="nav_analyze") | |
| st.button("About", key="nav_about") | |
| st.markdown("---") | |
| st.markdown("**Contact**") | |
| st.markdown("π§ support@truthcheck.ai") | |
| # Header | |
| st.markdown(""" | |
| <div class="header"> | |
| <div class="container"> | |
| <h1 class="header-title">π‘οΈ TruthCheck</h1> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Hero Section | |
| st.markdown(""" | |
| <div class="container"> | |
| <div class="hero"> | |
| <div class="hero-left"> | |
| <h2 class="hero-title">Instant Fake News Detection</h2> | |
| <p class="hero-text"> | |
| Verify news articles with our AI-powered tool, driven by BERT and BiLSTM for fast and accurate authenticity analysis. | |
| </p> | |
| </div> | |
| <div class="hero-right"> | |
| <img src="hero.png" alt="TruthCheck Illustration"> | |
| </div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # About Section | |
| st.markdown(""" | |
| <div class="container"> | |
| <div class="about-section"> | |
| <h2 class="about-title">About TruthCheck</h2> | |
| <p class="about-text"> | |
| TruthCheck uses a hybrid BERT-BiLSTM model to detect fake news with high accuracy. Paste an article below for instant analysis. | |
| </p> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Input Section | |
| st.markdown('<div class="container"><div class="input-container">', unsafe_allow_html=True) | |
| news_text = st.text_area( | |
| "Analyze a News Article", | |
| height=150, | |
| placeholder="Paste your news article here for instant AI analysis...", | |
| key="news_input" | |
| ) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Analyze Button | |
| st.markdown('<div class="container">', unsafe_allow_html=True) | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| analyze_button = st.button("π Analyze Now", key="analyze_button") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| if analyze_button: | |
| if news_text and len(news_text.strip()) > 10: | |
| with st.spinner("Analyzing article..."): | |
| try: | |
| result = predict_news(news_text) | |
| st.markdown('<div class="container"><div class="results-container">', unsafe_allow_html=True) | |
| # Prediction Result | |
| col1, col2 = st.columns([1, 1], gap="medium") | |
| with col1: | |
| if result['label'] == 'FAKE': | |
| st.markdown(f''' | |
| <div class="result-card fake-news"> | |
| <div class="prediction-badge">π¨ Fake News Detected <span class="confidence-score">{result["confidence"]:.1%}</span></div> | |
| <p>Our AI has identified this content as likely misinformation based on linguistic patterns and content analysis.</p> | |
| </div> | |
| ''', unsafe_allow_html=True) | |
| else: | |
| st.markdown(f''' | |
| <div class="result-card real-news"> | |
| <div class="prediction-badge">β Authentic News <span class="confidence-score">{result["confidence"]:.1%}</span></div> | |
| <p>This content appears to be legitimate based on professional writing style and factual consistency.</p> | |
| </div> | |
| ''', unsafe_allow_html=True) | |
| with col2: | |
| st.markdown('<div class="chart-container">', unsafe_allow_html=True) | |
| st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Attention Analysis | |
| st.markdown('<div class="chart-container">', unsafe_allow_html=True) | |
| st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True) | |
| st.markdown('</div></div></div>', unsafe_allow_html=True) | |
| except Exception as e: | |
| st.markdown('<div class="container">', unsafe_allow_html=True) | |
| st.error(f"Error: {str(e)}. Please try again or contact support.") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| else: | |
| st.markdown('<div class="container">', unsafe_allow_html=True) | |
| st.error("Please enter a news article (at least 10 words) for analysis.") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Footer | |
| st.markdown(""" | |
| <div class="footer"> | |
| <p style="text-align: center; font-weight: 600; font-size: 16px;">π» Developed with β€οΈ using Streamlit | Β© 2025</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() |