Files changed (10) hide show
  1. .env +27 -0
  2. Dockerfile +36 -0
  3. README.md +40 -10
  4. app.py +1508 -0
  5. auth.py +633 -0
  6. check_routes.py +61 -0
  7. fix_users_table.py +180 -0
  8. initialize_plans.py +25 -0
  9. paypal_integration.py +1004 -0
  10. requirements.txt +21 -0
.env ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backend API URL (will be updated to your Hugging Face Space URL)
2
+ BACKEND_URL=https://your-username-legal-document-analyzer.hf.space
3
+
4
+ # Frontend App URL (if you're deploying the frontend separately)
5
+ APP_URL=http://localhost:3000
6
+
7
+ # PayPal Configuration
8
+ PAYPAL_BASE_URL=https://api-m.sandbox.paypal.com
9
+ PAYPAL_CLIENT_ID=ASOzKSgawVTyJpK_1vOKap4TTJ3OsHQ9syDvEX43O2Vi7ZVoto1z6zYWWm20LrrJ-dA9wqD33jrT5qLu
10
+ PAYPAL_SECRET=ED0YJoSvOq6sUjOfQvz88Z-NsFhDyK2Dv2TUI3LOoEZ11rkNev92Cp6O8d2mBLw7fdunKRfHhcMyNuXN
11
+
12
+ # JWT Secret
13
+ JWT_SECRET=13105030e5bfb231ebf59b8cdf91a39571e51a51fe62a4e1c323079f9945cb7d
14
+
15
+ # Database Path (adjusted for Hugging Face container structure)
16
+ DB_PATH=/app/data/user_data.db
17
+
18
+ # Plan IDs Path (adjusted for Hugging Face container structure)
19
+ PLAN_IDS_PATH=/app/data/plan_ids.json
20
+
21
+ # Model Loading Configuration
22
+ LOAD_MODELS=True
23
+ MODELS_CACHE_DIR=/app/models_cache
24
+
25
+ # Hugging Face Spaces Configuration
26
+ PORT=7860
27
+ HOST=0.0.0.0
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ libffi-dev \
9
+ git \
10
+ ffmpeg \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements first for better caching
14
+ COPY requirements.txt .
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Create necessary directories
18
+ RUN mkdir -p static uploads temp data models_cache
19
+
20
+ # Copy only the necessary files
21
+ COPY app.py .
22
+ COPY auth.py .
23
+ COPY paypal_integration.py .
24
+ COPY .env .
25
+
26
+ # Set environment variables
27
+ ENV LOAD_MODELS=True
28
+ ENV MODELS_CACHE_DIR=/app/models_cache
29
+ ENV PORT=7860
30
+ ENV HOST=0.0.0.0
31
+
32
+ # Expose the port
33
+ EXPOSE 7860
34
+
35
+ # Command to run the application
36
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,40 @@
1
- ---
2
- title: Testing
3
- emoji: 📚
4
- colorFrom: gray
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Legal Document Analysis API
3
+ emoji: 📄
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ # Legal Document Analysis API
12
+
13
+ This API provides tools for analyzing legal documents, videos, and audio files. It uses NLP models to extract insights, summarize content, and answer legal questions.
14
+
15
+ ## Features
16
+
17
+ - Document analysis (PDF)
18
+ - Video and audio transcription and analysis
19
+ - Legal question answering
20
+ - Risk assessment and visualization
21
+ - Contract clause analysis
22
+
23
+ ## Deployment
24
+
25
+ This API is deployed on Hugging Face Spaces.
26
+
27
+ ## API Endpoints
28
+
29
+ - `/analyze_document` - Analyze legal documents
30
+ - `/analyze_legal_video` - Analyze legal videos
31
+ - `/analyze_legal_audio` - Analyze legal audio
32
+ - `/ask_legal_question` - Ask questions about legal documents
33
+
34
+ ## Technologies
35
+
36
+ - FastAPI
37
+ - Hugging Face Transformers
38
+ - SpaCy
39
+ - PyTorch
40
+ - MoviePy
app.py ADDED
@@ -0,0 +1,1508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import time
4
+ import uuid
5
+ import tempfile
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import pdfplumber
9
+ import spacy
10
+ import torch
11
+ import sqlite3
12
+ import uvicorn
13
+ import moviepy.editor as mp
14
+ from threading import Thread
15
+ from datetime import datetime, timedelta
16
+ from typing import List, Dict, Optional
17
+ from fastapi import FastAPI, File, UploadFile, Form, Depends, HTTPException, status, Header
18
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
19
+ from fastapi.staticfiles import StaticFiles
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ import logging
22
+ from pydantic import BaseModel
23
+ from transformers import (
24
+ AutoTokenizer,
25
+ AutoModelForQuestionAnswering,
26
+ pipeline,
27
+ TrainingArguments,
28
+ Trainer
29
+ )
30
+ from sentence_transformers import SentenceTransformer
31
+ from passlib.context import CryptContext
32
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
33
+ import jwt
34
+ from dotenv import load_dotenv
35
+ # Import get_db_connection from auth
36
+ from auth import (
37
+ User, UserCreate, Token, get_current_active_user, authenticate_user,
38
+ create_access_token, hash_password, register_user, check_subscription_access,
39
+ SUBSCRIPTION_TIERS, JWT_EXPIRATION_DELTA, get_db_connection, update_auth_db_schema
40
+ )
41
+ # Add this import near the top with your other imports
42
+ from paypal_integration import (
43
+ create_user_subscription, verify_subscription_payment,
44
+ update_user_subscription, handle_subscription_webhook, initialize_database
45
+ )
46
+ from fastapi import Request # Add this if not already imported
47
+
48
+ logging.basicConfig(
49
+ level=logging.INFO,
50
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
51
+ )
52
+ logger = logging.getLogger("app")
53
+
54
+ # Initialize the database
55
+ # Initialize FastAPI app
56
+ app = FastAPI(
57
+ title="Legal Document Analysis API",
58
+ description="API for analyzing legal documents, videos, and audio",
59
+ version="1.0.0"
60
+ )
61
+
62
+ # Set up CORS middleware
63
+ app.add_middleware(
64
+ CORSMiddleware,
65
+ allow_origins=["*"], # Frontend URL
66
+ allow_credentials=True,
67
+ allow_methods=["*"],
68
+ allow_headers=["*"],
69
+ )
70
+ initialize_database()
71
+ try:
72
+ update_auth_db_schema()
73
+ logger.info("Database schema updated successfully")
74
+ except Exception as e:
75
+ logger.error(f"Database schema update error: {e}")
76
+
77
+ # Create static directory for file storage
78
+ os.makedirs("static", exist_ok=True)
79
+ os.makedirs("uploads", exist_ok=True)
80
+ os.makedirs("temp", exist_ok=True)
81
+ app.mount("/static", StaticFiles(directory="static"), name="static")
82
+
83
+ # Set device for model inference
84
+ device = "cuda" if torch.cuda.is_available() else "cpu"
85
+ print(f"Using device: {device}")
86
+
87
+ # Initialize chat history
88
+ chat_history = []
89
+
90
+ # Document context storage
91
+ document_contexts = {}
92
+
93
+ def store_document_context(task_id, text):
94
+ """Store document text for later retrieval."""
95
+ document_contexts[task_id] = text
96
+
97
+ def load_document_context(task_id):
98
+ """Load document text for a given task ID."""
99
+ return document_contexts.get(task_id, "")
100
+
101
+ def get_db_connection():
102
+ """Get a connection to the SQLite database."""
103
+ db_path = os.path.join(os.path.dirname(__file__), "legal_analysis.db")
104
+ conn = sqlite3.connect(db_path)
105
+ conn.row_factory = sqlite3.Row
106
+ return conn
107
+
108
+ load_dotenv()
109
+ DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "data/user_data.db"))
110
+ os.makedirs(os.path.join(os.path.dirname(__file__), "data"), exist_ok=True)
111
+
112
+ def fine_tune_qa_model():
113
+ """Fine-tunes a QA model on the CUAD dataset."""
114
+ print("Loading base model for fine-tuning...")
115
+ tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
116
+ model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
117
+
118
+ # Load and preprocess CUAD dataset
119
+ print("Loading CUAD dataset...")
120
+ from datasets import load_dataset
121
+
122
+ try:
123
+ dataset = load_dataset("cuad")
124
+ except Exception as e:
125
+ print(f"Error loading CUAD dataset: {str(e)}")
126
+ print("Downloading CUAD dataset from alternative source...")
127
+ # Implement alternative dataset loading here
128
+ return tokenizer, model
129
+
130
+ print(f"Dataset loaded with {len(dataset['train'])} training examples")
131
+
132
+ # Preprocess the dataset
133
+ def preprocess_function(examples):
134
+ questions = [q.strip() for q in examples["question"]]
135
+ contexts = [c.strip() for c in examples["context"]]
136
+
137
+ inputs = tokenizer(
138
+ questions,
139
+ contexts,
140
+ max_length=384,
141
+ truncation="only_second",
142
+ stride=128,
143
+ return_overflowing_tokens=True,
144
+ return_offsets_mapping=True,
145
+ padding="max_length",
146
+ )
147
+
148
+ offset_mapping = inputs.pop("offset_mapping")
149
+ sample_map = inputs.pop("overflow_to_sample_mapping")
150
+
151
+ answers = examples["answers"]
152
+ start_positions = []
153
+ end_positions = []
154
+
155
+ for i, offset in enumerate(offset_mapping):
156
+ sample_idx = sample_map[i]
157
+ answer = answers[sample_idx]
158
+
159
+ start_char = answer["answer_start"][0] if len(answer["answer_start"]) > 0 else 0
160
+ end_char = start_char + len(answer["text"][0]) if len(answer["text"]) > 0 else 0
161
+
162
+ sequence_ids = inputs.sequence_ids(i)
163
+
164
+ # Find the start and end of the context
165
+ idx = 0
166
+ while sequence_ids[idx] != 1:
167
+ idx += 1
168
+ context_start = idx
169
+
170
+ while idx < len(sequence_ids) and sequence_ids[idx] == 1:
171
+ idx += 1
172
+ context_end = idx - 1
173
+
174
+ # If the answer is not fully inside the context, label is (0, 0)
175
+ if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
176
+ start_positions.append(0)
177
+ end_positions.append(0)
178
+ else:
179
+ # Otherwise it's the start and end token positions
180
+ idx = context_start
181
+ while idx <= context_end and offset[idx][0] <= start_char:
182
+ idx += 1
183
+ start_positions.append(idx - 1)
184
+
185
+ idx = context_end
186
+ while idx >= context_start and offset[idx][1] >= end_char:
187
+ idx -= 1
188
+ end_positions.append(idx + 1)
189
+
190
+ inputs["start_positions"] = start_positions
191
+ inputs["end_positions"] = end_positions
192
+ return inputs
193
+
194
+ print("Preprocessing dataset...")
195
+ processed_dataset = dataset.map(
196
+ preprocess_function,
197
+ batched=True,
198
+ remove_columns=dataset["train"].column_names,
199
+ )
200
+
201
+ print("Splitting dataset...")
202
+ train_dataset = processed_dataset["train"]
203
+ val_dataset = processed_dataset["validation"]
204
+
205
+ train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
206
+ val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
207
+
208
+ training_args = TrainingArguments(
209
+ output_dir="./fine_tuned_legal_qa",
210
+ evaluation_strategy="steps",
211
+ eval_steps=100,
212
+ learning_rate=2e-5,
213
+ per_device_train_batch_size=16,
214
+ per_device_eval_batch_size=16,
215
+ num_train_epochs=1,
216
+ weight_decay=0.01,
217
+ logging_steps=50,
218
+ save_steps=100,
219
+ load_best_model_at_end=True,
220
+ report_to=[]
221
+ )
222
+
223
+ print("✅ Starting fine tuning on CUAD QA dataset...")
224
+ trainer = Trainer(
225
+ model=model,
226
+ args=training_args,
227
+ train_dataset=train_dataset,
228
+ eval_dataset=val_dataset,
229
+ tokenizer=tokenizer,
230
+ )
231
+
232
+ trainer.train()
233
+ print("✅ Fine tuning completed. Saving model...")
234
+
235
+ model.save_pretrained("./fine_tuned_legal_qa")
236
+ tokenizer.save_pretrained("./fine_tuned_legal_qa")
237
+
238
+ return tokenizer, model
239
+
240
+ #############################
241
+ # Load NLP Models #
242
+ #############################
243
+
244
+ # Initialize model variables
245
+ nlp = None
246
+ summarizer = None
247
+ embedding_model = None
248
+ ner_model = None
249
+ speech_to_text = None
250
+ cuad_model = None
251
+ cuad_tokenizer = None
252
+ qa_model = None
253
+
254
+ # Add model caching functionality
255
+ import pickle
256
+ import os.path
257
+
258
+ MODELS_CACHE_DIR = "c:\\Users\\hardi\\OneDrive\\Desktop\\New folder (7)\\doc-vid-analyze-main\\models_cache"
259
+ os.makedirs(MODELS_CACHE_DIR, exist_ok=True)
260
+
261
+ def save_model_to_cache(model, model_name):
262
+ """Save a model to the cache directory"""
263
+ try:
264
+ cache_path = os.path.join(MODELS_CACHE_DIR, f"{model_name}.pkl")
265
+ with open(cache_path, 'wb') as f:
266
+ pickle.dump(model, f)
267
+ print(f"✅ Saved {model_name} to cache")
268
+ return True
269
+ except Exception as e:
270
+ print(f"⚠️ Failed to save {model_name} to cache: {str(e)}")
271
+ return False
272
+
273
+ def load_model_from_cache(model_name):
274
+ """Load a model from the cache directory"""
275
+ try:
276
+ cache_path = os.path.join(MODELS_CACHE_DIR, f"{model_name}.pkl")
277
+ if os.path.exists(cache_path):
278
+ with open(cache_path, 'rb') as f:
279
+ model = pickle.load(f)
280
+ print(f"✅ Loaded {model_name} from cache")
281
+ return model
282
+ return None
283
+ except Exception as e:
284
+ print(f"⚠️ Failed to load {model_name} from cache: {str(e)}")
285
+ return None
286
+
287
+ # Add a flag to control model loading
288
+ LOAD_MODELS = os.getenv("LOAD_MODELS", "True").lower() in ("true", "1", "t")
289
+
290
+ try:
291
+ if LOAD_MODELS:
292
+ # Try to load SpaCy from cache first
293
+ nlp = load_model_from_cache("spacy_model")
294
+ if nlp is None:
295
+ try:
296
+ nlp = spacy.load("en_core_web_sm")
297
+ save_model_to_cache(nlp, "spacy_model")
298
+ except:
299
+ print("⚠️ SpaCy model not found, downloading...")
300
+ spacy.cli.download("en_core_web_sm")
301
+ nlp = spacy.load("en_core_web_sm")
302
+ save_model_to_cache(nlp, "spacy_model")
303
+
304
+ print("✅ Loading NLP models...")
305
+
306
+ # Load the summarizer with caching
307
+ print("Loading summarizer model...")
308
+ summarizer = load_model_from_cache("summarizer_model")
309
+ if summarizer is None:
310
+ try:
311
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn",
312
+ device=0 if torch.cuda.is_available() else -1)
313
+ save_model_to_cache(summarizer, "summarizer_model")
314
+ print("✅ Summarizer loaded successfully")
315
+ except Exception as e:
316
+ print(f"⚠️ Error loading summarizer: {str(e)}")
317
+ try:
318
+ print("Trying alternative summarizer model...")
319
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6",
320
+ device=0 if torch.cuda.is_available() else -1)
321
+ save_model_to_cache(summarizer, "summarizer_model")
322
+ print("✅ Alternative summarizer loaded successfully")
323
+ except Exception as e2:
324
+ print(f"⚠️ Error loading alternative summarizer: {str(e2)}")
325
+ summarizer = None
326
+
327
+ # Load the embedding model with caching
328
+ print("Loading embedding model...")
329
+ embedding_model = load_model_from_cache("embedding_model")
330
+ if embedding_model is None:
331
+ try:
332
+ embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
333
+ save_model_to_cache(embedding_model, "embedding_model")
334
+ print("✅ Embedding model loaded successfully")
335
+ except Exception as e:
336
+ print(f"⚠️ Error loading embedding model: {str(e)}")
337
+ embedding_model = None
338
+
339
+ # Load the NER model with caching
340
+ print("Loading NER model...")
341
+ ner_model = load_model_from_cache("ner_model")
342
+ if ner_model is None:
343
+ try:
344
+ ner_model = pipeline("ner", model="dslim/bert-base-NER",
345
+ device=0 if torch.cuda.is_available() else -1)
346
+ save_model_to_cache(ner_model, "ner_model")
347
+ print("✅ NER model loaded successfully")
348
+ except Exception as e:
349
+ print(f"⚠️ Error loading NER model: {str(e)}")
350
+ ner_model = None
351
+
352
+ # Speech to text model with caching
353
+ print("Loading speech to text model...")
354
+ speech_to_text = load_model_from_cache("speech_to_text_model")
355
+ if speech_to_text is None:
356
+ try:
357
+ speech_to_text = pipeline("automatic-speech-recognition",
358
+ model="openai/whisper-medium",
359
+ chunk_length_s=30,
360
+ device_map="auto" if torch.cuda.is_available() else "cpu")
361
+ save_model_to_cache(speech_to_text, "speech_to_text_model")
362
+ print("✅ Speech to text model loaded successfully")
363
+ except Exception as e:
364
+ print(f"⚠️ Error loading speech to text model: {str(e)}")
365
+ speech_to_text = None
366
+
367
+ # Load the fine-tuned model with caching
368
+ print("Loading fine-tuned CUAD QA model...")
369
+ cuad_model = load_model_from_cache("cuad_model")
370
+ cuad_tokenizer = load_model_from_cache("cuad_tokenizer")
371
+
372
+ if cuad_model is None or cuad_tokenizer is None:
373
+ try:
374
+ cuad_tokenizer = AutoTokenizer.from_pretrained("hardik8588/fine-tuned-legal-qa")
375
+ from transformers import AutoModelForQuestionAnswering
376
+ cuad_model = AutoModelForQuestionAnswering.from_pretrained("hardik8588/fine-tuned-legal-qa")
377
+ cuad_model.to(device)
378
+ save_model_to_cache(cuad_tokenizer, "cuad_tokenizer")
379
+ save_model_to_cache(cuad_model, "cuad_model")
380
+ print("✅ Successfully loaded fine-tuned model")
381
+ except Exception as e:
382
+ print(f"⚠️ Error loading fine-tuned model: {str(e)}")
383
+ print("⚠️ Falling back to pre-trained model...")
384
+ try:
385
+ cuad_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
386
+ from transformers import AutoModelForQuestionAnswering
387
+ cuad_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
388
+ cuad_model.to(device)
389
+ save_model_to_cache(cuad_tokenizer, "cuad_tokenizer")
390
+ save_model_to_cache(cuad_model, "cuad_model")
391
+ print("✅ Pre-trained model loaded successfully")
392
+ except Exception as e2:
393
+ print(f"⚠️ Error loading pre-trained model: {str(e2)}")
394
+ cuad_model = None
395
+ cuad_tokenizer = None
396
+
397
+ # Load a general QA model with caching
398
+ print("Loading general QA model...")
399
+ qa_model = load_model_from_cache("qa_model")
400
+ if qa_model is None:
401
+ try:
402
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
403
+ save_model_to_cache(qa_model, "qa_model")
404
+ print("✅ QA model loaded successfully")
405
+ except Exception as e:
406
+ print(f"⚠️ Error loading QA model: {str(e)}")
407
+ qa_model = None
408
+
409
+ print("✅ All models loaded successfully")
410
+ else:
411
+ print("⚠️ Model loading skipped (LOAD_MODELS=False)")
412
+
413
+ except Exception as e:
414
+ print(f"⚠️ Error loading models: {str(e)}")
415
+ # Instead of raising an error, set fallback behavior
416
+ nlp = None
417
+ summarizer = None
418
+ embedding_model = None
419
+ ner_model = None
420
+ speech_to_text = None
421
+ cuad_model = None
422
+ cuad_tokenizer = None
423
+ qa_model = None
424
+ print("⚠️ Running with limited functionality due to model loading errors")
425
+
426
+ def legal_chatbot(user_input, context):
427
+ """Uses a real NLP model for legal Q&A."""
428
+ global chat_history
429
+ chat_history.append({"role": "user", "content": user_input})
430
+ response = qa_model(question=user_input, context=context)["answer"]
431
+ chat_history.append({"role": "assistant", "content": response})
432
+ return response
433
+
434
+ def extract_text_from_pdf(pdf_file):
435
+ """Extracts text from a PDF file using pdfplumber."""
436
+ try:
437
+ # Suppress pdfplumber warnings about CropBox
438
+ import logging
439
+ logging.getLogger("pdfminer").setLevel(logging.ERROR)
440
+
441
+ with pdfplumber.open(pdf_file) as pdf:
442
+ print(f"Processing PDF with {len(pdf.pages)} pages")
443
+ text = ""
444
+ for i, page in enumerate(pdf.pages):
445
+ page_text = page.extract_text() or ""
446
+ text += page_text + "\n"
447
+ if (i + 1) % 10 == 0: # Log progress every 10 pages
448
+ print(f"Processed {i + 1} pages...")
449
+
450
+ print(f"✅ PDF text extraction complete: {len(text)} characters extracted")
451
+ return text.strip() if text else None
452
+ except Exception as e:
453
+ print(f"❌ PDF extraction error: {str(e)}")
454
+ raise HTTPException(status_code=400, detail=f"PDF extraction failed: {str(e)}")
455
+
456
+ def process_video_to_text(video_file_path):
457
+ """Extract audio from video and convert to text."""
458
+ try:
459
+ print(f"Processing video file at {video_file_path}")
460
+ temp_audio_path = os.path.join("temp", "extracted_audio.wav")
461
+ video = mp.VideoFileClip(video_file_path)
462
+ video.audio.write_audiofile(temp_audio_path, codec='pcm_s16le')
463
+ print(f"Audio extracted to {temp_audio_path}")
464
+ result = speech_to_text(temp_audio_path)
465
+ transcript = result["text"]
466
+ print(f"Transcription completed: {len(transcript)} characters")
467
+ if os.path.exists(temp_audio_path):
468
+ os.remove(temp_audio_path)
469
+ return transcript
470
+ except Exception as e:
471
+ print(f"Error in video processing: {str(e)}")
472
+ raise HTTPException(status_code=400, detail=f"Video processing failed: {str(e)}")
473
+
474
+ def process_audio_to_text(audio_file_path):
475
+ """Process audio file and convert to text."""
476
+ try:
477
+ print(f"Processing audio file at {audio_file_path}")
478
+ result = speech_to_text(audio_file_path)
479
+ transcript = result["text"]
480
+ print(f"Transcription completed: {len(transcript)} characters")
481
+ return transcript
482
+ except Exception as e:
483
+ print(f"Error in audio processing: {str(e)}")
484
+ raise HTTPException(status_code=400, detail=f"Audio processing failed: {str(e)}")
485
+
486
+ def extract_named_entities(text):
487
+ """Extracts named entities from legal text."""
488
+ max_length = 10000
489
+ entities = []
490
+ for i in range(0, len(text), max_length):
491
+ chunk = text[i:i+max_length]
492
+ doc = nlp(chunk)
493
+ entities.extend([{"entity": ent.text, "label": ent.label_} for ent in doc.ents])
494
+ return entities
495
+
496
+ def analyze_risk(text):
497
+ """Analyzes legal risk in the document using keyword-based analysis."""
498
+ risk_keywords = {
499
+ "Liability": ["liability", "responsible", "responsibility", "legal obligation"],
500
+ "Termination": ["termination", "breach", "contract end", "default"],
501
+ "Indemnification": ["indemnification", "indemnify", "hold harmless", "compensate", "compensation"],
502
+ "Payment Risk": ["payment", "terms", "reimbursement", "fee", "schedule", "invoice", "money"],
503
+ "Insurance": ["insurance", "coverage", "policy", "claims"],
504
+ }
505
+ risk_scores = {category: 0 for category in risk_keywords}
506
+ lower_text = text.lower()
507
+ for category, keywords in risk_keywords.items():
508
+ for keyword in keywords:
509
+ risk_scores[category] += lower_text.count(keyword.lower())
510
+ return risk_scores
511
+
512
+ def extract_context_for_risk_terms(text, risk_keywords, window=1):
513
+ """
514
+ Extracts and summarizes the context around risk terms.
515
+ """
516
+ doc = nlp(text)
517
+ sentences = list(doc.sents)
518
+ risk_contexts = {category: [] for category in risk_keywords}
519
+ for i, sent in enumerate(sentences):
520
+ sent_text_lower = sent.text.lower()
521
+ for category, details in risk_keywords.items():
522
+ for keyword in details["keywords"]:
523
+ if keyword.lower() in sent_text_lower:
524
+ start_idx = max(0, i - window)
525
+ end_idx = min(len(sentences), i + window + 1)
526
+ context_chunk = " ".join([s.text for s in sentences[start_idx:end_idx]])
527
+ risk_contexts[category].append(context_chunk)
528
+ summarized_contexts = {}
529
+ for category, contexts in risk_contexts.items():
530
+ if contexts:
531
+ combined_context = " ".join(contexts)
532
+ try:
533
+ summary_result = summarizer(combined_context, max_length=100, min_length=30, do_sample=False)
534
+ summary = summary_result[0]['summary_text']
535
+ except Exception as e:
536
+ summary = "Context summarization failed."
537
+ summarized_contexts[category] = summary
538
+ else:
539
+ summarized_contexts[category] = "No contextual details found."
540
+ return summarized_contexts
541
+
542
+ def get_detailed_risk_info(text):
543
+ """
544
+ Returns detailed risk information by merging risk scores with descriptive details
545
+ and contextual summaries from the document.
546
+ """
547
+ risk_details = {
548
+ "Liability": {
549
+ "description": "Liability refers to the legal responsibility for losses or damages.",
550
+ "common_concerns": "Broad liability clauses may expose parties to unforeseen risks.",
551
+ "recommendations": "Review and negotiate clear limits on liability.",
552
+ "example": "E.g., 'The party shall be liable for direct damages due to negligence.'"
553
+ },
554
+ "Termination": {
555
+ "description": "Termination involves conditions under which a contract can be ended.",
556
+ "common_concerns": "Unilateral termination rights or ambiguous conditions can be risky.",
557
+ "recommendations": "Ensure termination clauses are balanced and include notice periods.",
558
+ "example": "E.g., 'Either party may terminate the agreement with 30 days notice.'"
559
+ },
560
+ "Indemnification": {
561
+ "description": "Indemnification requires one party to compensate for losses incurred by the other.",
562
+ "common_concerns": "Overly broad indemnification can shift significant risk.",
563
+ "recommendations": "Negotiate clear limits and carve-outs where necessary.",
564
+ "example": "E.g., 'The seller shall indemnify the buyer against claims from product defects.'"
565
+ },
566
+ "Payment Risk": {
567
+ "description": "Payment risk pertains to terms regarding fees, schedules, and reimbursements.",
568
+ "common_concerns": "Vague payment terms or hidden charges increase risk.",
569
+ "recommendations": "Clarify payment conditions and include penalties for delays.",
570
+ "example": "E.g., 'Payments must be made within 30 days, with a 2% late fee thereafter.'"
571
+ },
572
+ "Insurance": {
573
+ "description": "Insurance risk covers the adequacy and scope of required coverage.",
574
+ "common_concerns": "Insufficient insurance can leave parties exposed in unexpected events.",
575
+ "recommendations": "Review insurance requirements to ensure they meet the risk profile.",
576
+ "example": "E.g., 'The contractor must maintain liability insurance with at least $1M coverage.'"
577
+ }
578
+ }
579
+ risk_scores = analyze_risk(text)
580
+ risk_keywords_context = {
581
+ "Liability": {"keywords": ["liability", "responsible", "responsibility", "legal obligation"]},
582
+ "Termination": {"keywords": ["termination", "breach", "contract end", "default"]},
583
+ "Indemnification": {"keywords": ["indemnification", "indemnify", "hold harmless", "compensate", "compensation"]},
584
+ "Payment Risk": {"keywords": ["payment", "terms", "reimbursement", "fee", "schedule", "invoice", "money"]},
585
+ "Insurance": {"keywords": ["insurance", "coverage", "policy", "claims"]}
586
+ }
587
+ risk_contexts = extract_context_for_risk_terms(text, risk_keywords_context, window=1)
588
+ detailed_info = {}
589
+ for risk_term, score in risk_scores.items():
590
+ if score > 0:
591
+ info = risk_details.get(risk_term, {"description": "No details available."})
592
+ detailed_info[risk_term] = {
593
+ "score": score,
594
+ "description": info.get("description", ""),
595
+ "common_concerns": info.get("common_concerns", ""),
596
+ "recommendations": info.get("recommendations", ""),
597
+ "example": info.get("example", ""),
598
+ "context_summary": risk_contexts.get(risk_term, "No context available.")
599
+ }
600
+ return detailed_info
601
+
602
+ def analyze_contract_clauses(text):
603
+ """Analyzes contract clauses using the fine-tuned CUAD QA model."""
604
+ max_length = 512
605
+ step = 256
606
+ clauses_detected = []
607
+ try:
608
+ clause_types = list(cuad_model.config.id2label.values())
609
+ except Exception as e:
610
+ clause_types = [
611
+ "Obligations of Seller", "Governing Law", "Termination", "Indemnification",
612
+ "Confidentiality", "Insurance", "Non-Compete", "Change of Control",
613
+ "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
614
+ "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
615
+ ]
616
+ chunks = [text[i:i+max_length] for i in range(0, len(text), step) if i+step < len(text)]
617
+ for chunk in chunks:
618
+ inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512).to(device)
619
+ with torch.no_grad():
620
+ outputs = cuad_model(**inputs)
621
+ predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
622
+ for idx, confidence in enumerate(predictions):
623
+ if confidence > 0.5 and idx < len(clause_types):
624
+ clauses_detected.append({"type": clause_types[idx], "confidence": float(confidence)})
625
+ aggregated_clauses = {}
626
+ for clause in clauses_detected:
627
+ clause_type = clause["type"]
628
+ if clause_type not in aggregated_clauses or clause["confidence"] > aggregated_clauses[clause_type]["confidence"]:
629
+ aggregated_clauses[clause_type] = clause
630
+ return list(aggregated_clauses.values())
631
+
632
+ def summarize_text(text):
633
+ """Summarizes legal text using the summarizer model."""
634
+ try:
635
+ if summarizer is None:
636
+ return "Basic analysis (NLP models not available)"
637
+
638
+ # Split text into chunks if it's too long
639
+ max_chunk_size = 1024
640
+ if len(text) > max_chunk_size:
641
+ chunks = [text[i:i+max_chunk_size] for i in range(0, len(text), max_chunk_size)]
642
+ summaries = []
643
+ for chunk in chunks:
644
+ summary = summarizer(chunk, max_length=100, min_length=30, do_sample=False)
645
+ summaries.append(summary[0]['summary_text'])
646
+ return " ".join(summaries)
647
+ else:
648
+ summary = summarizer(text, max_length=100, min_length=30, do_sample=False)
649
+ return summary[0]['summary_text']
650
+ except Exception as e:
651
+ print(f"Error in summarization: {str(e)}")
652
+ return "Summarization failed. Please try again later."
653
+
654
+ @app.post("/analyze_legal_document")
655
+ async def analyze_legal_document(
656
+ file: UploadFile = File(...),
657
+ current_user: User = Depends(get_current_active_user)
658
+ ):
659
+ """Analyzes a legal document (PDF) and returns insights based on subscription tier."""
660
+ try:
661
+ # Calculate file size in MB
662
+ file_content = await file.read()
663
+ file_size_mb = len(file_content) / (1024 * 1024)
664
+
665
+ # Check subscription access for document analysis
666
+ check_subscription_access(current_user, "document_analysis", file_size_mb)
667
+
668
+ print(f"Processing file: {file.filename}")
669
+
670
+ # Create a temporary file to store the uploaded PDF
671
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
672
+ tmp.write(file_content)
673
+ tmp_path = tmp.name
674
+
675
+ # Extract text from PDF
676
+ text = extract_text_from_pdf(tmp_path)
677
+
678
+ # Clean up the temporary file
679
+ os.unlink(tmp_path)
680
+
681
+ if not text:
682
+ raise HTTPException(status_code=400, detail="Could not extract text from PDF")
683
+
684
+ # Generate a task ID
685
+ task_id = str(uuid.uuid4())
686
+
687
+ # Store document context for later retrieval
688
+ store_document_context(task_id, text)
689
+
690
+ # Basic analysis available to all tiers
691
+ summary = summarize_text(text)
692
+ entities = extract_named_entities(text)
693
+ risk_scores = analyze_risk(text)
694
+
695
+ # Prepare response based on subscription tier
696
+ response = {
697
+ "task_id": task_id,
698
+ "summary": summary,
699
+ "entities": entities,
700
+ "risk_assessment": risk_scores,
701
+ "subscription_tier": current_user.subscription_tier
702
+ }
703
+
704
+ # Add premium features if user has access
705
+ if current_user.subscription_tier == "premium_tier":
706
+ # Add detailed risk assessment
707
+ if "detailed_risk_assessment" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
708
+ detailed_risk = get_detailed_risk_info(text)
709
+ response["detailed_risk_assessment"] = detailed_risk
710
+
711
+ # Add contract clause analysis
712
+ if "contract_clause_analysis" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
713
+ clauses = analyze_contract_clauses(text)
714
+ response["contract_clauses"] = clauses
715
+
716
+ return response
717
+
718
+ except Exception as e:
719
+ print(f"Error analyzing document: {str(e)}")
720
+ raise HTTPException(status_code=500, detail=f"Error analyzing document: {str(e)}")
721
+
722
+ # Add this function to check resource limits based on subscription tier
723
+ def check_resource_limits(user: User, resource_type: str, size_mb: float = None, count: int = 1):
724
+ """
725
+ Check if the user has exceeded their subscription limits for a specific resource
726
+
727
+ Args:
728
+ user: The user making the request
729
+ resource_type: Type of resource (document, video, audio)
730
+ size_mb: Size of the resource in MB
731
+ count: Number of resources being used (default 1)
732
+
733
+ Returns:
734
+ bool: True if within limits, raises HTTPException otherwise
735
+ """
736
+ # Get the user's subscription tier limits
737
+ tier = user.subscription_tier
738
+ tier_limits = SUBSCRIPTION_TIERS.get(tier, SUBSCRIPTION_TIERS["free_tier"])["limits"]
739
+
740
+ # Check size limits
741
+ if size_mb is not None:
742
+ if resource_type == "document" and size_mb > tier_limits["document_size_mb"]:
743
+ raise HTTPException(
744
+ status_code=status.HTTP_403_FORBIDDEN,
745
+ detail=f"Document size exceeds the {tier_limits['document_size_mb']}MB limit for your {tier} subscription"
746
+ )
747
+ elif resource_type == "video" and size_mb > tier_limits["video_size_mb"]:
748
+ raise HTTPException(
749
+ status_code=status.HTTP_403_FORBIDDEN,
750
+ detail=f"Video size exceeds the {tier_limits['video_size_mb']}MB limit for your {tier} subscription"
751
+ )
752
+ elif resource_type == "audio" and size_mb > tier_limits["audio_size_mb"]:
753
+ raise HTTPException(
754
+ status_code=status.HTTP_403_FORBIDDEN,
755
+ detail=f"Audio size exceeds the {tier_limits['audio_size_mb']}MB limit for your {tier} subscription"
756
+ )
757
+
758
+ # Check monthly document count
759
+ if resource_type == "document":
760
+ # Get current month and year
761
+ now = datetime.now()
762
+ month, year = now.month, now.year
763
+
764
+ # Check usage stats for current month
765
+ conn = get_db_connection()
766
+ cursor = conn.cursor()
767
+ cursor.execute(
768
+ "SELECT analyses_used FROM usage_stats WHERE user_id = ? AND month = ? AND year = ?",
769
+ (user.id, month, year)
770
+ )
771
+ result = cursor.fetchone()
772
+
773
+ current_usage = result[0] if result else 0
774
+
775
+ # Check if adding this usage would exceed the limit
776
+ if current_usage + count > tier_limits["documents_per_month"]:
777
+ conn.close()
778
+ raise HTTPException(
779
+ status_code=status.HTTP_403_FORBIDDEN,
780
+ detail=f"You have reached your monthly limit of {tier_limits['documents_per_month']} document analyses for your {tier} subscription"
781
+ )
782
+
783
+ # Update usage stats
784
+ if result:
785
+ cursor.execute(
786
+ "UPDATE usage_stats SET analyses_used = ? WHERE user_id = ? AND month = ? AND year = ?",
787
+ (current_usage + count, user.id, month, year)
788
+ )
789
+ else:
790
+ usage_id = str(uuid.uuid4())
791
+ cursor.execute(
792
+ "INSERT INTO usage_stats (id, user_id, month, year, analyses_used) VALUES (?, ?, ?, ?, ?)",
793
+ (usage_id, user.id, month, year, count)
794
+ )
795
+
796
+ conn.commit()
797
+ conn.close()
798
+
799
+ # Check if feature is available in the tier
800
+ if resource_type == "video" and tier_limits["video_size_mb"] == 0:
801
+ raise HTTPException(
802
+ status_code=status.HTTP_403_FORBIDDEN,
803
+ detail=f"Video analysis is not available in your {tier} subscription"
804
+ )
805
+
806
+ if resource_type == "audio" and tier_limits["audio_size_mb"] == 0:
807
+ raise HTTPException(
808
+ status_code=status.HTTP_403_FORBIDDEN,
809
+ detail=f"Audio analysis is not available in your {tier} subscription"
810
+ )
811
+
812
+ return True
813
+
814
+ @app.post("/analyze_legal_video")
815
+ async def analyze_legal_video(
816
+ file: UploadFile = File(...),
817
+ current_user: User = Depends(get_current_active_user)
818
+ ):
819
+ """Analyzes legal video by transcribing and analyzing the transcript."""
820
+ try:
821
+ # Calculate file size in MB
822
+ file_content = await file.read()
823
+ file_size_mb = len(file_content) / (1024 * 1024)
824
+
825
+ # Check subscription access for video analysis
826
+ check_subscription_access(current_user, "video_analysis", file_size_mb)
827
+
828
+ print(f"Processing video file: {file.filename}")
829
+
830
+ # Create a temporary file to store the uploaded video
831
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp:
832
+ tmp.write(file_content)
833
+ tmp_path = tmp.name
834
+
835
+ # Process video to extract transcript
836
+ transcript = process_video_to_text(tmp_path)
837
+
838
+ # Clean up the temporary file
839
+ os.unlink(tmp_path)
840
+
841
+ if not transcript:
842
+ raise HTTPException(status_code=400, detail="Could not extract transcript from video")
843
+
844
+ # Generate a task ID
845
+ task_id = str(uuid.uuid4())
846
+
847
+ # Store document context for later retrieval
848
+ store_document_context(task_id, transcript)
849
+
850
+ # Basic analysis
851
+ summary = summarize_text(transcript)
852
+ entities = extract_named_entities(transcript)
853
+ risk_scores = analyze_risk(transcript)
854
+
855
+ # Prepare response
856
+ response = {
857
+ "task_id": task_id,
858
+ "transcript": transcript,
859
+ "summary": summary,
860
+ "entities": entities,
861
+ "risk_assessment": risk_scores,
862
+ "subscription_tier": current_user.subscription_tier
863
+ }
864
+
865
+ # Add premium features if user has access
866
+ if current_user.subscription_tier == "premium_tier":
867
+ # Add detailed risk assessment
868
+ if "detailed_risk_assessment" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
869
+ detailed_risk = get_detailed_risk_info(transcript)
870
+ response["detailed_risk_assessment"] = detailed_risk
871
+
872
+ return response
873
+
874
+ except Exception as e:
875
+ print(f"Error analyzing video: {str(e)}")
876
+ raise HTTPException(status_code=500, detail=f"Error analyzing video: {str(e)}")
877
+
878
+
879
+ @app.post("/legal_chatbot/{task_id}")
880
+ async def chat_with_document(
881
+ task_id: str,
882
+ question: str = Form(...),
883
+ current_user: User = Depends(get_current_active_user)
884
+ ):
885
+ """Chat with a document using the legal chatbot."""
886
+ try:
887
+ # Check if user has access to chatbot feature
888
+ if "chatbot" not in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
889
+ raise HTTPException(
890
+ status_code=403,
891
+ detail=f"The chatbot feature is not available in your {current_user.subscription_tier} subscription. Please upgrade to access this feature."
892
+ )
893
+
894
+ # Check if document context exists
895
+ context = load_document_context(task_id)
896
+ if not context:
897
+ raise HTTPException(status_code=404, detail="Document context not found. Please analyze a document first.")
898
+
899
+ # Use the chatbot to answer the question
900
+ answer = legal_chatbot(question, context)
901
+
902
+ return {"answer": answer, "chat_history": chat_history}
903
+
904
+ except Exception as e:
905
+ print(f"Error in chatbot: {str(e)}")
906
+ raise HTTPException(status_code=500, detail=f"Error in chatbot: {str(e)}")
907
+
908
+ @app.get("/")
909
+ async def root():
910
+ """Root endpoint that returns a welcome message."""
911
+ return HTMLResponse(content="""
912
+ <html>
913
+ <head>
914
+ <title>Legal Document Analysis API</title>
915
+ <style>
916
+ body {
917
+ font-family: Arial, sans-serif;
918
+ max-width: 800px;
919
+ margin: 0 auto;
920
+ padding: 20px;
921
+ }
922
+ h1 {
923
+ color: #2c3e50;
924
+ }
925
+ .endpoint {
926
+ background-color: #f8f9fa;
927
+ padding: 15px;
928
+ margin-bottom: 10px;
929
+ border-radius: 5px;
930
+ }
931
+ .method {
932
+ font-weight: bold;
933
+ color: #e74c3c;
934
+ }
935
+ </style>
936
+ </head>
937
+ <body>
938
+ <h1>Legal Document Analysis API</h1>
939
+ <p>Welcome to the Legal Document Analysis API. This API provides tools for analyzing legal documents, videos, and audio.</p>
940
+ <h2>Available Endpoints:</h2>
941
+ <div class="endpoint">
942
+ <p><span class="method">POST</span> /analyze_legal_document - Analyze a legal document (PDF)</p>
943
+ </div>
944
+ <div class="endpoint">
945
+ <p><span class="method">POST</span> /analyze_legal_video - Analyze a legal video</p>
946
+ </div>
947
+ <div class="endpoint">
948
+ <p><span class="method">POST</span> /analyze_legal_audio - Analyze legal audio</p>
949
+ </div>
950
+ <div class="endpoint">
951
+ <p><span class="method">POST</span> /legal_chatbot/{task_id} - Chat with a document</p>
952
+ </div>
953
+ <div class="endpoint">
954
+ <p><span class="method">POST</span> /register - Register a new user</p>
955
+ </div>
956
+ <div class="endpoint">
957
+ <p><span class="method">POST</span> /token - Login to get an access token</p>
958
+ </div>
959
+ <div class="endpoint">
960
+ <p><span class="method">GET</span> /users/me - Get current user information</p>
961
+ </div>
962
+ <div class="endpoint">
963
+ <p><span class="method">POST</span> /subscribe/{tier} - Subscribe to a plan</p>
964
+ </div>
965
+ <p>For more details, visit the <a href="/docs">API documentation</a>.</p>
966
+ </body>
967
+ </html>
968
+ """)
969
+
970
+ @app.post("/register", response_model=Token)
971
+ async def register_new_user(user_data: UserCreate):
972
+ """Register a new user with a free subscription"""
973
+ try:
974
+ success, result = register_user(user_data.email, user_data.password)
975
+
976
+ if not success:
977
+ raise HTTPException(status_code=400, detail=result)
978
+
979
+ return {"access_token": result["access_token"], "token_type": "bearer"}
980
+
981
+ except HTTPException:
982
+ # Re-raise HTTP exceptions
983
+ raise
984
+ except Exception as e:
985
+ print(f"Registration error: {str(e)}")
986
+ raise HTTPException(status_code=500, detail=f"Registration failed: {str(e)}")
987
+
988
+ @app.post("/token", response_model=Token)
989
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
990
+ """Endpoint for OAuth2 token generation"""
991
+ try:
992
+ # Add debug logging
993
+ logger.info(f"Token request for username: {form_data.username}")
994
+
995
+ user = authenticate_user(form_data.username, form_data.password)
996
+ if not user:
997
+ logger.warning(f"Authentication failed for: {form_data.username}")
998
+ raise HTTPException(
999
+ status_code=status.HTTP_401_UNAUTHORIZED,
1000
+ detail="Incorrect username or password",
1001
+ headers={"WWW-Authenticate": "Bearer"},
1002
+ )
1003
+
1004
+ access_token = create_access_token(user.id)
1005
+ if not access_token:
1006
+ logger.error(f"Failed to create access token for user: {user.id}")
1007
+ raise HTTPException(
1008
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1009
+ detail="Could not create access token",
1010
+ )
1011
+
1012
+ logger.info(f"Login successful for: {form_data.username}")
1013
+ return {"access_token": access_token, "token_type": "bearer"}
1014
+ except Exception as e:
1015
+ logger.error(f"Token endpoint error: {e}")
1016
+ raise HTTPException(
1017
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1018
+ detail=f"Login error: {str(e)}",
1019
+ )
1020
+
1021
+
1022
+ @app.get("/debug/token")
1023
+ async def debug_token(authorization: str = Header(None)):
1024
+ """Debug endpoint to check token validity"""
1025
+ try:
1026
+ if not authorization:
1027
+ return {"valid": False, "error": "No authorization header provided"}
1028
+
1029
+ # Extract token from Authorization header
1030
+ scheme, token = authorization.split()
1031
+ if scheme.lower() != 'bearer':
1032
+ return {"valid": False, "error": "Not a bearer token"}
1033
+
1034
+ # Log the token for debugging
1035
+ logger.info(f"Debugging token: {token[:10]}...")
1036
+
1037
+ # Try to validate the token
1038
+ try:
1039
+ user = await get_current_active_user(token)
1040
+ return {"valid": True, "user_id": user.id, "email": user.email}
1041
+ except Exception as e:
1042
+ return {"valid": False, "error": str(e)}
1043
+ except Exception as e:
1044
+ return {"valid": False, "error": f"Token debug error: {str(e)}"}
1045
+
1046
+
1047
+ @app.post("/login")
1048
+ async def api_login(email: str, password: str):
1049
+ success, result = login_user(email, password)
1050
+ if not success:
1051
+ raise HTTPException(
1052
+ status_code=status.HTTP_401_UNAUTHORIZED,
1053
+ detail=result
1054
+ )
1055
+ return result
1056
+
1057
+ @app.get("/health")
1058
+ def health_check():
1059
+ """Simple health check endpoint to verify the API is running"""
1060
+ return {"status": "ok", "message": "API is running"}
1061
+
1062
+ @app.get("/users/me", response_model=User)
1063
+ async def read_users_me(current_user: User = Depends(get_current_active_user)):
1064
+ return current_user
1065
+
1066
+ @app.post("/analyze_legal_audio")
1067
+ async def analyze_legal_audio(
1068
+ file: UploadFile = File(...),
1069
+ current_user: User = Depends(get_current_active_user)
1070
+ ):
1071
+ """Analyzes legal audio by transcribing and analyzing the transcript."""
1072
+ try:
1073
+ # Calculate file size in MB
1074
+ file_content = await file.read()
1075
+ file_size_mb = len(file_content) / (1024 * 1024)
1076
+
1077
+ # Check subscription access for audio analysis
1078
+ check_subscription_access(current_user, "audio_analysis", file_size_mb)
1079
+
1080
+ print(f"Processing audio file: {file.filename}")
1081
+
1082
+ # Create a temporary file to store the uploaded audio
1083
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp:
1084
+ tmp.write(file_content)
1085
+ tmp_path = tmp.name
1086
+
1087
+ # Process audio to extract transcript
1088
+ transcript = process_audio_to_text(tmp_path)
1089
+
1090
+ # Clean up the temporary file
1091
+ os.unlink(tmp_path)
1092
+
1093
+ if not transcript:
1094
+ raise HTTPException(status_code=400, detail="Could not extract transcript from audio")
1095
+
1096
+ # Generate a task ID
1097
+ task_id = str(uuid.uuid4())
1098
+
1099
+ # Store document context for later retrieval
1100
+ store_document_context(task_id, transcript)
1101
+
1102
+ # Basic analysis
1103
+ summary = summarize_text(transcript)
1104
+ entities = extract_named_entities(transcript)
1105
+ risk_scores = analyze_risk(transcript)
1106
+
1107
+ # Prepare response
1108
+ response = {
1109
+ "task_id": task_id,
1110
+ "transcript": transcript,
1111
+ "summary": summary,
1112
+ "entities": entities,
1113
+ "risk_assessment": risk_scores,
1114
+ "subscription_tier": current_user.subscription_tier
1115
+ }
1116
+
1117
+ # Add premium features if user has access
1118
+ if current_user.subscription_tier == "premium_tier": # Change from premium_tier to premium
1119
+ # Add detailed risk assessment
1120
+ if "detailed_risk_assessment" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
1121
+ detailed_risk = get_detailed_risk_info(transcript)
1122
+ response["detailed_risk_assessment"] = detailed_risk
1123
+
1124
+ return response
1125
+
1126
+ except Exception as e:
1127
+ print(f"Error analyzing audio: {str(e)}")
1128
+ raise HTTPException(status_code=500, detail=f"Error analyzing audio: {str(e)}")
1129
+
1130
+
1131
+
1132
+ # Add these new endpoints before the if __name__ == "__main__" line
1133
+ @app.get("/users/me/subscription")
1134
+ async def get_user_subscription(current_user: User = Depends(get_current_active_user)):
1135
+ """Get the current user's subscription details"""
1136
+ try:
1137
+ # Get subscription details from database
1138
+ conn = get_db_connection()
1139
+ cursor = conn.cursor()
1140
+
1141
+ # Get the most recent active subscription
1142
+ try:
1143
+ cursor.execute(
1144
+ "SELECT id, tier, status, created_at, expires_at, paypal_subscription_id FROM subscriptions "
1145
+ "WHERE user_id = ? AND status = 'active' ORDER BY created_at DESC LIMIT 1",
1146
+ (current_user.id,)
1147
+ )
1148
+ subscription = cursor.fetchone()
1149
+ except sqlite3.OperationalError as e:
1150
+ # Handle missing tier column
1151
+ if "no such column: tier" in str(e):
1152
+ logger.warning("Subscriptions table missing 'tier' column. Returning default subscription.")
1153
+ subscription = None
1154
+ else:
1155
+ raise
1156
+
1157
+ # Get subscription tiers with pricing directly from SUBSCRIPTION_TIERS
1158
+ subscription_tiers = {
1159
+ "free_tier": {
1160
+ "price": SUBSCRIPTION_TIERS["free_tier"]["price"],
1161
+ "currency": SUBSCRIPTION_TIERS["free_tier"]["currency"],
1162
+ "features": SUBSCRIPTION_TIERS["free_tier"]["features"]
1163
+ },
1164
+ "standard_tier": {
1165
+ "price": SUBSCRIPTION_TIERS["standard_tier"]["price"],
1166
+ "currency": SUBSCRIPTION_TIERS["standard_tier"]["currency"],
1167
+ "features": SUBSCRIPTION_TIERS["standard_tier"]["features"]
1168
+ },
1169
+ "premium_tier": {
1170
+ "price": SUBSCRIPTION_TIERS["premium_tier"]["price"],
1171
+ "currency": SUBSCRIPTION_TIERS["premium_tier"]["currency"],
1172
+ "features": SUBSCRIPTION_TIERS["premium_tier"]["features"]
1173
+ }
1174
+ }
1175
+
1176
+ if subscription:
1177
+ sub_id, tier, status, created_at, expires_at, paypal_id = subscription
1178
+ result = {
1179
+ "id": sub_id,
1180
+ "tier": tier,
1181
+ "status": status,
1182
+ "created_at": created_at,
1183
+ "expires_at": expires_at,
1184
+ "paypal_subscription_id": paypal_id,
1185
+ "current_tier": current_user.subscription_tier,
1186
+ "subscription_tiers": subscription_tiers
1187
+ }
1188
+ else:
1189
+ result = {
1190
+ "tier": "free_tier",
1191
+ "status": "active",
1192
+ "current_tier": current_user.subscription_tier,
1193
+ "subscription_tiers": subscription_tiers
1194
+ }
1195
+
1196
+ conn.close()
1197
+ return result
1198
+ except Exception as e:
1199
+ logger.error(f"Error getting subscription: {str(e)}")
1200
+ raise HTTPException(status_code=500, detail=f"Error getting subscription: {str(e)}")
1201
+ # Add this model definition before your endpoints
1202
+ class SubscriptionCreate(BaseModel):
1203
+ tier: str
1204
+
1205
+ @app.post("/create_subscription")
1206
+ async def create_subscription(
1207
+ subscription: SubscriptionCreate,
1208
+ current_user: User = Depends(get_current_active_user)
1209
+ ):
1210
+ """Create a subscription for the current user"""
1211
+ try:
1212
+ # Log the request for debugging
1213
+ logger.info(f"Creating subscription for user {current_user.email} with tier {subscription.tier}")
1214
+ logger.info(f"Available tiers: {list(SUBSCRIPTION_TIERS.keys())}")
1215
+
1216
+ # Validate tier
1217
+ valid_tiers = ["standard_tier", "premium_tier"]
1218
+ if subscription.tier not in valid_tiers:
1219
+ logger.warning(f"Invalid tier requested: {subscription.tier}")
1220
+ raise HTTPException(status_code=400, detail=f"Invalid tier: {subscription.tier}. Must be one of {valid_tiers}")
1221
+
1222
+ # Create subscription
1223
+ logger.info(f"Calling create_user_subscription with email: {current_user.email}, tier: {subscription.tier}")
1224
+ success, result = create_user_subscription(current_user.email, subscription.tier)
1225
+
1226
+ if not success:
1227
+ logger.error(f"Failed to create subscription: {result}")
1228
+ raise HTTPException(status_code=400, detail=result)
1229
+
1230
+ logger.info(f"Subscription created successfully: {result}")
1231
+ return result
1232
+ except Exception as e:
1233
+ logger.error(f"Error creating subscription: {str(e)}")
1234
+ # Include the full traceback for better debugging
1235
+ import traceback
1236
+ logger.error(f"Traceback: {traceback.format_exc()}")
1237
+ raise HTTPException(status_code=500, detail=f"Error creating subscription: {str(e)}")
1238
+
1239
+ @app.post("/subscribe/{tier}")
1240
+ async def subscribe_to_tier(
1241
+ tier: str,
1242
+ current_user: User = Depends(get_current_active_user)
1243
+ ):
1244
+ """Subscribe to a specific tier"""
1245
+ try:
1246
+ # Validate tier
1247
+ valid_tiers = ["standard_tier", "premium_tier"]
1248
+ if tier not in valid_tiers:
1249
+ raise HTTPException(status_code=400, detail=f"Invalid tier: {tier}. Must be one of {valid_tiers}")
1250
+
1251
+ # Create subscription
1252
+ success, result = create_user_subscription(current_user.email, tier)
1253
+
1254
+ if not success:
1255
+ raise HTTPException(status_code=400, detail=result)
1256
+
1257
+ return result
1258
+ except Exception as e:
1259
+ logger.error(f"Error creating subscription: {str(e)}")
1260
+ raise HTTPException(status_code=500, detail=f"Error creating subscription: {str(e)}")
1261
+
1262
+ @app.post("/subscription/create")
1263
+ async def create_subscription(request: Request, current_user: User = Depends(get_current_active_user)):
1264
+ """Create a subscription for the current user"""
1265
+ try:
1266
+ data = await request.json()
1267
+ tier = data.get("tier")
1268
+
1269
+ if not tier:
1270
+ return JSONResponse(
1271
+ status_code=400,
1272
+ content={"detail": "Tier is required"}
1273
+ )
1274
+
1275
+ # Log the request for debugging
1276
+ logger.info(f"Creating subscription for user {current_user.email} with tier {tier}")
1277
+
1278
+ # Create the subscription using the imported function directly
1279
+ success, result = create_user_subscription(current_user.email, tier)
1280
+
1281
+ if success:
1282
+ # Make sure we're returning the approval_url in the response
1283
+ logger.info(f"Subscription created successfully: {result}")
1284
+ logger.info(f"Approval URL: {result.get('approval_url')}")
1285
+
1286
+ return {
1287
+ "success": True,
1288
+ "data": {
1289
+ "approval_url": result["approval_url"],
1290
+ "subscription_id": result["subscription_id"],
1291
+ "tier": result["tier"]
1292
+ }
1293
+ }
1294
+ else:
1295
+ logger.error(f"Failed to create subscription: {result}")
1296
+ return JSONResponse(
1297
+ status_code=400,
1298
+ content={"success": False, "detail": result}
1299
+ )
1300
+ except Exception as e:
1301
+ logger.error(f"Error creating subscription: {str(e)}")
1302
+ import traceback
1303
+ logger.error(f"Traceback: {traceback.format_exc()}")
1304
+ return JSONResponse(
1305
+ status_code=500,
1306
+ content={"success": False, "detail": f"Error creating subscription: {str(e)}"}
1307
+ )
1308
+
1309
+ @app.post("/admin/initialize-paypal-plans")
1310
+ async def initialize_paypal_plans(request: Request):
1311
+ """Initialize PayPal subscription plans"""
1312
+ try:
1313
+ # This should be protected with admin authentication in production
1314
+ plans = initialize_subscription_plans()
1315
+
1316
+ if plans:
1317
+ return JSONResponse(
1318
+ status_code=200,
1319
+ content={"success": True, "plans": plans}
1320
+ )
1321
+ else:
1322
+ return JSONResponse(
1323
+ status_code=500,
1324
+ content={"success": False, "detail": "Failed to initialize plans"}
1325
+ )
1326
+ except Exception as e:
1327
+ logger.error(f"Error initializing PayPal plans: {str(e)}")
1328
+ return JSONResponse(
1329
+ status_code=500,
1330
+ content={"success": False, "detail": f"Error initializing plans: {str(e)}"}
1331
+ )
1332
+
1333
+
1334
+ @app.post("/subscription/verify")
1335
+ async def verify_subscription(request: Request, current_user: User = Depends(get_current_active_user)):
1336
+ """Verify a subscription after payment"""
1337
+ try:
1338
+ data = await request.json()
1339
+ subscription_id = data.get("subscription_id")
1340
+
1341
+ if not subscription_id:
1342
+ return JSONResponse(
1343
+ status_code=400,
1344
+ content={"success": False, "detail": "Subscription ID is required"}
1345
+ )
1346
+
1347
+ logger.info(f"Verifying subscription: {subscription_id}")
1348
+
1349
+ # Verify the subscription with PayPal
1350
+ success, result = verify_paypal_subscription(subscription_id)
1351
+
1352
+ if not success:
1353
+ logger.error(f"Subscription verification failed: {result}")
1354
+ return JSONResponse(
1355
+ status_code=400,
1356
+ content={"success": False, "detail": str(result)}
1357
+ )
1358
+
1359
+ # Update the user's subscription in the database
1360
+ conn = get_db_connection()
1361
+ cursor = conn.cursor()
1362
+
1363
+ # Get the subscription details
1364
+ cursor.execute(
1365
+ "SELECT tier FROM subscriptions WHERE paypal_subscription_id = ?",
1366
+ (subscription_id,)
1367
+ )
1368
+ subscription = cursor.fetchone()
1369
+
1370
+ if not subscription:
1371
+ # This is a new subscription, get the tier from the PayPal response
1372
+ tier = "standard_tier" # Default to standard tier
1373
+ # You could extract the tier from the PayPal plan ID if needed
1374
+
1375
+ # Create a new subscription record
1376
+ sub_id = str(uuid.uuid4())
1377
+ start_date = datetime.now()
1378
+ expires_at = start_date + timedelta(days=30)
1379
+
1380
+ cursor.execute(
1381
+ "INSERT INTO subscriptions (id, user_id, tier, status, created_at, expires_at, paypal_subscription_id) VALUES (?, ?, ?, ?, ?, ?, ?)",
1382
+ (sub_id, current_user.id, tier, "active", start_date, expires_at, subscription_id)
1383
+ )
1384
+ else:
1385
+ # Update existing subscription
1386
+ tier = subscription[0]
1387
+ cursor.execute(
1388
+ "UPDATE subscriptions SET status = 'active' WHERE paypal_subscription_id = ?",
1389
+ (subscription_id,)
1390
+ )
1391
+
1392
+ # Update user's subscription tier
1393
+ cursor.execute(
1394
+ "UPDATE users SET subscription_tier = ? WHERE id = ?",
1395
+ (tier, current_user.id)
1396
+ )
1397
+
1398
+ conn.commit()
1399
+ conn.close()
1400
+
1401
+ return JSONResponse(
1402
+ status_code=200,
1403
+ content={"success": True, "detail": "Subscription verified successfully"}
1404
+ )
1405
+
1406
+ except Exception as e:
1407
+ logger.error(f"Error verifying subscription: {str(e)}")
1408
+ return JSONResponse(
1409
+ status_code=500,
1410
+ content={"success": False, "detail": f"Error verifying subscription: {str(e)}"}
1411
+ )
1412
+
1413
+ @app.post("/subscription/webhook")
1414
+ async def subscription_webhook(request: Request):
1415
+ """Handle PayPal subscription webhooks"""
1416
+ try:
1417
+ payload = await request.json()
1418
+ success, result = handle_subscription_webhook(payload)
1419
+
1420
+ if not success:
1421
+ logger.error(f"Webhook processing failed: {result}")
1422
+ return {"status": "error", "message": result}
1423
+
1424
+ return {"status": "success", "message": result}
1425
+ except Exception as e:
1426
+ logger.error(f"Error processing webhook: {str(e)}")
1427
+ return {"status": "error", "message": f"Error processing webhook: {str(e)}"}
1428
+
1429
+ @app.get("/subscription/verify/{subscription_id}")
1430
+ async def verify_subscription(
1431
+ subscription_id: str,
1432
+ current_user: User = Depends(get_current_active_user)
1433
+ ):
1434
+ """Verify a subscription payment and update user tier"""
1435
+ try:
1436
+ # Verify the subscription
1437
+ success, result = verify_subscription_payment(subscription_id)
1438
+
1439
+ if not success:
1440
+ raise HTTPException(status_code=400, detail=f"Subscription verification failed: {result}")
1441
+
1442
+ # Get the plan ID from the subscription to determine tier
1443
+ plan_id = result.get("plan_id", "")
1444
+
1445
+ # Connect to DB to get the tier for this plan
1446
+ conn = get_db_connection()
1447
+ cursor = conn.cursor()
1448
+ cursor.execute("SELECT tier FROM paypal_plans WHERE plan_id = ?", (plan_id,))
1449
+ tier_result = cursor.fetchone()
1450
+ conn.close()
1451
+
1452
+ if not tier_result:
1453
+ raise HTTPException(status_code=400, detail="Could not determine subscription tier")
1454
+
1455
+ tier = tier_result[0]
1456
+
1457
+ # Update the user's subscription
1458
+ success, update_result = update_user_subscription(current_user.email, subscription_id, tier)
1459
+
1460
+ if not success:
1461
+ raise HTTPException(status_code=500, detail=f"Failed to update subscription: {update_result}")
1462
+
1463
+ return {
1464
+ "message": f"Successfully subscribed to {tier} tier",
1465
+ "subscription_id": subscription_id,
1466
+ "status": result.get("status", ""),
1467
+ "next_billing_time": result.get("billing_info", {}).get("next_billing_time", "")
1468
+ }
1469
+
1470
+ except HTTPException:
1471
+ raise
1472
+ except Exception as e:
1473
+ print(f"Subscription verification error: {str(e)}")
1474
+ raise HTTPException(status_code=500, detail=f"Subscription verification failed: {str(e)}")
1475
+
1476
+ @app.post("/webhook/paypal")
1477
+ async def paypal_webhook(request: Request):
1478
+ """Handle PayPal subscription webhooks"""
1479
+ try:
1480
+ payload = await request.json()
1481
+ logger.info(f"Received PayPal webhook: {payload.get('event_type', 'unknown event')}")
1482
+
1483
+ # Process the webhook
1484
+ result = handle_subscription_webhook(payload)
1485
+
1486
+ return {"status": "success", "message": "Webhook processed"}
1487
+ except Exception as e:
1488
+ logger.error(f"Webhook processing error: {str(e)}")
1489
+ # Return 200 even on error to acknowledge receipt to PayPal
1490
+ return {"status": "error", "message": str(e)}
1491
+
1492
+ # Add this to your startup code
1493
+ @app.on_event("startup")
1494
+ async def startup_event():
1495
+ """Initialize subscription plans on startup"""
1496
+ try:
1497
+ # Initialize PayPal subscription plans if needed
1498
+ # If you have an initialize_subscription_plans function in your paypal_integration.py,
1499
+ # you can call it here
1500
+ print("Application started successfully")
1501
+ except Exception as e:
1502
+ print(f"Error during startup: {str(e)}")
1503
+
1504
+ if __name__ == "__main__":
1505
+ import uvicorn
1506
+ port = int(os.environ.get("PORT", 7860))
1507
+ host = os.environ.get("HOST", "0.0.0.0")
1508
+ uvicorn.run("app:app", host=host, port=port, reload=True)
auth.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import uuid
3
+ import os
4
+ import logging
5
+ from datetime import datetime, timedelta
6
+ import hashlib # Use hashlib instead of jwt
7
+ from passlib.hash import bcrypt
8
+ from dotenv import load_dotenv
9
+ from fastapi import Depends, HTTPException
10
+ from fastapi.security import OAuth2PasswordBearer
11
+ from pydantic import BaseModel
12
+ from typing import Optional
13
+ from fastapi import HTTPException, status
14
+ import jwt
15
+ from jwt.exceptions import PyJWTError
16
+ import sqlite3
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+
21
+ # Configure logging
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
25
+ )
26
+ logger = logging.getLogger('auth')
27
+
28
+ # Security configuration
29
+ SECRET_KEY = os.getenv("JWT_SECRET", "your-secret-key-for-development-only")
30
+ ALGORITHM = "HS256"
31
+ JWT_EXPIRATION_DELTA = timedelta(days=1) # Token valid for 1 day
32
+ # Database path from environment variable or default
33
+ # Fix the incorrect DB_PATH
34
+ DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "data/user_data.db"))
35
+
36
+ # FastAPI OAuth2 scheme
37
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
38
+
39
+ # Pydantic models for FastAPI
40
+ class User(BaseModel):
41
+ id: str
42
+ email: str
43
+ subscription_tier: str = "free_tier"
44
+ subscription_expiry: Optional[datetime] = None
45
+ api_calls_remaining: int = 5
46
+ last_reset_date: Optional[datetime] = None
47
+
48
+ class UserCreate(BaseModel):
49
+ email: str
50
+ password: str
51
+
52
+ class Token(BaseModel):
53
+ access_token: str
54
+ token_type: str
55
+
56
+ class TokenData(BaseModel):
57
+ user_id: Optional[str] = None
58
+
59
+ # Subscription tiers and limits
60
+ # Update the SUBSCRIPTION_TIERS dictionary
61
+ SUBSCRIPTION_TIERS = {
62
+ "free_tier": {
63
+ "price": 0,
64
+ "currency": "INR",
65
+ "features": ["basic_document_analysis", "basic_risk_assessment"],
66
+ "limits": {
67
+ "document_size_mb": 5,
68
+ "documents_per_month": 3,
69
+ "video_size_mb": 0,
70
+ "audio_size_mb": 0
71
+ }
72
+ },
73
+ "standard_tier": {
74
+ "price": 799,
75
+ "currency": "INR",
76
+ "features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot"],
77
+ "limits": {
78
+ "document_size_mb": 20,
79
+ "documents_per_month": 20,
80
+ "video_size_mb": 100,
81
+ "audio_size_mb": 50
82
+ }
83
+ },
84
+ "premium_tier": {
85
+ "price": 1499,
86
+ "currency": "INR",
87
+ "features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot", "detailed_risk_assessment", "contract_clause_analysis"],
88
+ "limits": {
89
+ "document_size_mb": 50,
90
+ "documents_per_month": 999999, # Unlimited
91
+ "video_size_mb": 500,
92
+ "audio_size_mb": 200
93
+ }
94
+ }
95
+ }
96
+
97
+ # Database connection management
98
+ def get_db_connection():
99
+ """Create and return a database connection with proper error handling"""
100
+ try:
101
+ # Ensure the directory exists
102
+ db_dir = os.path.dirname(DB_PATH)
103
+ os.makedirs(db_dir, exist_ok=True)
104
+
105
+ conn = sqlite3.connect(DB_PATH)
106
+ conn.row_factory = sqlite3.Row # Return rows as dictionaries
107
+ return conn
108
+ except sqlite3.Error as e:
109
+ logger.error(f"Database connection error: {e}")
110
+ raise Exception(f"Database connection failed: {e}")
111
+
112
+ # Database setup
113
+ # In the init_auth_db function, update the CREATE TABLE statement to match our schema
114
+ def init_auth_db():
115
+ """Initialize the authentication database with required tables"""
116
+ try:
117
+ conn = get_db_connection()
118
+ c = conn.cursor()
119
+
120
+ # Create users table with the correct schema
121
+ c.execute('''
122
+ CREATE TABLE IF NOT EXISTS users (
123
+ id TEXT PRIMARY KEY,
124
+ email TEXT UNIQUE NOT NULL,
125
+ hashed_password TEXT NOT NULL,
126
+ password TEXT,
127
+ subscription_tier TEXT DEFAULT 'free_tier',
128
+ is_active BOOLEAN DEFAULT 1,
129
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
130
+ api_calls_remaining INTEGER DEFAULT 10,
131
+ last_reset_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP
132
+ )
133
+ ''')
134
+
135
+ # Create subscriptions table
136
+ c.execute('''
137
+ CREATE TABLE IF NOT EXISTS subscriptions (
138
+ id TEXT PRIMARY KEY,
139
+ user_id TEXT,
140
+ tier TEXT,
141
+ plan_id TEXT,
142
+ status TEXT,
143
+ created_at TIMESTAMP,
144
+ expires_at TIMESTAMP,
145
+ paypal_subscription_id TEXT,
146
+ FOREIGN KEY (user_id) REFERENCES users (id)
147
+ )
148
+ ''')
149
+
150
+ # Create usage stats table
151
+ c.execute('''
152
+ CREATE TABLE IF NOT EXISTS usage_stats (
153
+ id TEXT PRIMARY KEY,
154
+ user_id TEXT,
155
+ month INTEGER,
156
+ year INTEGER,
157
+ analyses_used INTEGER,
158
+ FOREIGN KEY (user_id) REFERENCES users (id)
159
+ )
160
+ ''')
161
+
162
+ # Create tokens table for refresh tokens
163
+ c.execute('''
164
+ CREATE TABLE IF NOT EXISTS refresh_tokens (
165
+ user_id TEXT,
166
+ token TEXT,
167
+ expires_at TIMESTAMP,
168
+ FOREIGN KEY (user_id) REFERENCES users (id)
169
+ )
170
+ ''')
171
+
172
+ conn.commit()
173
+ logger.info("Database initialized successfully")
174
+ except Exception as e:
175
+ logger.error(f"Database initialization error: {e}")
176
+ raise
177
+ finally:
178
+ if conn:
179
+ conn.close()
180
+
181
+ # Initialize the database
182
+ init_auth_db()
183
+
184
+ # Password hashing with bcrypt
185
+ # Update the password hashing and verification functions to use a more reliable method
186
+
187
+ # Replace these functions
188
+ # Remove these conflicting functions
189
+ # def hash_password(password):
190
+ # """Hash a password using bcrypt"""
191
+ # return bcrypt.hash(password)
192
+ #
193
+ # def verify_password(plain_password, hashed_password):
194
+ # """Verify a password against its hash"""
195
+ # return bcrypt.verify(plain_password, hashed_password)
196
+
197
+ # Keep only these improved functions
198
+ def hash_password(password):
199
+ """Hash a password using bcrypt"""
200
+ # Use a more direct approach to avoid bcrypt version issues
201
+ import bcrypt
202
+ # Convert password to bytes if it's not already
203
+ if isinstance(password, str):
204
+ password = password.encode('utf-8')
205
+ # Generate salt and hash
206
+ salt = bcrypt.gensalt()
207
+ hashed = bcrypt.hashpw(password, salt)
208
+ # Return as string for storage
209
+ return hashed.decode('utf-8')
210
+
211
+ def verify_password(plain_password, hashed_password):
212
+ """Verify a password against its hash"""
213
+ import bcrypt
214
+ # Convert inputs to bytes if they're not already
215
+ if isinstance(plain_password, str):
216
+ plain_password = plain_password.encode('utf-8')
217
+ if isinstance(hashed_password, str):
218
+ hashed_password = hashed_password.encode('utf-8')
219
+
220
+ try:
221
+ # Use direct bcrypt verification
222
+ return bcrypt.checkpw(plain_password, hashed_password)
223
+ except Exception as e:
224
+ logger.error(f"Password verification error: {e}")
225
+ return False
226
+
227
+ # User registration
228
+ def register_user(email, password):
229
+ try:
230
+ conn = get_db_connection()
231
+ c = conn.cursor()
232
+
233
+ # Check if user already exists
234
+ c.execute("SELECT * FROM users WHERE email = ?", (email,))
235
+ if c.fetchone():
236
+ return False, "Email already registered"
237
+
238
+ # Create new user
239
+ user_id = str(uuid.uuid4())
240
+
241
+ # Add more detailed logging
242
+ logger.info(f"Registering new user with email: {email}")
243
+ hashed_pw = hash_password(password)
244
+ logger.info(f"Password hashed successfully: {bool(hashed_pw)}")
245
+
246
+ c.execute("""
247
+ INSERT INTO users
248
+ (id, email, hashed_password, subscription_tier, api_calls_remaining, last_reset_date)
249
+ VALUES (?, ?, ?, ?, ?, ?)
250
+ """, (user_id, email, hashed_pw, "free_tier", 5, datetime.now()))
251
+
252
+ conn.commit()
253
+ logger.info(f"User registered successfully: {email}")
254
+
255
+ # Verify the user was actually stored
256
+ c.execute("SELECT * FROM users WHERE email = ?", (email,))
257
+ stored_user = c.fetchone()
258
+ logger.info(f"User verification after registration: {bool(stored_user)}")
259
+
260
+ access_token = create_access_token(user_id)
261
+ return True, {
262
+ "user_id": user_id,
263
+ "access_token": access_token,
264
+ "token_type": "bearer"
265
+ }
266
+ except Exception as e:
267
+ logger.error(f"User registration error: {e}")
268
+ return False, f"Registration failed: {str(e)}"
269
+ finally:
270
+ if conn:
271
+ conn.close()
272
+
273
+ # User login
274
+ # Fix the authenticate_user function
275
+ # In the authenticate_user function, update the password verification to use hashed_password
276
+ def authenticate_user(email, password):
277
+ """Authenticate a user and return user data with tokens"""
278
+ try:
279
+ conn = get_db_connection()
280
+ c = conn.cursor()
281
+
282
+ # Get user by email
283
+ c.execute("SELECT * FROM users WHERE email = ? AND is_active = 1", (email,))
284
+ user = c.fetchone()
285
+
286
+ if not user:
287
+ logger.warning(f"User not found: {email}")
288
+ return None
289
+
290
+ # Add debug logging for password verification
291
+ logger.info(f"Verifying password for user: {email}")
292
+ logger.info(f"Stored hashed password: {user['hashed_password'][:20]}...")
293
+
294
+ try:
295
+ # Check if password verification works
296
+ is_valid = verify_password(password, user['hashed_password'])
297
+ logger.info(f"Password verification result: {is_valid}")
298
+
299
+ if not is_valid:
300
+ logger.warning(f"Password verification failed for user: {email}")
301
+ return None
302
+ except Exception as e:
303
+ logger.error(f"Password verification error: {e}")
304
+ return None
305
+
306
+ # Update last login time if column exists
307
+ try:
308
+ c.execute("UPDATE users SET last_login = ? WHERE id = ?",
309
+ (datetime.now(), user['id']))
310
+ conn.commit()
311
+ except sqlite3.OperationalError:
312
+ # last_login column might not exist
313
+ pass
314
+
315
+ # Convert sqlite3.Row to dict to use get() method
316
+ user_dict = dict(user)
317
+
318
+ # Create and return a User object
319
+ return User(
320
+ id=user_dict['id'],
321
+ email=user_dict['email'],
322
+ subscription_tier=user_dict.get('subscription_tier', 'free_tier'),
323
+ subscription_expiry=None, # Handle this properly if needed
324
+ api_calls_remaining=user_dict.get('api_calls_remaining', 5),
325
+ last_reset_date=user_dict.get('last_reset_date')
326
+ )
327
+ except Exception as e:
328
+ logger.error(f"Login error: {e}")
329
+ return None
330
+ finally:
331
+ if conn:
332
+ conn.close()
333
+
334
+ # Token generation and validation - completely replaced
335
+ def create_access_token(user_id):
336
+ """Create a new access token for a user"""
337
+ try:
338
+ # Create a JWT token with user_id and expiration
339
+ expiration = datetime.now() + JWT_EXPIRATION_DELTA
340
+
341
+ # Create a token payload
342
+ payload = {
343
+ "sub": user_id,
344
+ "exp": expiration.timestamp()
345
+ }
346
+
347
+ # Generate the JWT token
348
+ token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
349
+
350
+ logger.info(f"Created access token for user: {user_id}")
351
+ return token
352
+ except Exception as e:
353
+ logger.error(f"Token creation error: {e}")
354
+ return None
355
+
356
+
357
+ def update_auth_db_schema():
358
+ """Update the authentication database schema with any missing columns"""
359
+ try:
360
+ conn = get_db_connection()
361
+ c = conn.cursor()
362
+
363
+ # Check if tier column exists in subscriptions table
364
+ c.execute("PRAGMA table_info(subscriptions)")
365
+ columns = [column[1] for column in c.fetchall()]
366
+
367
+ # Add tier column if it doesn't exist
368
+ if "tier" not in columns:
369
+ logger.info("Adding 'tier' column to subscriptions table")
370
+ c.execute("ALTER TABLE subscriptions ADD COLUMN tier TEXT")
371
+ conn.commit()
372
+ logger.info("Database schema updated successfully")
373
+
374
+ conn.close()
375
+ except Exception as e:
376
+ logger.error(f"Database schema update error: {e}")
377
+ raise HTTPException(
378
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
379
+ detail=f"Database schema update error: {str(e)}"
380
+ )
381
+
382
+ # Add this to your get_current_user function
383
+ async def get_current_user(token: str = Depends(oauth2_scheme)):
384
+ credentials_exception = HTTPException(
385
+ status_code=status.HTTP_401_UNAUTHORIZED,
386
+ detail="Could not validate credentials",
387
+ headers={"WWW-Authenticate": "Bearer"},
388
+ )
389
+ try:
390
+ # Decode the JWT token
391
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
392
+ user_id: str = payload.get("sub")
393
+ if user_id is None:
394
+ logger.error("Token missing 'sub' field")
395
+ raise credentials_exception
396
+ except Exception as e:
397
+ logger.error(f"Token validation error: {str(e)}")
398
+ raise credentials_exception
399
+
400
+ # Get user from database
401
+ conn = get_db_connection()
402
+ cursor = conn.cursor()
403
+ cursor.execute("SELECT id, email, subscription_tier, is_active FROM users WHERE id = ?", (user_id,))
404
+ user_data = cursor.fetchone()
405
+ conn.close()
406
+
407
+ if user_data is None:
408
+ logger.error(f"User not found: {user_id}")
409
+ raise credentials_exception
410
+
411
+ user = User(
412
+ id=user_data[0],
413
+ email=user_data[1],
414
+ subscription_tier=user_data[2],
415
+ is_active=bool(user_data[3])
416
+ )
417
+
418
+ return user
419
+
420
+ async def get_current_active_user(current_user: User = Depends(get_current_user)):
421
+ """Get the current active user"""
422
+ return current_user
423
+
424
+ def create_user_subscription(email, tier):
425
+ """Create a subscription for a user"""
426
+ try:
427
+ # Get user by email
428
+ conn = get_db_connection()
429
+ c = conn.cursor()
430
+
431
+ # Get user ID
432
+ c.execute("SELECT id FROM users WHERE email = ?", (email,))
433
+ user_data = c.fetchone()
434
+
435
+ if not user_data:
436
+ return False, "User not found"
437
+
438
+ user_id = user_data['id']
439
+
440
+ # Check if tier is valid
441
+ valid_tiers = ["standard_tier", "premium_tier"]
442
+ if tier not in valid_tiers:
443
+ return False, f"Invalid tier: {tier}. Must be one of {valid_tiers}"
444
+
445
+ # Create subscription
446
+ subscription_id = str(uuid.uuid4())
447
+ created_at = datetime.now()
448
+ expires_at = created_at + timedelta(days=30) # 30-day subscription
449
+
450
+ # Insert subscription
451
+ c.execute("""
452
+ INSERT INTO subscriptions
453
+ (id, user_id, tier, status, created_at, expires_at)
454
+ VALUES (?, ?, ?, ?, ?, ?)
455
+ """, (subscription_id, user_id, tier, "active", created_at, expires_at))
456
+
457
+ # Update user's subscription tier
458
+ c.execute("""
459
+ UPDATE users
460
+ SET subscription_tier = ?
461
+ WHERE id = ?
462
+ """, (tier, user_id))
463
+
464
+ conn.commit()
465
+
466
+ return True, {
467
+ "id": subscription_id,
468
+ "user_id": user_id,
469
+ "tier": tier,
470
+ "status": "active",
471
+ "created_at": created_at.isoformat(),
472
+ "expires_at": expires_at.isoformat()
473
+ }
474
+ except Exception as e:
475
+ logger.error(f"Subscription creation error: {e}")
476
+ return False, f"Failed to create subscription: {str(e)}"
477
+ finally:
478
+ if conn:
479
+ conn.close()
480
+
481
+ def get_user(user_id: str):
482
+ """Get user by ID"""
483
+ try:
484
+ conn = get_db_connection()
485
+ c = conn.cursor()
486
+
487
+ # Get user
488
+ c.execute("SELECT * FROM users WHERE id = ? AND is_active = 1", (user_id,))
489
+ user_data = c.fetchone()
490
+
491
+ if not user_data:
492
+ return None
493
+
494
+ # Convert to User model
495
+ user_dict = dict(user_data)
496
+
497
+ # Handle datetime conversions if needed
498
+ if user_dict.get("subscription_expiry") and isinstance(user_dict["subscription_expiry"], str):
499
+ user_dict["subscription_expiry"] = datetime.fromisoformat(user_dict["subscription_expiry"])
500
+ if user_dict.get("last_reset_date") and isinstance(user_dict["last_reset_date"], str):
501
+ user_dict["last_reset_date"] = datetime.fromisoformat(user_dict["last_reset_date"])
502
+
503
+ return User(
504
+ id=user_dict['id'],
505
+ email=user_dict['email'],
506
+ subscription_tier=user_dict['subscription_tier'],
507
+ subscription_expiry=user_dict.get('subscription_expiry'),
508
+ api_calls_remaining=user_dict.get('api_calls_remaining', 5),
509
+ last_reset_date=user_dict.get('last_reset_date')
510
+ )
511
+ except Exception as e:
512
+ logger.error(f"Get user error: {e}")
513
+ return None
514
+ finally:
515
+ if conn:
516
+ conn.close()
517
+
518
+ def check_subscription_access(user: User, feature: str, file_size_mb: Optional[float] = None):
519
+ """Check if the user has access to the requested feature and file size"""
520
+ # Check if subscription is expired
521
+ if user.subscription_tier != "free_tier" and user.subscription_expiry and user.subscription_expiry < datetime.now():
522
+ # Downgrade to free tier if subscription expired
523
+ user.subscription_tier = "free_tier"
524
+ user.api_calls_remaining = SUBSCRIPTION_TIERS["free_tier"]["daily_api_calls"]
525
+ with get_db_connection() as conn:
526
+ c = conn.cursor()
527
+ c.execute("""
528
+ UPDATE users
529
+ SET subscription_tier = ?, api_calls_remaining = ?
530
+ WHERE id = ?
531
+ """, (user.subscription_tier, user.api_calls_remaining, user.id))
532
+ conn.commit()
533
+
534
+ # Reset API calls if needed
535
+ user = reset_api_calls_if_needed(user)
536
+
537
+ # Check if user has API calls remaining
538
+ if user.api_calls_remaining <= 0:
539
+ raise HTTPException(
540
+ status_code=429,
541
+ detail="API call limit reached for today. Please upgrade your subscription or try again tomorrow."
542
+ )
543
+
544
+ # Check if feature is available in user's subscription tier
545
+ tier_features = SUBSCRIPTION_TIERS[user.subscription_tier]["features"]
546
+ if feature not in tier_features:
547
+ raise HTTPException(
548
+ status_code=403,
549
+ detail=f"The {feature} feature is not available in your {user.subscription_tier} subscription. Please upgrade to access this feature."
550
+ )
551
+
552
+ # Check file size limit if applicable
553
+ if file_size_mb:
554
+ max_size = SUBSCRIPTION_TIERS[user.subscription_tier]["max_document_size_mb"]
555
+ if file_size_mb > max_size:
556
+ raise HTTPException(
557
+ status_code=413,
558
+ detail=f"File size exceeds the {max_size}MB limit for your {user.subscription_tier} subscription. Please upgrade or use a smaller file."
559
+ )
560
+
561
+ # Decrement API calls remaining
562
+ user.api_calls_remaining -= 1
563
+ with get_db_connection() as conn:
564
+ c = conn.cursor()
565
+ c.execute("""
566
+ UPDATE users
567
+ SET api_calls_remaining = ?
568
+ WHERE id = ?
569
+ """, (user.api_calls_remaining, user.id))
570
+ conn.commit()
571
+
572
+ return True
573
+
574
+ def reset_api_calls_if_needed(user: User):
575
+ """Reset API call counter if it's a new day"""
576
+ today = datetime.now().date()
577
+ if user.last_reset_date is None or user.last_reset_date.date() < today:
578
+ tier_limits = SUBSCRIPTION_TIERS[user.subscription_tier]
579
+ user.api_calls_remaining = tier_limits["daily_api_calls"]
580
+ user.last_reset_date = datetime.now()
581
+ # Update the user in the database
582
+ with get_db_connection() as conn:
583
+ c = conn.cursor()
584
+ c.execute("""
585
+ UPDATE users
586
+ SET api_calls_remaining = ?, last_reset_date = ?
587
+ WHERE id = ?
588
+ """, (user.api_calls_remaining, user.last_reset_date, user.id))
589
+ conn.commit()
590
+
591
+ return user
592
+
593
+ def login_user(email, password):
594
+ """Login a user with email and password"""
595
+ try:
596
+ # Authenticate user
597
+ user = authenticate_user(email, password)
598
+ if not user:
599
+ return False, "Incorrect username or password"
600
+
601
+ # Create access token
602
+ access_token = create_access_token(user.id)
603
+
604
+ # Create refresh token
605
+ refresh_token = str(uuid.uuid4())
606
+ expires_at = datetime.now() + timedelta(days=30)
607
+
608
+ # Store refresh token
609
+ conn = get_db_connection()
610
+ c = conn.cursor()
611
+ c.execute("INSERT INTO refresh_tokens VALUES (?, ?, ?)",
612
+ (user.id, refresh_token, expires_at))
613
+ conn.commit()
614
+
615
+ # Get subscription info
616
+ c.execute("SELECT * FROM subscriptions WHERE user_id = ? AND status = 'active'", (user.id,))
617
+ subscription = c.fetchone()
618
+
619
+ # Convert subscription to dict if it exists, otherwise set to None
620
+ subscription_dict = dict(subscription) if subscription else None
621
+
622
+ conn.close()
623
+
624
+ return True, {
625
+ "user_id": user.id,
626
+ "email": user.email,
627
+ "access_token": access_token,
628
+ "refresh_token": refresh_token,
629
+ "subscription": subscription_dict
630
+ }
631
+ except Exception as e:
632
+ logger.error(f"Login error: {e}")
633
+ return False, f"Login failed: {str(e)}"
check_routes.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import importlib.util
4
+
5
+ def check_fastapi_routes():
6
+ """Check the FastAPI routes in the app.py file."""
7
+ print("Checking FastAPI routes...")
8
+
9
+ # Store current directory
10
+ current_dir = os.getcwd()
11
+
12
+ try:
13
+ # Change to the backend directory
14
+ os.chdir("doc-vid-analyze-main")
15
+ print(f"Current directory: {os.getcwd()}")
16
+
17
+ # Check if app.py exists
18
+ if not os.path.exists("app.py"):
19
+ print(f"❌ Error: app.py not found in {os.getcwd()}")
20
+ return False
21
+
22
+ # Load the app.py module
23
+ print("Loading app.py module...")
24
+ spec = importlib.util.spec_from_file_location("app", "app.py")
25
+ app_module = importlib.util.module_from_spec(spec)
26
+ sys.modules["app"] = app_module
27
+ spec.loader.exec_module(app_module)
28
+
29
+ # Check if app is defined
30
+ if not hasattr(app_module, "app"):
31
+ print("❌ Error: 'app' object not found in app.py")
32
+ return False
33
+
34
+ app = app_module.app
35
+
36
+ # Print app information
37
+ print("\n📋 FastAPI App Information:")
38
+ print(f"App title: {app.title}")
39
+ print(f"App version: {app.version}")
40
+ print(f"App description: {app.description}")
41
+
42
+ # Print routes
43
+ print("\n📋 FastAPI Routes:")
44
+ for route in app.routes:
45
+ print(f"Route: {route.path}")
46
+ print(f" Methods: {route.methods}")
47
+ print(f" Name: {route.name}")
48
+ print(f" Endpoint: {route.endpoint.__name__ if hasattr(route.endpoint, '__name__') else route.endpoint}")
49
+ print()
50
+
51
+ return True
52
+
53
+ except Exception as e:
54
+ print(f"❌ Error checking routes: {e}")
55
+ return False
56
+ finally:
57
+ # Return to original directory
58
+ os.chdir(current_dir)
59
+
60
+ if __name__ == "__main__":
61
+ check_fastapi_routes()
fix_users_table.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import os
3
+ import uuid
4
+ import datetime
5
+
6
+ # Define both database paths
7
+ DB_PATH_1 = os.path.join(os.path.dirname(__file__), "../data/user_data.db")
8
+ DB_PATH_2 = os.path.join(os.path.dirname(__file__), "data/user_data.db")
9
+
10
+ # Define the function to create users table
11
+ # Make sure the create_users_table function allows NULL for hashed_password temporarily
12
+ def create_users_table(cursor):
13
+ """Create the users table with all required columns"""
14
+ cursor.execute('''
15
+ CREATE TABLE users (
16
+ id TEXT PRIMARY KEY,
17
+ email TEXT UNIQUE NOT NULL,
18
+ hashed_password TEXT DEFAULT 'temp_hash_for_migration',
19
+ password TEXT,
20
+ subscription_tier TEXT DEFAULT 'free',
21
+ is_active BOOLEAN DEFAULT 1,
22
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
23
+ api_calls_remaining INTEGER DEFAULT 10,
24
+ last_reset_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP
25
+ )
26
+ ''')
27
+
28
+ # Update the CREATE TABLE statement to include all necessary columns
29
+ def fix_users_table(db_path):
30
+ # Make sure the data directory exists
31
+ data_dir = os.path.dirname(db_path)
32
+ if not os.path.exists(data_dir):
33
+ print(f"Creating data directory: {data_dir}")
34
+ os.makedirs(data_dir, exist_ok=True)
35
+
36
+ if not os.path.exists(db_path):
37
+ print(f"Database does not exist at: {os.path.abspath(db_path)}")
38
+ return False
39
+
40
+ print(f"Using database path: {os.path.abspath(db_path)}")
41
+
42
+ # Connect to the database
43
+ conn = sqlite3.connect(db_path)
44
+ cursor = conn.cursor()
45
+
46
+ # Check if users table exists
47
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'")
48
+ if cursor.fetchone():
49
+ print("Users table exists, checking schema...")
50
+
51
+ # Check columns
52
+ cursor.execute("PRAGMA table_info(users)")
53
+ columns_info = cursor.fetchall()
54
+ columns = [column[1] for column in columns_info]
55
+
56
+ # List of all required columns
57
+ required_columns = ['id', 'email', 'hashed_password', 'password', 'subscription_tier',
58
+ 'is_active', 'created_at', 'api_calls_remaining', 'last_reset_date']
59
+
60
+ # Check if any required column is missing
61
+ missing_columns = [col for col in required_columns if col not in columns]
62
+
63
+ if missing_columns:
64
+ print(f"Schema needs fixing. Missing columns: {', '.join(missing_columns)}")
65
+
66
+ # Dynamically build the SELECT query based on available columns
67
+ available_columns = [col for col in columns if col != 'id'] # Exclude id as we'll generate new ones
68
+
69
+ if not available_columns:
70
+ print("No usable columns found in users table, creating new table...")
71
+ cursor.execute("DROP TABLE users")
72
+ create_users_table(cursor)
73
+ print("Created new empty users table with correct schema")
74
+ else:
75
+ # Backup existing users with available columns
76
+ select_query = f"SELECT {', '.join(available_columns)} FROM users"
77
+ print(f"Backing up users with query: {select_query}")
78
+ cursor.execute(select_query)
79
+ existing_users = cursor.fetchall()
80
+
81
+ # Drop the existing table
82
+ cursor.execute("DROP TABLE users")
83
+
84
+ # Create the table with the correct schema
85
+ create_users_table(cursor)
86
+
87
+ # Restore the users with new UUIDs for IDs
88
+ if existing_users:
89
+ current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
90
+ for user in existing_users:
91
+ user_id = str(uuid.uuid4())
92
+
93
+ # Create a dictionary to map column names to values
94
+ user_data = {'id': user_id}
95
+ for i, col in enumerate(available_columns):
96
+ user_data[col] = user[i]
97
+
98
+ # Set default values for missing columns
99
+ # Add a default value for hashed_password in the Set default values section
100
+ if 'hashed_password' not in user_data:
101
+ user_data['hashed_password'] = 'temp_hash_for_migration' # Temporary hash for migration
102
+ if 'subscription_tier' not in user_data:
103
+ user_data['subscription_tier'] = 'free'
104
+ if 'is_active' not in user_data:
105
+ user_data['is_active'] = 1
106
+ if 'created_at' not in user_data:
107
+ user_data['created_at'] = current_time
108
+ if 'api_calls_remaining' not in user_data:
109
+ user_data['api_calls_remaining'] = 10
110
+ if 'last_reset_date' not in user_data:
111
+ user_data['last_reset_date'] = current_time
112
+
113
+ # Build INSERT query with all required columns
114
+ insert_columns = ['id']
115
+ insert_values = [user_id]
116
+
117
+ # Add values for columns that exist in the old table
118
+ for col in available_columns:
119
+ insert_columns.append(col)
120
+ insert_values.append(user_data[col])
121
+
122
+ # Add default values for columns that don't exist in the old table
123
+ for col in required_columns:
124
+ # Add hashed_password to the column default values section
125
+ if col not in ['id'] + available_columns:
126
+ insert_columns.append(col)
127
+ if col == 'subscription_tier':
128
+ insert_values.append('free')
129
+ elif col == 'is_active':
130
+ insert_values.append(1)
131
+ elif col == 'created_at':
132
+ insert_values.append(current_time)
133
+ elif col == 'api_calls_remaining':
134
+ insert_values.append(10)
135
+ elif col == 'last_reset_date':
136
+ insert_values.append(current_time)
137
+ elif col == 'hashed_password':
138
+ insert_values.append('temp_hash_for_migration') # Temporary hash for migration
139
+ else:
140
+ insert_values.append(None) # Default to NULL for other columns
141
+
142
+ placeholders = ', '.join(['?'] * len(insert_columns))
143
+ insert_query = f"INSERT INTO users ({', '.join(insert_columns)}) VALUES ({placeholders})"
144
+
145
+ cursor.execute(insert_query, insert_values)
146
+
147
+ print(f"Fixed users table, restored {len(existing_users)} users")
148
+ else:
149
+ print("Users table schema is correct")
150
+ else:
151
+ print("Users table doesn't exist, creating it now...")
152
+ create_users_table(cursor)
153
+ print("Users table created successfully")
154
+
155
+ # Commit changes and close connection
156
+ conn.commit()
157
+ conn.close()
158
+ return True
159
+
160
+ if __name__ == "__main__":
161
+ print("Checking first database location...")
162
+ success1 = fix_users_table(DB_PATH_1)
163
+
164
+ print("\nChecking second database location...")
165
+ success2 = fix_users_table(DB_PATH_2)
166
+
167
+ if not (success1 or success2):
168
+ print("\nWarning: Could not find any existing database files.")
169
+ print("Creating a new database at the primary location...")
170
+ # Create a new database at the primary location
171
+ data_dir = os.path.dirname(DB_PATH_1)
172
+ if not os.path.exists(data_dir):
173
+ os.makedirs(data_dir, exist_ok=True)
174
+
175
+ conn = sqlite3.connect(DB_PATH_1)
176
+ cursor = conn.cursor()
177
+ create_users_table(cursor)
178
+ conn.commit()
179
+ conn.close()
180
+ print(f"Created new database at: {os.path.abspath(DB_PATH_1)}")
initialize_plans.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from dotenv import load_dotenv
4
+ from paypal_integration import initialize_subscription_plans
5
+
6
+ # Load environment variables
7
+ load_dotenv()
8
+
9
+ def main():
10
+ """Initialize PayPal subscription plans"""
11
+ print("Initializing PayPal subscription plans...")
12
+ plans = initialize_subscription_plans()
13
+
14
+ if plans:
15
+ print("✅ Plans initialized successfully:")
16
+ for tier, plan_id in plans.items():
17
+ print(f" - {tier}: {plan_id}")
18
+ return True
19
+ else:
20
+ print("❌ Failed to initialize plans. Check the logs for details.")
21
+ return False
22
+
23
+ if __name__ == "__main__":
24
+ success = main()
25
+ sys.exit(0 if success else 1)
paypal_integration.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import sqlite3
4
+ from datetime import datetime, timedelta
5
+ import uuid
6
+ import os
7
+ import logging
8
+ from requests.adapters import HTTPAdapter
9
+ from requests.packages.urllib3.util.retry import Retry
10
+ from auth import get_db_connection
11
+ from dotenv import load_dotenv
12
+
13
+ # PayPal API Configuration - Remove default values for production
14
+ PAYPAL_CLIENT_ID = os.getenv("PAYPAL_CLIENT_ID")
15
+ PAYPAL_SECRET = os.getenv("PAYPAL_SECRET")
16
+ PAYPAL_BASE_URL = os.getenv("PAYPAL_BASE_URL", "https://api-m.sandbox.paypal.com")
17
+
18
+ # Add validation to ensure credentials are provided
19
+ # Set up logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
23
+ handlers=[
24
+ logging.FileHandler(os.path.join(os.path.dirname(__file__), "../logs/paypal.log")),
25
+ logging.StreamHandler()
26
+ ]
27
+ )
28
+ logger = logging.getLogger("paypal_integration")
29
+
30
+ # Then replace print statements with logger calls
31
+ # For example:
32
+ if not PAYPAL_CLIENT_ID or not PAYPAL_SECRET:
33
+ logger.warning("PayPal credentials not found in environment variables")
34
+
35
+
36
+ # Get PayPal access token
37
+ # Add better error handling for production
38
+ # Create a session with retry capability
39
+ def create_retry_session(retries=3, backoff_factor=0.3):
40
+ session = requests.Session()
41
+ retry = Retry(
42
+ total=retries,
43
+ read=retries,
44
+ connect=retries,
45
+ backoff_factor=backoff_factor,
46
+ status_forcelist=[500, 502, 503, 504],
47
+ )
48
+ adapter = HTTPAdapter(max_retries=retry)
49
+ session.mount('http://', adapter)
50
+ session.mount('https://', adapter)
51
+ return session
52
+
53
+ # Then use this session for API calls
54
+ # Replace get_access_token with logger instead of print
55
+ def get_access_token():
56
+ url = f"{PAYPAL_BASE_URL}/v1/oauth2/token"
57
+ headers = {
58
+ "Accept": "application/json",
59
+ "Accept-Language": "en_US"
60
+ }
61
+ data = "grant_type=client_credentials"
62
+
63
+ try:
64
+ session = create_retry_session()
65
+ response = session.post(
66
+ url,
67
+ auth=(PAYPAL_CLIENT_ID, PAYPAL_SECRET),
68
+ headers=headers,
69
+ data=data
70
+ )
71
+
72
+ if response.status_code == 200:
73
+ return response.json()["access_token"]
74
+ else:
75
+ logger.error(f"Error getting access token: {response.status_code}")
76
+ return None
77
+ except Exception as e:
78
+ logger.error(f"Exception in get_access_token: {str(e)}")
79
+ return None
80
+
81
+ def call_paypal_api(endpoint, method="GET", data=None, token=None):
82
+ """
83
+ Helper function to make PayPal API calls
84
+
85
+ Args:
86
+ endpoint: API endpoint (without base URL)
87
+ method: HTTP method (GET, POST, etc.)
88
+ data: Request payload (for POST/PUT)
89
+ token: PayPal access token (will be fetched if None)
90
+
91
+ Returns:
92
+ tuple: (success, response_data or error_message)
93
+ """
94
+ try:
95
+ if not token:
96
+ token = get_access_token()
97
+ if not token:
98
+ return False, "Failed to get PayPal access token"
99
+
100
+ url = f"{PAYPAL_BASE_URL}{endpoint}"
101
+ headers = {
102
+ "Content-Type": "application/json",
103
+ "Authorization": f"Bearer {token}"
104
+ }
105
+
106
+ session = create_retry_session()
107
+
108
+ if method.upper() == "GET":
109
+ response = session.get(url, headers=headers)
110
+ elif method.upper() == "POST":
111
+ response = session.post(url, headers=headers, data=json.dumps(data) if data else None)
112
+ elif method.upper() == "PUT":
113
+ response = session.put(url, headers=headers, data=json.dumps(data) if data else None)
114
+ else:
115
+ return False, f"Unsupported HTTP method: {method}"
116
+
117
+ if response.status_code in [200, 201, 204]:
118
+ if response.status_code == 204: # No content
119
+ return True, {}
120
+ return True, response.json() if response.text else {}
121
+ else:
122
+ logger.error(f"PayPal API error: {response.status_code} - {response.text}")
123
+ return False, f"PayPal API error: {response.status_code} - {response.text}"
124
+
125
+ except Exception as e:
126
+ logger.error(f"Error calling PayPal API: {str(e)}")
127
+ return False, f"Error calling PayPal API: {str(e)}"
128
+
129
+ def create_paypal_subscription(user_id, tier):
130
+ """Create a PayPal subscription for a user"""
131
+ try:
132
+ # Get the price from the subscription tier
133
+ from auth import SUBSCRIPTION_TIERS
134
+
135
+ if tier not in SUBSCRIPTION_TIERS:
136
+ return False, f"Invalid tier: {tier}"
137
+
138
+ price = SUBSCRIPTION_TIERS[tier]["price"]
139
+ currency = SUBSCRIPTION_TIERS[tier]["currency"]
140
+
141
+ # Create a PayPal subscription (implement PayPal API calls here)
142
+ # For now, just return a success response
143
+ return True, {
144
+ "subscription_id": f"test_sub_{uuid.uuid4()}",
145
+ "status": "ACTIVE",
146
+ "tier": tier,
147
+ "price": price,
148
+ "currency": currency
149
+ }
150
+ except Exception as e:
151
+ logger.error(f"Error creating PayPal subscription: {str(e)}")
152
+ return False, f"Failed to create PayPal subscription: {str(e)}"
153
+
154
+
155
+ # Create a product in PayPal
156
+ def create_product(name, description):
157
+ """Create a product in PayPal"""
158
+ payload = {
159
+ "name": name,
160
+ "description": description,
161
+ "type": "SERVICE",
162
+ "category": "SOFTWARE"
163
+ }
164
+
165
+ success, result = call_paypal_api("/v1/catalogs/products", "POST", payload)
166
+ if success:
167
+ return result["id"]
168
+ else:
169
+ logger.error(f"Failed to create product: {result}")
170
+ return None
171
+
172
+ # Create a subscription plan in PayPal
173
+ # Update create_plan to use INR instead of USD
174
+ def create_plan(product_id, name, price, interval="MONTH", interval_count=1):
175
+ """Create a subscription plan in PayPal"""
176
+ payload = {
177
+ "product_id": product_id,
178
+ "name": name,
179
+ "billing_cycles": [
180
+ {
181
+ "frequency": {
182
+ "interval_unit": interval,
183
+ "interval_count": interval_count
184
+ },
185
+ "tenure_type": "REGULAR",
186
+ "sequence": 1,
187
+ "total_cycles": 0, # Infinite cycles
188
+ "pricing_scheme": {
189
+ "fixed_price": {
190
+ "value": str(price),
191
+ "currency_code": "USD"
192
+ }
193
+ }
194
+ }
195
+ ],
196
+ "payment_preferences": {
197
+ "auto_bill_outstanding": True,
198
+ "setup_fee": {
199
+ "value": "0",
200
+ "currency_code": "USD"
201
+ },
202
+ "setup_fee_failure_action": "CONTINUE",
203
+ "payment_failure_threshold": 3
204
+ }
205
+ }
206
+
207
+ success, result = call_paypal_api("/v1/billing/plans", "POST", payload)
208
+ if success:
209
+ return result["id"]
210
+ else:
211
+ logger.error(f"Failed to create plan: {result}")
212
+ return None
213
+
214
+ # Update initialize_subscription_plans to use INR pricing
215
+ def initialize_subscription_plans():
216
+ """
217
+ Initialize PayPal subscription plans for the application.
218
+ This should be called once to set up the plans in PayPal.
219
+ """
220
+ try:
221
+ # Check if plans already exist
222
+ existing_plans = get_subscription_plans()
223
+ if existing_plans and len(existing_plans) >= 2:
224
+ logger.info("PayPal plans already initialized")
225
+ return existing_plans
226
+
227
+ # First, create products for each tier
228
+ products = {
229
+ "standard_tier": {
230
+ "name": "Standard Legal Document Analysis",
231
+ "description": "Standard subscription with document analysis features",
232
+ "type": "SERVICE",
233
+ "category": "SOFTWARE"
234
+ },
235
+ "premium_tier": {
236
+ "name": "Premium Legal Document Analysis",
237
+ "description": "Premium subscription with all document analysis features",
238
+ "type": "SERVICE",
239
+ "category": "SOFTWARE"
240
+ }
241
+ }
242
+
243
+ product_ids = {}
244
+ for tier, product_data in products.items():
245
+ success, result = call_paypal_api("/v1/catalogs/products", "POST", product_data)
246
+ if success:
247
+ product_ids[tier] = result["id"]
248
+ logger.info(f"Created PayPal product for {tier}: {result['id']}")
249
+ else:
250
+ logger.error(f"Failed to create product for {tier}: {result}")
251
+ return None
252
+
253
+ # Define the plans with product IDs - Changed currency to USD
254
+ plans = {
255
+ "standard_tier": {
256
+ "product_id": product_ids["standard_tier"],
257
+ "name": "Standard Plan",
258
+ "description": "Standard subscription with basic features",
259
+ "billing_cycles": [
260
+ {
261
+ "frequency": {
262
+ "interval_unit": "MONTH",
263
+ "interval_count": 1
264
+ },
265
+ "tenure_type": "REGULAR",
266
+ "sequence": 1,
267
+ "total_cycles": 0,
268
+ "pricing_scheme": {
269
+ "fixed_price": {
270
+ "value": "9.99",
271
+ "currency_code": "USD"
272
+ }
273
+ }
274
+ }
275
+ ],
276
+ "payment_preferences": {
277
+ "auto_bill_outstanding": True,
278
+ "setup_fee": {
279
+ "value": "0",
280
+ "currency_code": "USD"
281
+ },
282
+ "setup_fee_failure_action": "CONTINUE",
283
+ "payment_failure_threshold": 3
284
+ }
285
+ },
286
+ "premium_tier": {
287
+ "product_id": product_ids["premium_tier"],
288
+ "name": "Premium Plan",
289
+ "description": "Premium subscription with all features",
290
+ "billing_cycles": [
291
+ {
292
+ "frequency": {
293
+ "interval_unit": "MONTH",
294
+ "interval_count": 1
295
+ },
296
+ "tenure_type": "REGULAR",
297
+ "sequence": 1,
298
+ "total_cycles": 0,
299
+ "pricing_scheme": {
300
+ "fixed_price": {
301
+ "value": "19.99",
302
+ "currency_code": "USD"
303
+ }
304
+ }
305
+ }
306
+ ],
307
+ "payment_preferences": {
308
+ "auto_bill_outstanding": True,
309
+ "setup_fee": {
310
+ "value": "0",
311
+ "currency_code": "USD"
312
+ },
313
+ "setup_fee_failure_action": "CONTINUE",
314
+ "payment_failure_threshold": 3
315
+ }
316
+ }
317
+ }
318
+
319
+ # Create the plans in PayPal
320
+ created_plans = {}
321
+ for tier, plan_data in plans.items():
322
+ success, result = call_paypal_api("/v1/billing/plans", "POST", plan_data)
323
+ if success:
324
+ created_plans[tier] = result["id"]
325
+ logger.info(f"Created PayPal plan for {tier}: {result['id']}")
326
+ else:
327
+ logger.error(f"Failed to create plan for {tier}: {result}")
328
+
329
+ # Save the plan IDs to a file
330
+ if created_plans:
331
+ save_subscription_plans(created_plans)
332
+ return created_plans
333
+ else:
334
+ logger.error("Failed to create any PayPal plans")
335
+ return None
336
+ except Exception as e:
337
+ logger.error(f"Error initializing subscription plans: {str(e)}")
338
+ return None
339
+
340
+ # Update create_subscription_link to use call_paypal_api helper
341
+ def create_subscription_link(plan_id):
342
+ # Get the plan IDs
343
+ plans = get_subscription_plans()
344
+ if not plans:
345
+ return None
346
+
347
+ # Use environment variable for the app URL to make it work in different environments
348
+ app_url = os.getenv("APP_URL", "http://localhost:8501")
349
+
350
+ payload = {
351
+ "plan_id": plans[plan_id],
352
+ "application_context": {
353
+ "brand_name": "Legal Document Analyzer",
354
+ "locale": "en_US",
355
+ "shipping_preference": "NO_SHIPPING",
356
+ "user_action": "SUBSCRIBE_NOW",
357
+ "return_url": f"{app_url}?status=success&subscription_id={{id}}",
358
+ "cancel_url": f"{app_url}?status=cancel"
359
+ }
360
+ }
361
+
362
+ success, data = call_paypal_api("/v1/billing/subscriptions", "POST", payload)
363
+ if not success:
364
+ logger.error(f"Error creating subscription: {data}")
365
+ return None
366
+
367
+ try:
368
+ return {
369
+ "subscription_id": data["id"],
370
+ "approval_url": next(link["href"] for link in data["links"] if link["rel"] == "approve")
371
+ }
372
+ except Exception as e:
373
+ logger.error(f"Exception processing subscription response: {str(e)}")
374
+ return None
375
+
376
+ # Fix the webhook handler function signature to match how it's called in app.py
377
+ def handle_subscription_webhook(payload):
378
+ """
379
+ Handle PayPal subscription webhooks
380
+
381
+ Args:
382
+ payload: The full webhook payload
383
+
384
+ Returns:
385
+ tuple: (success, result)
386
+ - success: True if successful, False otherwise
387
+ - result: Success message or error message
388
+ """
389
+ try:
390
+ event_type = payload.get("event_type")
391
+ resource = payload.get("resource", {})
392
+
393
+ logger.info(f"Received PayPal webhook: {event_type}")
394
+
395
+ # Handle different event types
396
+ if event_type == "BILLING.SUBSCRIPTION.CREATED":
397
+ # A subscription was created
398
+ subscription_id = resource.get("id")
399
+ if not subscription_id:
400
+ return False, "Missing subscription ID in webhook"
401
+
402
+ # Update subscription status in database
403
+ conn = get_db_connection()
404
+ cursor = conn.cursor()
405
+ cursor.execute(
406
+ "UPDATE subscriptions SET status = 'pending' WHERE paypal_subscription_id = ?",
407
+ (subscription_id,)
408
+ )
409
+ conn.commit()
410
+ conn.close()
411
+
412
+ return True, "Subscription created successfully"
413
+
414
+ elif event_type == "BILLING.SUBSCRIPTION.ACTIVATED":
415
+ # A subscription was activated
416
+ subscription_id = resource.get("id")
417
+ if not subscription_id:
418
+ return False, "Missing subscription ID in webhook"
419
+
420
+ # Update subscription status in database
421
+ conn = get_db_connection()
422
+ cursor = conn.cursor()
423
+ cursor.execute(
424
+ "UPDATE subscriptions SET status = 'active' WHERE paypal_subscription_id = ?",
425
+ (subscription_id,)
426
+ )
427
+ conn.commit()
428
+ conn.close()
429
+
430
+ return True, "Subscription activated successfully"
431
+
432
+ elif event_type == "BILLING.SUBSCRIPTION.CANCELLED":
433
+ # A subscription was cancelled
434
+ subscription_id = resource.get("id")
435
+ if not subscription_id:
436
+ return False, "Missing subscription ID in webhook"
437
+
438
+ # Update subscription status in database
439
+ conn = get_db_connection()
440
+ cursor = conn.cursor()
441
+ cursor.execute(
442
+ "UPDATE subscriptions SET status = 'cancelled' WHERE paypal_subscription_id = ?",
443
+ (subscription_id,)
444
+ )
445
+ conn.commit()
446
+ conn.close()
447
+
448
+ return True, "Subscription cancelled successfully"
449
+
450
+ elif event_type == "BILLING.SUBSCRIPTION.SUSPENDED":
451
+ # A subscription was suspended
452
+ subscription_id = resource.get("id")
453
+ if not subscription_id:
454
+ return False, "Missing subscription ID in webhook"
455
+
456
+ # Update subscription status in database
457
+ conn = get_db_connection()
458
+ cursor = conn.cursor()
459
+ cursor.execute(
460
+ "UPDATE subscriptions SET status = 'suspended' WHERE paypal_subscription_id = ?",
461
+ (subscription_id,)
462
+ )
463
+ conn.commit()
464
+ conn.close()
465
+
466
+ return True, "Subscription suspended successfully"
467
+
468
+ else:
469
+ # Unhandled event type
470
+ logger.info(f"Unhandled webhook event type: {event_type}")
471
+ return True, f"Unhandled event type: {event_type}"
472
+
473
+ except Exception as e:
474
+ logger.error(f"Error handling webhook: {str(e)}")
475
+ return False, f"Error handling webhook: {str(e)}"
476
+ # Add this function to update user subscription
477
+ def update_user_subscription(user_email, subscription_id, tier):
478
+ """
479
+ Update a user's subscription status
480
+
481
+ Args:
482
+ user_email: The email of the user
483
+ subscription_id: The PayPal subscription ID
484
+ tier: The subscription tier
485
+
486
+ Returns:
487
+ tuple: (success, result)
488
+ - success: True if successful, False otherwise
489
+ - result: Success message or error message
490
+ """
491
+ try:
492
+ # Get user ID from email
493
+ conn = get_db_connection()
494
+ cursor = conn.cursor()
495
+ cursor.execute("SELECT id FROM users WHERE email = ?", (user_email,))
496
+ user_result = cursor.fetchone()
497
+
498
+ if not user_result:
499
+ conn.close()
500
+ return False, f"User not found: {user_email}"
501
+
502
+ user_id = user_result[0]
503
+
504
+ # Update the subscription status
505
+ cursor.execute(
506
+ "UPDATE subscriptions SET status = 'active' WHERE user_id = ? AND paypal_subscription_id = ?",
507
+ (user_id, subscription_id)
508
+ )
509
+
510
+ # Deactivate any other active subscriptions for this user
511
+ cursor.execute(
512
+ "UPDATE subscriptions SET status = 'inactive' WHERE user_id = ? AND paypal_subscription_id != ? AND status = 'active'",
513
+ (user_id, subscription_id)
514
+ )
515
+
516
+ # Update the user's subscription tier
517
+ cursor.execute(
518
+ "UPDATE users SET subscription_tier = ? WHERE email = ?",
519
+ (tier, user_email)
520
+ )
521
+
522
+ conn.commit()
523
+ conn.close()
524
+
525
+ return True, f"Subscription updated to {tier} tier"
526
+
527
+ except Exception as e:
528
+ logger.error(f"Error updating user subscription: {str(e)}")
529
+ return False, f"Error updating subscription: {str(e)}"
530
+
531
+ # Add this near the top with other path definitions
532
+ # Update the PLAN_IDS_PATH definition to use the correct path
533
+ PLAN_IDS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "data", "plan_ids.json"))
534
+
535
+ # Make sure the data directory exists
536
+ os.makedirs(os.path.dirname(PLAN_IDS_PATH), exist_ok=True)
537
+
538
+ # Add this debug log to see where the file is expected
539
+ logger.info(f"PayPal plans will be stored at: {PLAN_IDS_PATH}")
540
+
541
+ # Add this function if it's not defined elsewhere
542
+ def get_db_connection():
543
+ """Get a connection to the SQLite database"""
544
+ DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "../data/user_data.db"))
545
+ # Make sure the data directory exists
546
+ os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
547
+ return sqlite3.connect(DB_PATH)
548
+
549
+ # Add this function to create subscription tables if needed
550
+ def initialize_database():
551
+ """Initialize the database tables needed for subscriptions"""
552
+ conn = get_db_connection()
553
+ cursor = conn.cursor()
554
+
555
+ # Check if subscriptions table exists
556
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='subscriptions'")
557
+ if cursor.fetchone():
558
+ # Table exists, check if required columns exist
559
+ cursor.execute("PRAGMA table_info(subscriptions)")
560
+ columns = [column[1] for column in cursor.fetchall()]
561
+
562
+ # Check for missing columns and add them if needed
563
+ if "user_id" not in columns:
564
+ logger.info("Adding 'user_id' column to subscriptions table")
565
+ cursor.execute("ALTER TABLE subscriptions ADD COLUMN user_id TEXT NOT NULL DEFAULT ''")
566
+
567
+ if "created_at" not in columns:
568
+ logger.info("Adding 'created_at' column to subscriptions table")
569
+ cursor.execute("ALTER TABLE subscriptions ADD COLUMN created_at TIMESTAMP")
570
+
571
+ if "expires_at" not in columns:
572
+ logger.info("Adding 'expires_at' column to subscriptions table")
573
+ cursor.execute("ALTER TABLE subscriptions ADD COLUMN expires_at TIMESTAMP")
574
+
575
+ if "paypal_subscription_id" not in columns:
576
+ logger.info("Adding 'paypal_subscription_id' column to subscriptions table")
577
+ cursor.execute("ALTER TABLE subscriptions ADD COLUMN paypal_subscription_id TEXT")
578
+ else:
579
+ # Create subscriptions table with all required columns
580
+ cursor.execute('''
581
+ CREATE TABLE IF NOT EXISTS subscriptions (
582
+ id TEXT PRIMARY KEY,
583
+ user_id TEXT NOT NULL,
584
+ tier TEXT NOT NULL,
585
+ status TEXT NOT NULL,
586
+ created_at TIMESTAMP NOT NULL,
587
+ expires_at TIMESTAMP,
588
+ paypal_subscription_id TEXT
589
+ )
590
+ ''')
591
+ logger.info("Created subscriptions table with all required columns")
592
+
593
+ # Create PayPal plans table if it doesn't exist
594
+ cursor.execute('''
595
+ CREATE TABLE IF NOT EXISTS paypal_plans (
596
+ plan_id TEXT PRIMARY KEY,
597
+ tier TEXT NOT NULL,
598
+ price REAL NOT NULL,
599
+ currency TEXT NOT NULL,
600
+ created_at TIMESTAMP NOT NULL
601
+ )
602
+ ''')
603
+
604
+ conn.commit()
605
+ conn.close()
606
+ logger.info("Database initialization completed")
607
+
608
+
609
+ def create_user_subscription_mock(user_email, tier):
610
+ """
611
+ Create a mock subscription for testing
612
+
613
+ Args:
614
+ user_email: The email of the user
615
+ tier: The subscription tier
616
+
617
+ Returns:
618
+ tuple: (success, result)
619
+ """
620
+ try:
621
+ logger.info(f"Creating mock subscription for {user_email} at tier {tier}")
622
+
623
+ # Get user ID from email
624
+ conn = get_db_connection()
625
+ cursor = conn.cursor()
626
+ cursor.execute("SELECT id FROM users WHERE email = ?", (user_email,))
627
+ user_result = cursor.fetchone()
628
+
629
+ if not user_result:
630
+ conn.close()
631
+ return False, f"User not found: {user_email}"
632
+
633
+ user_id = user_result[0]
634
+
635
+ # Create a mock subscription ID
636
+ subscription_id = f"mock_sub_{uuid.uuid4()}"
637
+
638
+ # Store the subscription in database
639
+ sub_id = str(uuid.uuid4())
640
+ start_date = datetime.now()
641
+
642
+ cursor.execute(
643
+ "INSERT INTO subscriptions (id, user_id, tier, status, created_at, expires_at, paypal_subscription_id) VALUES (?, ?, ?, ?, ?, ?, ?)",
644
+ (sub_id, user_id, tier, "active", start_date, start_date + timedelta(days=30), subscription_id)
645
+ )
646
+
647
+ # Update user's subscription tier
648
+ cursor.execute(
649
+ "UPDATE users SET subscription_tier = ? WHERE id = ?",
650
+ (tier, user_id)
651
+ )
652
+
653
+ conn.commit()
654
+ conn.close()
655
+
656
+ # Use environment variable for the app URL
657
+ app_url = os.getenv("APP_URL", "http://localhost:3000")
658
+
659
+ # Return success with mock approval URL that matches the real PayPal URL pattern
660
+ return True, {
661
+ "subscription_id": subscription_id,
662
+ "approval_url": f"{app_url}/subscription/callback?status=success&subscription_id={subscription_id}",
663
+ "tier": tier
664
+ }
665
+
666
+ except Exception as e:
667
+ logger.error(f"Error creating mock subscription: {str(e)}")
668
+ return False, f"Error creating subscription: {str(e)}"
669
+
670
+ # Add this at the end of the file
671
+ def initialize():
672
+ """Initialize the PayPal integration module"""
673
+ try:
674
+ # Create necessary directories
675
+ os.makedirs(os.path.dirname(PLAN_IDS_PATH), exist_ok=True)
676
+
677
+ # Initialize database
678
+ initialize_database()
679
+
680
+ # Initialize subscription plans
681
+ plans = get_subscription_plans()
682
+ if plans:
683
+ logger.info(f"Subscription plans initialized: {plans}")
684
+ else:
685
+ logger.warning("Failed to initialize subscription plans")
686
+
687
+ return True
688
+ except Exception as e:
689
+ logger.error(f"Error initializing PayPal integration: {str(e)}")
690
+ return False
691
+
692
+ # Call initialize when the module is imported
693
+ initialize()
694
+
695
+ # Add this function to get subscription plans
696
+ def get_subscription_plans():
697
+ """
698
+ Get all available subscription plans with correct pricing
699
+ """
700
+ try:
701
+ # Check if we have plan IDs saved in a file
702
+ if os.path.exists(PLAN_IDS_PATH):
703
+ try:
704
+ with open(PLAN_IDS_PATH, 'r') as f:
705
+ plans = json.load(f)
706
+ logger.info(f"Loaded subscription plans from {PLAN_IDS_PATH}: {plans}")
707
+ return plans
708
+ except Exception as e:
709
+ logger.error(f"Error reading plan IDs file: {str(e)}")
710
+ return {}
711
+
712
+ # If no file exists, return empty dict
713
+ logger.warning(f"No plan IDs file found at {PLAN_IDS_PATH}. Please initialize subscription plans.")
714
+ return {}
715
+
716
+ except Exception as e:
717
+ logger.error(f"Error getting subscription plans: {str(e)}")
718
+ return {}
719
+
720
+ # Add this function to create subscription tables if needed
721
+ def initialize_database():
722
+ """Initialize the database tables needed for subscriptions"""
723
+ conn = get_db_connection()
724
+ cursor = conn.cursor()
725
+
726
+ # Create subscriptions table if it doesn't exist
727
+ cursor.execute('''
728
+ CREATE TABLE IF NOT EXISTS subscriptions (
729
+ id TEXT PRIMARY KEY,
730
+ user_id TEXT NOT NULL,
731
+ tier TEXT NOT NULL,
732
+ status TEXT NOT NULL,
733
+ created_at TIMESTAMP NOT NULL,
734
+ expires_at TIMESTAMP,
735
+ paypal_subscription_id TEXT
736
+ )
737
+ ''')
738
+
739
+ # Create PayPal plans table if it doesn't exist
740
+ cursor.execute('''
741
+ CREATE TABLE IF NOT EXISTS paypal_plans (
742
+ plan_id TEXT PRIMARY KEY,
743
+ tier TEXT NOT NULL,
744
+ price REAL NOT NULL,
745
+ currency TEXT NOT NULL,
746
+ created_at TIMESTAMP NOT NULL
747
+ )
748
+ ''')
749
+
750
+ conn.commit()
751
+ conn.close()
752
+
753
+
754
+ def create_user_subscription(user_email, tier):
755
+ """
756
+ Create a real PayPal subscription for a user
757
+
758
+ Args:
759
+ user_email: The email of the user
760
+ tier: The subscription tier (standard_tier or premium_tier)
761
+
762
+ Returns:
763
+ tuple: (success, result)
764
+ - success: True if successful, False otherwise
765
+ - result: Dictionary with subscription details or error message
766
+ """
767
+ try:
768
+ # Validate tier
769
+ valid_tiers = ["standard_tier", "premium_tier"]
770
+ if tier not in valid_tiers:
771
+ return False, f"Invalid tier: {tier}. Must be one of {valid_tiers}"
772
+
773
+ # Get the plan IDs
774
+ plans = get_subscription_plans()
775
+
776
+ # Log the plans for debugging
777
+ logger.info(f"Available subscription plans: {plans}")
778
+
779
+ # If no plans found, check if the file exists and try to load it directly
780
+ if not plans:
781
+ if os.path.exists(PLAN_IDS_PATH):
782
+ logger.info(f"Plan IDs file exists at {PLAN_IDS_PATH}, but couldn't load plans. Trying direct load.")
783
+ try:
784
+ with open(PLAN_IDS_PATH, 'r') as f:
785
+ plans = json.load(f)
786
+ logger.info(f"Directly loaded plans: {plans}")
787
+ except Exception as e:
788
+ logger.error(f"Error directly loading plans: {str(e)}")
789
+ else:
790
+ logger.error(f"Plan IDs file does not exist at {PLAN_IDS_PATH}")
791
+
792
+ # If still no plans, return error
793
+ if not plans:
794
+ logger.error("No PayPal plans found. Please initialize plans first.")
795
+ return False, "PayPal plans not configured. Please contact support."
796
+
797
+ # Check if the tier exists in plans
798
+ if tier not in plans:
799
+ return False, f"No plan found for tier: {tier}"
800
+
801
+ # Use environment variable for the app URL
802
+ app_url = os.getenv("APP_URL", "http://localhost:3000")
803
+
804
+ # Create the subscription with PayPal
805
+ payload = {
806
+ "plan_id": plans[tier],
807
+ "subscriber": {
808
+ "email_address": user_email
809
+ },
810
+ "application_context": {
811
+ "brand_name": "Legal Document Analyzer",
812
+ "locale": "en-US", # Changed from en_US to en-US
813
+ "shipping_preference": "NO_SHIPPING",
814
+ "user_action": "SUBSCRIBE_NOW",
815
+ "return_url": f"{app_url}/subscription/callback?status=success",
816
+ "cancel_url": f"{app_url}/subscription/callback?status=cancel"
817
+ }
818
+ }
819
+
820
+ # Make the API call to PayPal
821
+ success, subscription_data = call_paypal_api("/v1/billing/subscriptions", "POST", payload)
822
+ if not success:
823
+ return False, subscription_data # This is already an error message
824
+
825
+ # Extract the approval URL
826
+ approval_url = next((link["href"] for link in subscription_data["links"]
827
+ if link["rel"] == "approve"), None)
828
+
829
+ if not approval_url:
830
+ return False, "No approval URL found in PayPal response"
831
+
832
+ # Get user ID from email
833
+ conn = get_db_connection()
834
+ cursor = conn.cursor()
835
+ cursor.execute("SELECT id FROM users WHERE email = ?", (user_email,))
836
+ user_result = cursor.fetchone()
837
+
838
+ if not user_result:
839
+ conn.close()
840
+ return False, f"User not found: {user_email}"
841
+
842
+ user_id = user_result[0]
843
+
844
+ # Store pending subscription in database
845
+ sub_id = str(uuid.uuid4())
846
+ start_date = datetime.now()
847
+
848
+ cursor.execute(
849
+ "INSERT INTO subscriptions (id, user_id, tier, status, created_at, expires_at, paypal_subscription_id) VALUES (?, ?, ?, ?, ?, ?, ?)",
850
+ (sub_id, user_id, tier, "pending", start_date, None, subscription_data["id"])
851
+ )
852
+
853
+ conn.commit()
854
+ conn.close()
855
+
856
+ # Return success with approval URL
857
+ return True, {
858
+ "subscription_id": subscription_data["id"],
859
+ "approval_url": approval_url,
860
+ "tier": tier
861
+ }
862
+
863
+ except Exception as e:
864
+ logger.error(f"Error creating user subscription: {str(e)}")
865
+ return False, f"Error creating subscription: {str(e)}"
866
+
867
+ # Add a function to cancel a subscription
868
+ def cancel_subscription(subscription_id, reason="Customer requested cancellation"):
869
+ """
870
+ Cancel a PayPal subscription
871
+
872
+ Args:
873
+ subscription_id: The PayPal subscription ID
874
+ reason: The reason for cancellation
875
+
876
+ Returns:
877
+ tuple: (success, result)
878
+ - success: True if successful, False otherwise
879
+ - result: Success message or error message
880
+ """
881
+ try:
882
+ # Cancel the subscription with PayPal
883
+ payload = {
884
+ "reason": reason
885
+ }
886
+
887
+ success, result = call_paypal_api(
888
+ f"/v1/billing/subscriptions/{subscription_id}/cancel",
889
+ "POST",
890
+ payload
891
+ )
892
+
893
+ if not success:
894
+ return False, result
895
+
896
+ # Update subscription status in database
897
+ conn = get_db_connection()
898
+ cursor = conn.cursor()
899
+ cursor.execute(
900
+ "UPDATE subscriptions SET status = 'cancelled' WHERE paypal_subscription_id = ?",
901
+ (subscription_id,)
902
+ )
903
+
904
+ # Get the user ID for this subscription
905
+ cursor.execute(
906
+ "SELECT user_id FROM subscriptions WHERE paypal_subscription_id = ?",
907
+ (subscription_id,)
908
+ )
909
+ user_result = cursor.fetchone()
910
+
911
+ if user_result:
912
+ # Update user to free tier
913
+ cursor.execute(
914
+ "UPDATE users SET subscription_tier = 'free_tier' WHERE id = ?",
915
+ (user_result[0],)
916
+ )
917
+
918
+ conn.commit()
919
+ conn.close()
920
+
921
+ return True, "Subscription cancelled successfully"
922
+
923
+ except Exception as e:
924
+ logger.error(f"Error cancelling subscription: {str(e)}")
925
+ return False, f"Error cancelling subscription: {str(e)}"
926
+
927
+ def verify_subscription_payment(subscription_id):
928
+ """
929
+ Verify a subscription payment with PayPal
930
+
931
+ Args:
932
+ subscription_id: The PayPal subscription ID
933
+
934
+ Returns:
935
+ tuple: (success, result)
936
+ - success: True if successful, False otherwise
937
+ - result: Dictionary with subscription details or error message
938
+ """
939
+ try:
940
+ # Get subscription details from PayPal using our helper
941
+ success, subscription_data = call_paypal_api(f"/v1/billing/subscriptions/{subscription_id}")
942
+ if not success:
943
+ return False, subscription_data # This is already an error message
944
+
945
+ # Check subscription status
946
+ status = subscription_data.get("status", "").upper()
947
+
948
+ if status not in ["ACTIVE", "APPROVED"]:
949
+ return False, f"Subscription is not active: {status}"
950
+
951
+ # Return success with subscription data
952
+ return True, subscription_data
953
+
954
+ except Exception as e:
955
+ logger.error(f"Error verifying subscription: {str(e)}")
956
+ return False, f"Error verifying subscription: {str(e)}"
957
+
958
+ def verify_paypal_subscription(subscription_id):
959
+ """
960
+ Verify a PayPal subscription
961
+
962
+ Args:
963
+ subscription_id: The PayPal subscription ID
964
+
965
+ Returns:
966
+ tuple: (success, result)
967
+ """
968
+ try:
969
+ # Skip verification for mock subscriptions
970
+ if subscription_id.startswith("mock_sub_"):
971
+ return True, {"status": "ACTIVE"}
972
+
973
+ # For real subscriptions, call PayPal API
974
+ success, result = call_paypal_api(f"/v1/billing/subscriptions/{subscription_id}", "GET")
975
+
976
+ if success:
977
+ # Check subscription status
978
+ if result.get("status") == "ACTIVE":
979
+ return True, result
980
+ else:
981
+ return False, f"Subscription is not active: {result.get('status')}"
982
+ else:
983
+ logger.error(f"PayPal API error: {result}")
984
+ return False, f"Failed to verify subscription: {result}"
985
+ except Exception as e:
986
+ logger.error(f"Error verifying PayPal subscription: {str(e)}")
987
+ return False, f"Error verifying subscription: {str(e)}"
988
+
989
+ # Add this function to save subscription plans
990
+ def save_subscription_plans(plans):
991
+ """
992
+ Save subscription plans to a file
993
+
994
+ Args:
995
+ plans: Dictionary of plan IDs by tier
996
+ """
997
+ try:
998
+ with open(PLAN_IDS_PATH, 'w') as f:
999
+ json.dump(plans, f)
1000
+ logger.info(f"Saved subscription plans to {PLAN_IDS_PATH}")
1001
+ return True
1002
+ except Exception as e:
1003
+ logger.error(f"Error saving subscription plans: {str(e)}")
1004
+ return False
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.95.0
2
+ uvicorn>=0.21.1
3
+ pydantic>=1.10.7
4
+ python-multipart>=0.0.6
5
+ python-dotenv>=1.0.0
6
+ pdfplumber>=0.9.0
7
+ spacy>=3.5.2
8
+ torch>=2.0.0
9
+ transformers>=4.28.1
10
+ sentence-transformers>=2.2.2
11
+ moviepy>=1.0.3
12
+ matplotlib>=3.7.1
13
+ numpy>=1.24.2
14
+ passlib>=1.7.4
15
+ python-jose[cryptography]>=3.3.0
16
+ bcrypt>=4.0.1
17
+ requests>=2.28.2
18
+ SQLAlchemy>=2.0.9
19
+ aiofiles>=23.1.0
20
+ huggingface_hub>=0.16.4
21
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl