|
|
|
|
|
""" |
|
|
Procesamiento de facturas: OCR, NER y visualización |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import torch |
|
|
from doctr.io import DocumentFile |
|
|
from io import BytesIO |
|
|
from config import LABEL2COLOR, MAX_LENGTH, NORMALIZATION_FACTOR |
|
|
from validator import InvoiceValidator |
|
|
|
|
|
|
|
|
class InvoiceProcessor: |
|
|
"""Clase para procesar facturas y extraer entidades.""" |
|
|
|
|
|
def __init__(self, model_manager): |
|
|
""" |
|
|
Inicializa el procesador de facturas. |
|
|
|
|
|
Args: |
|
|
model_manager: Instancia de ModelManager con los modelos cargados |
|
|
""" |
|
|
self.model_manager = model_manager |
|
|
self.processor = model_manager.get_processor() |
|
|
self.model = model_manager.get_model() |
|
|
self.ocr_model = model_manager.get_ocr_model() |
|
|
self.device = model_manager.get_device() |
|
|
self.validator = InvoiceValidator() |
|
|
|
|
|
def extract_ocr_data(self, image: Image.Image): |
|
|
""" |
|
|
Extrae texto y bounding boxes usando DocTR. |
|
|
|
|
|
Args: |
|
|
image: Imagen PIL de la factura |
|
|
|
|
|
Returns: |
|
|
tuple: (words_data, image_width, image_height) o (None, None, None) en caso de error |
|
|
""" |
|
|
try: |
|
|
rgb_image = image.convert("RGB") |
|
|
img_byte_arr = BytesIO() |
|
|
rgb_image.save(img_byte_arr, format='JPEG') |
|
|
img_byte_arr.seek(0) |
|
|
image_bytes = img_byte_arr.read() |
|
|
|
|
|
doctr_doc = DocumentFile.from_images([image_bytes]) |
|
|
doctr_result = self.ocr_model(doctr_doc) |
|
|
|
|
|
if not doctr_result.pages: |
|
|
return None, None, None |
|
|
|
|
|
page = doctr_result.pages[0] |
|
|
words_data = [] |
|
|
|
|
|
for block in page.blocks: |
|
|
for line in block.lines: |
|
|
for word in line.words: |
|
|
text = word.value |
|
|
geom = np.array(word.geometry) * NORMALIZATION_FACTOR |
|
|
xmin, ymin = map(int, geom[0]) |
|
|
xmax, ymax = map(int, geom[1]) |
|
|
words_data.append({"text": text, "box": [xmin, ymin, xmax, ymax]}) |
|
|
|
|
|
image_width, image_height = image.size |
|
|
return words_data, image_width, image_height |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error en OCR: {e}") |
|
|
return None, None, None |
|
|
|
|
|
def perform_ner(self, image: Image.Image, words_data: list): |
|
|
""" |
|
|
Realiza NER sobre las palabras extraídas. |
|
|
|
|
|
Args: |
|
|
image: Imagen PIL |
|
|
words_data: Lista de diccionarios con 'text' y 'box' |
|
|
|
|
|
Returns: |
|
|
list: Predicciones para cada palabra |
|
|
""" |
|
|
words = [wd["text"] for wd in words_data] |
|
|
boxes = [wd["box"] for wd in words_data] |
|
|
|
|
|
|
|
|
encoding = self.processor( |
|
|
image, words, boxes=boxes, max_length=MAX_LENGTH, |
|
|
truncation=True, padding="max_length", return_tensors="pt" |
|
|
) |
|
|
|
|
|
input_ids = encoding["input_ids"].to(self.device) |
|
|
attention_mask = encoding["attention_mask"].to(self.device) |
|
|
bbox = encoding["bbox"].to(self.device) |
|
|
pixel_values = encoding["pixel_values"].to(self.device) |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
bbox=bbox, |
|
|
pixel_values=pixel_values |
|
|
) |
|
|
|
|
|
predictions = outputs.logits.argmax(dim=-1).squeeze().tolist() |
|
|
|
|
|
|
|
|
word_ids = encoding.word_ids() |
|
|
predictions_final = [] |
|
|
current_word_index = None |
|
|
|
|
|
for idx, pred_id in enumerate(predictions): |
|
|
word_idx = word_ids[idx] |
|
|
if word_idx is not None: |
|
|
if word_idx != current_word_index: |
|
|
if len(predictions_final) < len(words): |
|
|
predictions_final.append(self.model.config.id2label[pred_id]) |
|
|
current_word_index = word_idx |
|
|
|
|
|
return predictions_final |
|
|
|
|
|
def group_entities(self, words_data: list, predictions: list): |
|
|
""" |
|
|
Agrupa entidades usando el esquema BIO y desduplicación. |
|
|
|
|
|
Args: |
|
|
words_data: Lista de palabras con sus bboxes |
|
|
predictions: Predicciones NER para cada palabra |
|
|
|
|
|
Returns: |
|
|
list: Lista de entidades finales con etiqueta, valor y bbox |
|
|
""" |
|
|
ner_candidates = {} |
|
|
current_entity = [] |
|
|
current_label = None |
|
|
current_bbox_group = [] |
|
|
|
|
|
def save_current_entity(entity_list, label, bbox_list): |
|
|
if not entity_list or not label: |
|
|
return |
|
|
|
|
|
all_x = [b[0] for b in bbox_list] + [b[2] for b in bbox_list] |
|
|
all_y = [b[1] for b in bbox_list] + [b[3] for b in bbox_list] |
|
|
bbox_normalized = [min(all_x), min(all_y), max(all_x), max(all_y)] |
|
|
|
|
|
if label not in ner_candidates: |
|
|
ner_candidates[label] = [] |
|
|
|
|
|
ner_candidates[label].append({ |
|
|
'valor': " ".join(entity_list), |
|
|
'bbox_entity': bbox_normalized |
|
|
}) |
|
|
|
|
|
for word_data, pred_label in zip(words_data, predictions): |
|
|
word_text = word_data["text"] |
|
|
word_box = word_data["box"] |
|
|
tag_parts = pred_label.split('-', 1) |
|
|
tag_type = tag_parts[0] |
|
|
root_label = tag_parts[1] if len(tag_parts) > 1 else None |
|
|
|
|
|
if tag_type == 'B': |
|
|
save_current_entity(current_entity, current_label, current_bbox_group) |
|
|
current_label = root_label |
|
|
current_entity = [word_text] |
|
|
current_bbox_group = [word_box] |
|
|
elif tag_type == 'I': |
|
|
if current_label == root_label: |
|
|
current_entity.append(word_text) |
|
|
current_bbox_group.append(word_box) |
|
|
else: |
|
|
save_current_entity(current_entity, current_label, current_bbox_group) |
|
|
current_label = root_label |
|
|
current_entity = [word_text] |
|
|
current_bbox_group = [word_box] |
|
|
elif tag_type == 'O': |
|
|
save_current_entity(current_entity, current_label, current_bbox_group) |
|
|
current_entity = [] |
|
|
current_label = None |
|
|
current_bbox_group = [] |
|
|
|
|
|
save_current_entity(current_entity, current_label, current_bbox_group) |
|
|
|
|
|
|
|
|
final_ner_results = [] |
|
|
for label, candidates in ner_candidates.items(): |
|
|
if not candidates: |
|
|
continue |
|
|
sorted_candidates = sorted(candidates, key=lambda x: len(x['valor']), reverse=True) |
|
|
best_candidate = sorted_candidates[0] |
|
|
final_ner_results.append({ |
|
|
'etiqueta': label, |
|
|
'valor': best_candidate['valor'], |
|
|
'bbox_entity': best_candidate['bbox_entity'] |
|
|
}) |
|
|
|
|
|
return final_ner_results |
|
|
|
|
|
def draw_annotations(self, image: Image.Image, entities: list): |
|
|
""" |
|
|
Dibuja bounding boxes y etiquetas en la imagen. |
|
|
|
|
|
Args: |
|
|
image: Imagen PIL original |
|
|
entities: Lista de entidades con bbox |
|
|
|
|
|
Returns: |
|
|
Image: Imagen anotada |
|
|
""" |
|
|
annotated_image = image.copy() |
|
|
draw = ImageDraw.Draw(annotated_image) |
|
|
image_width, image_height = image.size |
|
|
|
|
|
try: |
|
|
font = ImageFont.truetype("arial.ttf", 20) |
|
|
except IOError: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
for entity in entities: |
|
|
label = entity['etiqueta'] |
|
|
min_x_norm, min_y_norm, max_x_norm, max_y_norm = entity['bbox_entity'] |
|
|
|
|
|
|
|
|
min_x = int(min_x_norm * image_width / NORMALIZATION_FACTOR) |
|
|
min_y = int(min_y_norm * image_height / NORMALIZATION_FACTOR) |
|
|
max_x = int(max_x_norm * image_width / NORMALIZATION_FACTOR) |
|
|
max_y = int(max_y_norm * image_height / NORMALIZATION_FACTOR) |
|
|
|
|
|
color = LABEL2COLOR.get(label, 'yellow') |
|
|
|
|
|
draw.rectangle([min_x, min_y, max_x, max_y], outline=color, width=3) |
|
|
draw.text((min_x, min_y - 20), label, fill=color, font=font) |
|
|
|
|
|
return annotated_image |
|
|
|
|
|
def process_invoice(self, image: Image.Image, filename: str): |
|
|
""" |
|
|
Procesa una factura completa: OCR + NER + visualización + validación. |
|
|
|
|
|
Args: |
|
|
image: Imagen PIL de la factura |
|
|
filename: Nombre del archivo |
|
|
|
|
|
Returns: |
|
|
tuple: (filename, annotated_image, table_data, json_data) |
|
|
""" |
|
|
|
|
|
words_data, image_width, image_height = self.extract_ocr_data(image) |
|
|
if words_data is None: |
|
|
return filename, None, [["ERROR", "No se pudo realizar OCR"]], [] |
|
|
|
|
|
if not words_data: |
|
|
return filename, None, [["ERROR", "No se encontró texto en la imagen"]], [] |
|
|
|
|
|
|
|
|
ocr_words = [wd["text"] for wd in words_data] |
|
|
|
|
|
|
|
|
try: |
|
|
predictions = self.perform_ner(image, words_data) |
|
|
except Exception as e: |
|
|
return filename, None, [["ERROR", f"Error en NER: {e}"]], [] |
|
|
|
|
|
|
|
|
entities = self.group_entities(words_data, predictions) |
|
|
|
|
|
|
|
|
validated_table, validation_errors = self.validator.validate_and_correct(entities, ocr_words) |
|
|
|
|
|
|
|
|
annotated_image = self.draw_annotations(image, entities) |
|
|
|
|
|
|
|
|
|
|
|
json_data = [ |
|
|
{ |
|
|
'etiqueta': row[0], |
|
|
'valor': row[1] |
|
|
} |
|
|
for row in validated_table |
|
|
] |
|
|
|
|
|
return filename, annotated_image, validated_table, json_data |