j-higgins commited on
Commit
7528883
·
unverified ·
1 Parent(s): 579a8ff

Update app.py

Browse files

updated w/ smaller model and better processing

Files changed (1) hide show
  1. app.py +135 -135
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import dash
2
  import dash_bootstrap_components as dbc
3
  import pandas as pd
4
- from dash import dcc, html
5
  from dash.dash_table import DataTable
6
  from dash.dependencies import Output, Input, State
7
  import plotly.express as px
@@ -12,6 +12,7 @@ from gliner_spacy.pipeline import GlinerSpacy
12
  import warnings
13
  import os
14
 
 
15
  warnings.filterwarnings("ignore", message="The sentencepiece tokenizer")
16
 
17
  # Initialize Dash app with Bootstrap theme and Font Awesome
@@ -27,53 +28,78 @@ CATEGORIES_FILE = os.path.join(BASE_DIR, 'google_categories(v2).txt')
27
  # Configuration for GLiNER integration
28
  custom_spacy_config = {
29
  "gliner_model": "urchade/gliner_small-v2.1",
30
- "chunk_size": 250,
31
  "labels": ["person", "organization", "location", "event", "work_of_art", "product", "service", "date", "number", "price", "address", "phone_number", "misc"],
32
- "style": "ent",
33
- "threshold": 0.3
34
  }
35
 
36
- # Model variables
37
  nlp = None
38
  sentence_model = None
39
 
40
- # Function to load models
41
- def load_models():
42
- global nlp, sentence_model
43
- nlp = spacy.blank("en")
44
- nlp.add_pipe("gliner_spacy", config=custom_spacy_config)
45
- sentence_model = SentenceTransformer('all-roberta-large-v1')
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Load Google's content categories
48
- with open(CATEGORIES_FILE, 'r') as f:
49
- google_categories = [line.strip() for line in f]
 
 
 
50
 
51
  # Function to perform NER using GLiNER with spaCy
52
  def perform_ner(text):
53
- doc = nlp(text)
54
- return [(ent.text, ent.label_) for ent in doc.ents]
 
 
 
55
 
56
  # Function to extract entities using GLiNER with spaCy
57
  def extract_entities(text):
58
- doc = nlp(text)
59
- entities = [(ent.text, ent.label_) for ent in doc.ents]
60
- return entities if entities else ["No specific entities found"]
 
 
 
61
 
62
  # Function to precompute category embeddings
63
  def compute_category_embeddings():
64
- return sentence_model.encode(google_categories)
 
 
 
65
 
66
  # Function to perform topic modeling using sentence transformers
67
  def perform_topic_modeling_from_similarities(similarities):
68
- top_indices = similarities.argsort()[-3:][::-1]
69
-
70
- best_match = google_categories[top_indices[0]]
71
- second_best = google_categories[top_indices[1]]
72
-
73
- if similarities[top_indices[0]] > similarities[top_indices[1]] * 1.1:
74
- return best_match
75
- else:
76
- return f"{best_match} , {second_best}"
 
 
 
77
 
78
  # Function to sort keywords by intent feature
79
  def sort_by_keyword_feature(f):
@@ -145,45 +171,49 @@ def sort_by_keyword_feature(f):
145
  return "other"
146
 
147
  # Optimized batch processing of keywords
148
- def batch_process_keywords(keywords, batch_size=32):
149
  processed_data = {'Keywords': [], 'Intent': [], 'NER Entities': [], 'Google Content Topics': []}
150
 
151
- # Precompute keyword embeddings once
152
- keyword_embeddings = sentence_model.encode(keywords, batch_size=batch_size, show_progress_bar=True)
153
-
154
- # Compute category embeddings
155
- category_embeddings = compute_category_embeddings()
156
-
157
- for i in range(0, len(keywords), batch_size):
158
- batch = keywords[i:i+batch_size]
159
- batch_embeddings = keyword_embeddings[i:i+batch_size]
160
 
161
- # Batch process intents
162
- intents = [sort_by_keyword_feature(kw) for kw in batch]
163
 
