File size: 10,674 Bytes
809b92e
 
 
 
 
 
 
 
 
 
 
2779464
809b92e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
959ec8a
809b92e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
959ec8a
809b92e
 
 
 
 
 
 
 
 
 
 
546c83b
809b92e
 
546c83b
959ec8a
 
 
809b92e
 
 
 
 
546c83b
809b92e
 
 
 
0f407f9
959ec8a
 
 
809b92e
 
959ec8a
546c83b
809b92e
 
959ec8a
546c83b
809b92e
959ec8a
809b92e
 
959ec8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
# invoice_processor.py
"""
Procesamiento de facturas: OCR, NER y visualización
"""

import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
from doctr.io import DocumentFile
from io import BytesIO
from config import LABEL2COLOR, MAX_LENGTH, NORMALIZATION_FACTOR
from validator import InvoiceValidator


class InvoiceProcessor:
    """Clase para procesar facturas y extraer entidades."""
    
    def __init__(self, model_manager):
        """
        Inicializa el procesador de facturas.
        
        Args:
            model_manager: Instancia de ModelManager con los modelos cargados
        """
        self.model_manager = model_manager
        self.processor = model_manager.get_processor()
        self.model = model_manager.get_model()
        self.ocr_model = model_manager.get_ocr_model()
        self.device = model_manager.get_device()
        self.validator = InvoiceValidator()  # ✅ AGREGADO
    
    def extract_ocr_data(self, image: Image.Image):
        """
        Extrae texto y bounding boxes usando DocTR.
        
        Args:
            image: Imagen PIL de la factura
            
        Returns:
            tuple: (words_data, image_width, image_height) o (None, None, None) en caso de error
        """
        try:
            rgb_image = image.convert("RGB")
            img_byte_arr = BytesIO()
            rgb_image.save(img_byte_arr, format='JPEG')
            img_byte_arr.seek(0)
            image_bytes = img_byte_arr.read()
            
            doctr_doc = DocumentFile.from_images([image_bytes])
            doctr_result = self.ocr_model(doctr_doc)
            
            if not doctr_result.pages:
                return None, None, None
            
            page = doctr_result.pages[0]
            words_data = []
            
            for block in page.blocks:
                for line in block.lines:
                    for word in line.words:
                        text = word.value
                        geom = np.array(word.geometry) * NORMALIZATION_FACTOR
                        xmin, ymin = map(int, geom[0])
                        xmax, ymax = map(int, geom[1])
                        words_data.append({"text": text, "box": [xmin, ymin, xmax, ymax]})
            
            image_width, image_height = image.size
            return words_data, image_width, image_height
            
        except Exception as e:
            print(f"Error en OCR: {e}")
            return None, None, None
    
    def perform_ner(self, image: Image.Image, words_data: list):
        """
        Realiza NER sobre las palabras extraídas.
        
        Args:
            image: Imagen PIL
            words_data: Lista de diccionarios con 'text' y 'box'
            
        Returns:
            list: Predicciones para cada palabra
        """
        words = [wd["text"] for wd in words_data]
        boxes = [wd["box"] for wd in words_data]
        
        # Preprocesamiento
        encoding = self.processor(
            image, words, boxes=boxes, max_length=MAX_LENGTH, 
            truncation=True, padding="max_length", return_tensors="pt"
        )
        
        input_ids = encoding["input_ids"].to(self.device)
        attention_mask = encoding["attention_mask"].to(self.device)
        bbox = encoding["bbox"].to(self.device)
        pixel_values = encoding["pixel_values"].to(self.device)
        
        # Inferencia
        self.model.eval()
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids, 
                attention_mask=attention_mask,
                bbox=bbox, 
                pixel_values=pixel_values
            )
        
        predictions = outputs.logits.argmax(dim=-1).squeeze().tolist()
        
        # Mapeo de predicciones a palabras
        word_ids = encoding.word_ids()
        predictions_final = []
        current_word_index = None
        
        for idx, pred_id in enumerate(predictions):
            word_idx = word_ids[idx]
            if word_idx is not None:
                if word_idx != current_word_index:
                    if len(predictions_final) < len(words):
                        predictions_final.append(self.model.config.id2label[pred_id])
                    current_word_index = word_idx
        
        return predictions_final
    
    def group_entities(self, words_data: list, predictions: list):
        """
        Agrupa entidades usando el esquema BIO y desduplicación.
        
        Args:
            words_data: Lista de palabras con sus bboxes
            predictions: Predicciones NER para cada palabra
            
        Returns:
            list: Lista de entidades finales con etiqueta, valor y bbox
        """
        ner_candidates = {}
        current_entity = []
        current_label = None
        current_bbox_group = []
        
        def save_current_entity(entity_list, label, bbox_list):
            if not entity_list or not label:
                return
            
            all_x = [b[0] for b in bbox_list] + [b[2] for b in bbox_list]
            all_y = [b[1] for b in bbox_list] + [b[3] for b in bbox_list]
            bbox_normalized = [min(all_x), min(all_y), max(all_x), max(all_y)]
            
            if label not in ner_candidates:
                ner_candidates[label] = []
            
            ner_candidates[label].append({
                'valor': " ".join(entity_list),
                'bbox_entity': bbox_normalized
            })
        
        for word_data, pred_label in zip(words_data, predictions):
            word_text = word_data["text"]
            word_box = word_data["box"]
            tag_parts = pred_label.split('-', 1)
            tag_type = tag_parts[0]
            root_label = tag_parts[1] if len(tag_parts) > 1 else None
            
            if tag_type == 'B':
                save_current_entity(current_entity, current_label, current_bbox_group)
                current_label = root_label
                current_entity = [word_text]
                current_bbox_group = [word_box]
            elif tag_type == 'I':
                if current_label == root_label:
                    current_entity.append(word_text)
                    current_bbox_group.append(word_box)
                else:
                    save_current_entity(current_entity, current_label, current_bbox_group)
                    current_label = root_label
                    current_entity = [word_text]
                    current_bbox_group = [word_box]
            elif tag_type == 'O':
                save_current_entity(current_entity, current_label, current_bbox_group)
                current_entity = []
                current_label = None
                current_bbox_group = []
        
        save_current_entity(current_entity, current_label, current_bbox_group)
        
        # Desduplicación: seleccionar el valor más largo
        final_ner_results = []
        for label, candidates in ner_candidates.items():
            if not candidates:
                continue
            sorted_candidates = sorted(candidates, key=lambda x: len(x['valor']), reverse=True)
            best_candidate = sorted_candidates[0]
            final_ner_results.append({
                'etiqueta': label,
                'valor': best_candidate['valor'],
                'bbox_entity': best_candidate['bbox_entity']
            })
        
        return final_ner_results
    
    def draw_annotations(self, image: Image.Image, entities: list):
        """
        Dibuja bounding boxes y etiquetas en la imagen.
        
        Args:
            image: Imagen PIL original
            entities: Lista de entidades con bbox
            
        Returns:
            Image: Imagen anotada
        """
        annotated_image = image.copy()
        draw = ImageDraw.Draw(annotated_image)
        image_width, image_height = image.size
        
        try:
            font = ImageFont.truetype("arial.ttf", 20)
        except IOError:
            font = ImageFont.load_default()
        
        for entity in entities:
            label = entity['etiqueta']
            min_x_norm, min_y_norm, max_x_norm, max_y_norm = entity['bbox_entity']
            
            # Desnormalizar coordenadas
            min_x = int(min_x_norm * image_width / NORMALIZATION_FACTOR)
            min_y = int(min_y_norm * image_height / NORMALIZATION_FACTOR)
            max_x = int(max_x_norm * image_width / NORMALIZATION_FACTOR)
            max_y = int(max_y_norm * image_height / NORMALIZATION_FACTOR)
            
            color = LABEL2COLOR.get(label, 'yellow')
            
            draw.rectangle([min_x, min_y, max_x, max_y], outline=color, width=3)
            draw.text((min_x, min_y - 20), label, fill=color, font=font)
        
        return annotated_image
    
    def process_invoice(self, image: Image.Image, filename: str):
        """
        Procesa una factura completa: OCR + NER + visualización + validación.
        
        Args:
            image: Imagen PIL de la factura
            filename: Nombre del archivo
            
        Returns:
            tuple: (filename, annotated_image, table_data, json_data)
        """
        # 1. OCR
        words_data, image_width, image_height = self.extract_ocr_data(image)
        if words_data is None:
            return filename, None, [["ERROR", "No se pudo realizar OCR"]], []
        
        if not words_data:
            return filename, None, [["ERROR", "No se encontró texto en la imagen"]], []
        
        # Extraer lista de palabras para el validador
        ocr_words = [wd["text"] for wd in words_data]
        
        # 2. NER
        try:
            predictions = self.perform_ner(image, words_data)
        except Exception as e:
            return filename, None, [["ERROR", f"Error en NER: {e}"]], []
        
        # 3. Agrupar entidades
        entities = self.group_entities(words_data, predictions)
        
        # 4. VALIDAR Y CORREGIR ENTIDADES
        validated_table, validation_errors = self.validator.validate_and_correct(entities, ocr_words)
        
        # 5. Dibujar anotaciones (solo las entidades detectadas originalmente)
        annotated_image = self.draw_annotations(image, entities)
        
        # 6. Preparar resultados
        # validated_table ya viene como [etiqueta, valor] (sin columna de validación)
        json_data = [
            {
                'etiqueta': row[0], 
                'valor': row[1]
            } 
            for row in validated_table
        ]
        
        return filename, annotated_image, validated_table, json_data