File size: 8,717 Bytes
7dfe46c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
#!/usr/bin/env python3
"""
Fix Qdrant collection dimensions for Manufacturing RAG Agent
"""

import os
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client.http import models

load_dotenv()



# QDRANT_API_KEY= os.getenv('QDRANT_API_KEY', 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.DHeUsIY234NwS-6cYDJec807Vdzbs1PHmBBU3_Jz9oo') 
# QDRANT_URL=os.getenv('QDRANT_URL', 'https://50f53cc8-bbb0-4939-8254-8f025a577222.us-west-2-0.aws.cloud.qdrant.io:6333')

# QDRANT_URL= os.getenv('QDRANT_URL', 'http://localhost:6333')

def fix_qdrant_collection():
    """Fix the Qdrant collection dimensions."""
    
    print("πŸ”§ Fixing Qdrant Collection Dimensions")
    print("=" * 50)
    
    # Get connection details
    qdrant_api_key = os.environ["QDRANT_API_KEY"]
    qdrant_url = os.environ["QDRANT_URL"]
    collection_name = 'manufacturing_docs'
    
    if not qdrant_url:
        print("❌ QDRANT_URL not found in environment variables")
        return False
    
    try:
        # Connect to Qdrant
        print(f"πŸ”— Connecting to Qdrant: {qdrant_url}")
        client = QdrantClient(
            url="https://50f53cc8-bbb0-4939-8254-8f025a577222.us-west-2-0.aws.cloud.qdrant.io:6333", 
            api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.gHOXbfqPucRwhczrW8s3VSZbconqQ6Rk49Uaz9ZChdE",)
        
        # Check if collection exists
        collections = client.get_collections()
        collection_names = [col.name for col in collections.collections]
        
        if collection_name in collection_names:
            print(f"πŸ“‹ Collection '{collection_name}' exists")
            
            # Get collection info
            collection_info = client.get_collection(collection_name)
            current_dim = collection_info.config.params.vectors.size
            print(f"πŸ“ Current vector dimensions: {current_dim}")
            
            if current_dim != 1024:
                print(f"⚠️  Need to recreate collection with correct dimensions (1024)")
                
                # Ask for confirmation
                response = input("πŸ—‘οΈ  Delete existing collection and recreate? (y/N): ").strip().lower()
                if response != 'y':
                    print("❌ Aborted by user")
                    return False
                
                # Delete existing collection
                print(f"πŸ—‘οΈ  Deleting collection '{collection_name}'...")
                client.delete_collection(collection_name)
                print("βœ… Collection deleted")
            else:
                print("βœ… Collection already has correct dimensions")
                return True
        
        # Create new collection with correct dimensions
        print(f"πŸ†• Creating collection '{collection_name}' with 4096 dimensions...")
        
        client.create_collection(
            collection_name=collection_name,
            vectors_config=models.VectorParams(
                size=4096,  # Correct size for Qwen/Qwen3-Embedding-8B
                distance=models.Distance.COSINE
            )
        )
        
        # Create payload indexes
        print("πŸ” Creating payload indexes...")
        
        indexes_to_create = [
            ("document_id", models.PayloadFieldSchema(
                data_type=models.PayloadSchemaType.KEYWORD
            )),
            ("document_type", models.PayloadFieldSchema(
                data_type=models.PayloadSchemaType.KEYWORD
            )),
            ("page_number", models.PayloadFieldSchema(
                data_type=models.PayloadSchemaType.INTEGER
            )),
            ("worksheet_name", models.PayloadFieldSchema(
                data_type=models.PayloadSchemaType.KEYWORD
            )),
        ]
        
        for field_name, field_schema in indexes_to_create:
            try:
                client.create_payload_index(
                    collection_name=collection_name,
                    field_name=field_name,
                    field_schema=field_schema
                )
                print(f"βœ… Created index for '{field_name}'")
            except Exception as e:
                print(f"⚠️  Failed to create index for '{field_name}': {e}")
        
        print("βœ… Collection recreated successfully with correct dimensions!")
        return True
        
    except Exception as e:
        
        print(f"❌ Error: {e}")
        return False

def update_config_file():
    """Update config.yaml with correct vector dimensions."""
    
    print("\nπŸ”§ Updating Configuration")
    print("=" * 30)
    
    config_path = "src/config.yaml"
    
    if not os.path.exists(config_path):
        print(f"❌ Config file not found: {config_path}")
        return False
    
    try:
        # Read current config
        with open(config_path, 'r') as f:
            content = f.read()
        
        # Update vector_size if it exists
        import re
        
        # Look for vector_size configuration
        if 'vector_size:' in content:
            # Replace vector_size value
            content = re.sub(r'vector_size:\s*\d+', 'vector_size: 4096', content)
            print("βœ… Updated vector_size to 4096")
        else:
            # Add vector_size to vector_store section
            if 'vector_store:' in content:
                content = re.sub(
                    r'(vector_store:\s*\n)',
                    r'\1  vector_size: 4096\n',
                    content
                )
                print("βœ… Added vector_size: 4096 to vector_store section")
            else:
                print("⚠️  No vector_store section found, please add manually:")
                print("vector_store:")
                print("  vector_size: 4096")
        
        # Write updated config
        with open(config_path, 'w') as f:
            f.write(content)
        
        print(f"βœ… Updated {config_path}")
        return True
        
    except Exception as e:
        print(f"❌ Error updating config: {e}")
        return False

def test_embedding_dimensions():
    """Test the actual embedding dimensions from SiliconFlow."""
    
    print("\nπŸ§ͺ Testing Embedding Dimensions")
    print("=" * 35)
    
    try:
        import requests
        
        api_key = os.getenv('SILICONFLOW_API_KEY')
        if not api_key:
            print("❌ SILICONFLOW_API_KEY not found")
            return None
        
        # Test embedding generation
        payload = {
            "model": "Qwen/Qwen3-Embedding-8B",
            "input": ["test embedding dimension"],
            "encoding_format": "float"
        }
        
        headers = {
            'Authorization': f'Bearer {api_key}',
            'Content-Type': 'application/json'
        }
        
        response = requests.post(
            "https://api.siliconflow.com/v1/embeddings",
            json=payload,
            headers=headers,
            timeout=10
        )
        
        if response.status_code == 200:
            data = response.json()
            if data.get('data') and len(data['data']) > 0:
                embedding = data['data'][0]['embedding']
                dim = len(embedding)
                print(f"βœ… Actual embedding dimensions: {dim}")
                return dim
            else:
                print("❌ No embedding data returned")
        else:
            print(f"❌ API error: {response.status_code} - {response.text}")
            
    except Exception as e:
        print(f"❌ Error testing embeddings: {e}")
    
    return None

def main():
    """Main function."""
    
    print("🏭 Manufacturing RAG Agent - Dimension Fix")
    print("=" * 60)
    
    # Test actual embedding dimensions
    actual_dim = test_embedding_dimensions()
    
    if actual_dim and actual_dim != 4096:
        print(f"⚠️  Warning: Expected 4096 dimensions, but got {actual_dim}")
        print("You may need to update the vector_size in your config")
    
    # Fix Qdrant collection
    if fix_qdrant_collection():
        print("\nβœ… Qdrant collection fixed successfully!")
    else:
        print("\n❌ Failed to fix Qdrant collection")
        return
    
    # Update config file
    if update_config_file():
        print("βœ… Configuration updated successfully!")
    else:
        print("⚠️  Please update config manually")
    
    print("\nπŸŽ‰ Fix Complete!")
    print("\nπŸ“‹ Next Steps:")
    print("1. Restart your Gradio demo")
    print("2. Re-upload your documents")
    print("3. Test question answering")
    
    print("\nπŸš€ To restart the demo:")
    print("python fixed_gradio_demo.py")

if __name__ == "__main__":
    main()