164
- # Batch process entities
165
- entities = [extract_entities(kw) for kw in batch]
166
-
167
- # Batch process topics
168
- similarities = cosine_similarity(batch_embeddings, category_embeddings)
169
- Google_Content_Topics = [perform_topic_modeling_from_similarities(sim) for sim in similarities]
170
-
171
- processed_data['Keywords'].extend(batch)
172
- processed_data['Intent'].extend(intents)
173
-
174
- # Convert entities to strings, handling both tuples and strings
175
- processed_entities = []
176
- for entity_list in entities:
177
- entity_strings = []
178
- for entity in entity_list:
179
- if isinstance(entity, tuple):
180
- entity_strings.append(f"{entity[0]} ({entity[1]})")
181
- else:
182
- entity_strings.append(str(entity))
183
- processed_entities.append(", ".join(entity_strings))
184
-
185
- processed_data['NER Entities'].extend(processed_entities)
186
- processed_data['Google Content Topics'].extend(Google_Content_Topics)
 
 
 
 
 
 
 
 
 
 
187
 
188
  return processed_data
189
 
@@ -206,12 +236,6 @@ app.layout = dbc.Container([
206
 
207
  dbc.Row([
208
  dbc.Col([
209
- dbc.Alert(
210
- "Models are loading. This may take a few minutes. Please wait...",
211
- id="loading-alert",
212
- color="info",
213
- is_open=True,
214
- ),
215
  dbc.Label('Enter keywords (one per line, maximum of 100):', className='text-light'),
216
  dcc.Textarea(id='keyword-input', value='', style={'width': '100%', 'height': 100}),
217
  dbc.Button('Submit', id='submit-button', color='primary', className='mb-3', disabled=True),
@@ -220,20 +244,16 @@ app.layout = dbc.Container([
220
  ], width=6)
221
  ], justify='center'),
222
 
223
- # Loading component
224
  dbc.Row([
225
  dbc.Col([
226
  dcc.Loading(
227
  id="loading",
228
  type="default",
229
- children=[
230
- html.Div([
231
- html.Div(id="loading-output")
232
- ], className="my-4")
233
- ],
234
  ),
235
  ], width=12)
236
- ], justify='center', className="mb-4"), # Added margin-bottom for separation
 
237
  dbc.Row(dbc.Col(dcc.Graph(id='bar-chart'), width=12)),
238
 
239
  dbc.Row([
@@ -261,7 +281,7 @@ app.layout = dbc.Container([
261
  dcc.Download(id='download'),
262
  dcc.Store(id='processed-data'),
263
 
264
- # Explanation content
265
  dbc.Row([
266
  dbc.Col([
267
  html.Div([
@@ -340,75 +360,52 @@ app.layout = dbc.Container([
340
 
341
  ], fluid=True)
342
 
343
- # Callback to load models and update the loading alert
344
  @app.callback(
345
  [Output('models-loaded', 'data'),
346
- Output('loading-alert', 'is_open'),
347
- Output('submit-button', 'disabled')],
348
- [Input('models-loaded', 'data')]
349
- )
350
- def load_models_callback(loaded):
351
- if not loaded:
352
- load_models()
353
- return True, False, False
354
- return loaded, False, False
355
-
356
- # Callback for smooth scrolling
357
- app.clientside_callback(
358
- """
359
- function(n_clicks) {
360
- const links = document.querySelectorAll('a[href^="#"]');
361
- links.forEach(link => {
362
- link.addEventListener('click', function(e) {
363
- e.preventDefault();
364
- const targetId = this.getAttribute('href').substring(1);
365
- const targetElement = document.getElementById(targetId);
366
- if (targetElement) {
367
- targetElement.scrollIntoView({behavior: 'smooth'});
368
- }
369
- });
370
- });
371
- return '';
372
- }
373
- """,
374
- Output('dummy-output', 'children'),
375
- Input('dummy-input', 'children')
376
- )
377
-
378
- # All other callbacks
379
- @app.callback(
380
- Output('alert', 'is_open'),
381
- Output('alert', 'children'),
382
- [Input('submit-button', 'n_clicks')],
383
- [State('keyword-input', 'value')]
384
- )
385
- def limit_keywords(n_clicks, keyword_input):
386
- if n_clicks is None:
387
- return False, ""
388
-
389
- keywords = keyword_input.split('\n')
390
- if len(keywords) > 100:
391
- return True, "Maximum limit of 100 keywords exceeded. Only the first 100 keywords will be processed."
392
-
393
- return False, ""
394
-
395
- @app.callback(
396
- [Output('processed-data', 'data'),
397
  Output('loading-output', 'children'),
398
  Output('processing-alert', 'is_open'),
399
  Output('processing-alert', 'children')],
400
- [Input('submit-button', 'n_clicks')],
 
401
  [State('keyword-input', 'value')]
402
  )
403
- def process_keywords(n_clicks, keyword_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  if n_clicks is None or not keyword_input:
405
- return None, '', False, ''
406
 
407
  keywords = [kw.strip() for kw in keyword_input.split('\n')[:100] if kw.strip()]
408
  processed_data = batch_process_keywords(keywords)
409
 
410
- return processed_data, '', True, "Keyword processing complete!"
411
 
 
412
  @app.callback(
413
  Output('bar-chart', 'figure'),
414
  [Input('processed-data', 'data')]
@@ -418,7 +415,7 @@ def update_bar_chart(processed_data):
418
  return {
419
  'data': [],
420
  'layout': {
421
- 'height': 0, # Set height to 0 when there's no data
422
  'annotations': [{
423
  'text': '',
424
  'xref': 'paper',
@@ -441,7 +438,7 @@ def update_bar_chart(processed_data):
441
  plot_bgcolor='#222222',
442
  paper_bgcolor='#222222',
443
  font_color='white',
444
- height=400, # Set a fixed height for the chart
445
  legend=dict(
446
  orientation="h",
447
  yanchor="bottom",
@@ -453,6 +450,7 @@ def update_bar_chart(processed_data):
453
 
454
  return fig
455
 
 
456
  @app.callback(
457
  [Output('table-intent-dropdown', 'options'),
458
  Output('download-button', 'disabled')],
@@ -467,6 +465,7 @@ def update_dropdown_and_button(processed_data):
467
  options = [{'label': intent, 'value': intent} for intent in intents]
468
  return options, False
469
 
 
470
  @app.callback(
471
  Output('keywords-table', 'children'),
472
  [Input('table-intent-dropdown', 'value')],
@@ -492,6 +491,7 @@ def update_keywords_table(selected_intent, processed_data):
492
  )
493
  return table
494
 
 
495
  @app.callback(
496
  Output('download', 'data'),
497
  [Input('download-button', 'n_clicks')],
 
1
  import dash
2
  import dash_bootstrap_components as dbc
3
  import pandas as pd
4
+ from dash import dcc, html, callback_context
5
  from dash.dash_table import DataTable
6
  from dash.dependencies import Output, Input, State
7
  import plotly.express as px
 
12
  import warnings
13
  import os
14
 
15
+ # Suppress specific warnings
16
  warnings.filterwarnings("ignore", message="The sentencepiece tokenizer")
17
 
18
  # Initialize Dash app with Bootstrap theme and Font Awesome
 
28
  # Configuration for GLiNER integration
29
  custom_spacy_config = {
30
  "gliner_model": "urchade/gliner_small-v2.1",
31
+ "chunk_size": 128,
32
  "labels": ["person", "organization", "location", "event", "work_of_art", "product", "service", "date", "number", "price", "address", "phone_number", "misc"],
33
+ "threshold": 0.5
 
34
  }
35
 
36
+ # Model variables for lazy loading
37
  nlp = None
38
  sentence_model = None
39
 
40
+ # Function to lazy load NLP model
41
+ def get_nlp():
42
+ global nlp
43
+ if nlp is None:
44
+ try:
45
+ nlp = spacy.blank("en")
46
+ nlp.add_pipe("gliner_spacy", config=custom_spacy_config)
47
+ except Exception as e:
48
+ raise
49
+ return nlp
50
+
51
+ # Function to lazy load sentence transformer model
52
+ def get_sentence_model():
53
+ global sentence_model
54
+ if sentence_model is None:
55
+ sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
56
+ return sentence_model
57
 
58
  # Load Google's content categories
59
+ try:
60
+ with open(CATEGORIES_FILE, 'r') as f:
61
+ google_categories = [line.strip() for line in f]
62
+ except Exception as e:
63
+ google_categories = []
64
 
65
  # Function to perform NER using GLiNER with spaCy
66
  def perform_ner(text):
67
+ try:
68
+ doc = get_nlp()(text)
69
+ return [(ent.text, ent.label_) for ent in doc.ents]
70
+ except Exception as e:
71
+ return []
72
 
73
  # Function to extract entities using GLiNER with spaCy
74
  def extract_entities(text):
75
+ try:
76
+ doc = get_nlp()(text)
77
+ entities = [(ent.text, ent.label_) for ent in doc.ents]
78
+ return entities if entities else ["No specific entities found"]
79
+ except Exception as e:
80
+ return ["Error extracting entities"]
81
 
82
  # Function to precompute category embeddings
83
  def compute_category_embeddings():
84
+ try:
85
+ return get_sentence_model().encode(google_categories)
86
+ except Exception as e:
87
+ return []
88
 
89
  # Function to perform topic modeling using sentence transformers
90
  def perform_topic_modeling_from_similarities(similarities):
91
+ try:
92
+ top_indices = similarities.argsort()[-3:][::-1]
93
+
94
+ best_match = google_categories[top_indices[0]]
95
+ second_best = google_categories[top_indices[1]]
96
+
97
+ if similarities[top_indices[0]] > similarities[top_indices[1]] * 1.1:
98
+ return best_match
99
+ else:
100
+ return f"{best_match} , {second_best}"
101
+ except Exception as e:
102
+ return "Error in topic modeling"
103
 
104
  # Function to sort keywords by intent feature
105
  def sort_by_keyword_feature(f):
 
171
  return "other"
172
 
173
  # Optimized batch processing of keywords
174
+ def batch_process_keywords(keywords, batch_size=16):
175
  processed_data = {'Keywords': [], 'Intent': [], 'NER Entities': [], 'Google Content Topics': []}
176
 
177
+ try:
178
+ # Precompute keyword embeddings once
179
+ keyword_embeddings = get_sentence_model().encode(keywords, batch_size=batch_size, show_progress_bar=False)
 
 
 
 
 
 
180
 
181
+ # Compute category embeddings
182
+ category_embeddings = compute_category_embeddings()
183
 
184
+ for i in range(0, len(keywords), batch_size):
185
+ batch = keywords[i:i+batch_size]
186
+ batch_embeddings = keyword_embeddings[i:i+batch_size]
187
+
188
+ # Batch process intents
189
+ intents = [sort_by_keyword_feature(kw) for kw in batch]
190
+
191
+ # Batch process entities
192
+ entities = [extract_entities(kw) for kw in batch]
193
+
194
+ # Batch process topics
195
+ similarities = cosine_similarity(batch_embeddings, category_embeddings)
196
+ Google_Content_Topics = [perform_topic_modeling_from_similarities(sim) for sim in similarities]
197
+
198
+ processed_data['Keywords'].extend(batch)
199
+ processed_data['Intent'].extend(intents)
200
+
201
+ # Convert entities to strings, handling both tuples and strings
202
+ processed_entities = []
203
+ for entity_list in entities:
204
+ entity_strings = []
205
+ for entity in entity_list:
206
+ if isinstance(entity, tuple):
207
+ entity_strings.append(f"{entity[0]} ({entity[1]})")
208
+ else:
209
+ entity_strings.append(str(entity))
210
+ processed_entities.append(", ".join(entity_strings))
211
+
212
+ processed_data['NER Entities'].extend(processed_entities)
213
+ processed_data['Google Content Topics'].extend(Google_Content_Topics)
214
+
215
+ except Exception as e:
216
+ pass
217
 
218
  return processed_data
219
 
 
236
 
237
  dbc.Row([
238
  dbc.Col([
 
 
 
 
 
 
239
  dbc.Label('Enter keywords (one per line, maximum of 100):', className='text-light'),
240
  dcc.Textarea(id='keyword-input', value='', style={'width': '100%', 'height': 100}),
241
  dbc.Button('Submit', id='submit-button', color='primary', className='mb-3', disabled=True),
 
244
  ], width=6)
245
  ], justify='center'),
246
 
 
247
  dbc.Row([
248
  dbc.Col([
249
  dcc.Loading(
250
  id="loading",
251
  type="default",
252
+ children=[html.Div(id="loading-output", className="my-4")]
 
 
 
 
253
  ),
254
  ], width=12)
255
+ ], justify='center', className="mb-4"),
256
+
257
  dbc.Row(dbc.Col(dcc.Graph(id='bar-chart'), width=12)),
258
 
259
  dbc.Row([
 
281
  dcc.Download(id='download'),
282
  dcc.Store(id='processed-data'),
283
 
284
+ # Explanation content
285
  dbc.Row([
286
  dbc.Col([
287
  html.Div([
 
360
 
361
  ], fluid=True)
362
 
363
+ # Combined callback
364
  @app.callback(
365
  [Output('models-loaded', 'data'),
366
+ Output('submit-button', 'disabled'),
367
+ Output('alert', 'is_open'),
368
+ Output('alert', 'children'),
369
+ Output('alert', 'color'),
370
+ Output('processed-data', 'data'),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  Output('loading-output', 'children'),
372
  Output('processing-alert', 'is_open'),
373
  Output('processing-alert', 'children')],
374
+ [Input('models-loaded', 'data'),
375
+ Input('submit-button', 'n_clicks')],
376
  [State('keyword-input', 'value')]
377
  )
378
+ def combined_callback(loaded, n_clicks, keyword_input):
379
+ ctx = callback_context
380
+ triggered_id = ctx.triggered[0]['prop_id'].split('.')[0]
381
+
382
+ if triggered_id == 'models-loaded':
383
+ return handle_model_loading(loaded)
384
+ elif triggered_id == 'submit-button':
385
+ return handle_keyword_processing(n_clicks, keyword_input)
386
+ else:
387
+ # Default return values
388
+ return loaded, False, False, "", "success", None, '', False, ''
389
+
390
+ def handle_model_loading(loaded):
391
+ if not loaded:
392
+ try:
393
+ # Lazy loading will occur when models are first used
394
+ return True, False, True, "Models ready to load", "success", None, '', False, ''
395
+ except Exception as e:
396
+ return False, True, True, f"Error preparing models: {str(e)}", "danger", None, '', False, ''
397
+ return loaded, not loaded, False, "", "success", None, '', False, ''
398
+
399
+ def handle_keyword_processing(n_clicks, keyword_input):
400
  if n_clicks is None or not keyword_input:
401
+ return True, False, False, "", "success", None, '', False, ''
402
 
403
  keywords = [kw.strip() for kw in keyword_input.split('\n')[:100] if kw.strip()]
404
  processed_data = batch_process_keywords(keywords)
405
 
406
+ return True, False, False, "", "success", processed_data, '', True, "Keyword processing complete!"
407
 
408
+ # Callback for updating the bar chart
409
  @app.callback(
410
  Output('bar-chart', 'figure'),
411
  [Input('processed-data', 'data')]
 
415
  return {
416
  'data': [],
417
  'layout': {
418
+ 'height': 0,
419
  'annotations': [{
420
  'text': '',
421
  'xref': 'paper',
 
438
  plot_bgcolor='#222222',
439
  paper_bgcolor='#222222',
440
  font_color='white',
441
+ height=400,
442
  legend=dict(
443
  orientation="h",
444
  yanchor="bottom",
 
450
 
451
  return fig
452
 
453
+ # Callback for updating the dropdown and download button
454
  @app.callback(
455
  [Output('table-intent-dropdown', 'options'),
456
  Output('download-button', 'disabled')],
 
465
  options = [{'label': intent, 'value': intent} for intent in intents]
466
  return options, False
467
 
468
+ # Callback for updating the keywords table
469
  @app.callback(
470
  Output('keywords-table', 'children'),
471
  [Input('table-intent-dropdown', 'value')],
 
491
  )
492
  return table
493
 
494
+ # Callback for downloading CSV
495
  @app.callback(
496
  Output('download', 'data'),
497
  [Input('download-button', 'n_clicks')],