layoutlmv3-facturas-extractor / invoice_processor.py
Lucas Gagneten
less log
0f407f9
# invoice_processor.py
"""
Procesamiento de facturas: OCR, NER y visualización
"""
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
from doctr.io import DocumentFile
from io import BytesIO
from config import LABEL2COLOR, MAX_LENGTH, NORMALIZATION_FACTOR
from validator import InvoiceValidator
class InvoiceProcessor:
"""Clase para procesar facturas y extraer entidades."""
def __init__(self, model_manager):
"""
Inicializa el procesador de facturas.
Args:
model_manager: Instancia de ModelManager con los modelos cargados
"""
self.model_manager = model_manager
self.processor = model_manager.get_processor()
self.model = model_manager.get_model()
self.ocr_model = model_manager.get_ocr_model()
self.device = model_manager.get_device()
self.validator = InvoiceValidator() # ✅ AGREGADO
def extract_ocr_data(self, image: Image.Image):
"""
Extrae texto y bounding boxes usando DocTR.
Args:
image: Imagen PIL de la factura
Returns:
tuple: (words_data, image_width, image_height) o (None, None, None) en caso de error
"""
try:
rgb_image = image.convert("RGB")
img_byte_arr = BytesIO()
rgb_image.save(img_byte_arr, format='JPEG')
img_byte_arr.seek(0)
image_bytes = img_byte_arr.read()
doctr_doc = DocumentFile.from_images([image_bytes])
doctr_result = self.ocr_model(doctr_doc)
if not doctr_result.pages:
return None, None, None
page = doctr_result.pages[0]
words_data = []
for block in page.blocks:
for line in block.lines:
for word in line.words:
text = word.value
geom = np.array(word.geometry) * NORMALIZATION_FACTOR
xmin, ymin = map(int, geom[0])
xmax, ymax = map(int, geom[1])
words_data.append({"text": text, "box": [xmin, ymin, xmax, ymax]})
image_width, image_height = image.size
return words_data, image_width, image_height
except Exception as e:
print(f"Error en OCR: {e}")
return None, None, None
def perform_ner(self, image: Image.Image, words_data: list):
"""
Realiza NER sobre las palabras extraídas.
Args:
image: Imagen PIL
words_data: Lista de diccionarios con 'text' y 'box'
Returns:
list: Predicciones para cada palabra
"""
words = [wd["text"] for wd in words_data]
boxes = [wd["box"] for wd in words_data]
# Preprocesamiento
encoding = self.processor(
image, words, boxes=boxes, max_length=MAX_LENGTH,
truncation=True, padding="max_length", return_tensors="pt"
)
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
bbox = encoding["bbox"].to(self.device)
pixel_values = encoding["pixel_values"].to(self.device)
# Inferencia
self.model.eval()
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox,
pixel_values=pixel_values
)
predictions = outputs.logits.argmax(dim=-1).squeeze().tolist()
# Mapeo de predicciones a palabras
word_ids = encoding.word_ids()
predictions_final = []
current_word_index = None
for idx, pred_id in enumerate(predictions):
word_idx = word_ids[idx]
if word_idx is not None:
if word_idx != current_word_index:
if len(predictions_final) < len(words):
predictions_final.append(self.model.config.id2label[pred_id])
current_word_index = word_idx
return predictions_final
def group_entities(self, words_data: list, predictions: list):
"""
Agrupa entidades usando el esquema BIO y desduplicación.
Args:
words_data: Lista de palabras con sus bboxes
predictions: Predicciones NER para cada palabra
Returns:
list: Lista de entidades finales con etiqueta, valor y bbox
"""
ner_candidates = {}
current_entity = []
current_label = None
current_bbox_group = []
def save_current_entity(entity_list, label, bbox_list):
if not entity_list or not label:
return
all_x = [b[0] for b in bbox_list] + [b[2] for b in bbox_list]
all_y = [b[1] for b in bbox_list] + [b[3] for b in bbox_list]
bbox_normalized = [min(all_x), min(all_y), max(all_x), max(all_y)]
if label not in ner_candidates:
ner_candidates[label] = []
ner_candidates[label].append({
'valor': " ".join(entity_list),
'bbox_entity': bbox_normalized
})
for word_data, pred_label in zip(words_data, predictions):
word_text = word_data["text"]
word_box = word_data["box"]
tag_parts = pred_label.split('-', 1)
tag_type = tag_parts[0]
root_label = tag_parts[1] if len(tag_parts) > 1 else None
if tag_type == 'B':
save_current_entity(current_entity, current_label, current_bbox_group)
current_label = root_label
current_entity = [word_text]
current_bbox_group = [word_box]
elif tag_type == 'I':
if current_label == root_label:
current_entity.append(word_text)
current_bbox_group.append(word_box)
else:
save_current_entity(current_entity, current_label, current_bbox_group)
current_label = root_label
current_entity = [word_text]
current_bbox_group = [word_box]
elif tag_type == 'O':
save_current_entity(current_entity, current_label, current_bbox_group)
current_entity = []
current_label = None
current_bbox_group = []
save_current_entity(current_entity, current_label, current_bbox_group)
# Desduplicación: seleccionar el valor más largo
final_ner_results = []
for label, candidates in ner_candidates.items():
if not candidates:
continue
sorted_candidates = sorted(candidates, key=lambda x: len(x['valor']), reverse=True)
best_candidate = sorted_candidates[0]
final_ner_results.append({
'etiqueta': label,
'valor': best_candidate['valor'],
'bbox_entity': best_candidate['bbox_entity']
})
return final_ner_results
def draw_annotations(self, image: Image.Image, entities: list):
"""
Dibuja bounding boxes y etiquetas en la imagen.
Args:
image: Imagen PIL original
entities: Lista de entidades con bbox
Returns:
Image: Imagen anotada
"""
annotated_image = image.copy()
draw = ImageDraw.Draw(annotated_image)
image_width, image_height = image.size
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
for entity in entities:
label = entity['etiqueta']
min_x_norm, min_y_norm, max_x_norm, max_y_norm = entity['bbox_entity']
# Desnormalizar coordenadas
min_x = int(min_x_norm * image_width / NORMALIZATION_FACTOR)
min_y = int(min_y_norm * image_height / NORMALIZATION_FACTOR)
max_x = int(max_x_norm * image_width / NORMALIZATION_FACTOR)
max_y = int(max_y_norm * image_height / NORMALIZATION_FACTOR)
color = LABEL2COLOR.get(label, 'yellow')
draw.rectangle([min_x, min_y, max_x, max_y], outline=color, width=3)
draw.text((min_x, min_y - 20), label, fill=color, font=font)
return annotated_image
def process_invoice(self, image: Image.Image, filename: str):
"""
Procesa una factura completa: OCR + NER + visualización + validación.
Args:
image: Imagen PIL de la factura
filename: Nombre del archivo
Returns:
tuple: (filename, annotated_image, table_data, json_data)
"""
# 1. OCR
words_data, image_width, image_height = self.extract_ocr_data(image)
if words_data is None:
return filename, None, [["ERROR", "No se pudo realizar OCR"]], []
if not words_data:
return filename, None, [["ERROR", "No se encontró texto en la imagen"]], []
# Extraer lista de palabras para el validador
ocr_words = [wd["text"] for wd in words_data]
# 2. NER
try:
predictions = self.perform_ner(image, words_data)
except Exception as e:
return filename, None, [["ERROR", f"Error en NER: {e}"]], []
# 3. Agrupar entidades
entities = self.group_entities(words_data, predictions)
# 4. VALIDAR Y CORREGIR ENTIDADES
validated_table, validation_errors = self.validator.validate_and_correct(entities, ocr_words)
# 5. Dibujar anotaciones (solo las entidades detectadas originalmente)
annotated_image = self.draw_annotations(image, entities)
# 6. Preparar resultados
# validated_table ya viene como [etiqueta, valor] (sin columna de validación)
json_data = [
{
'etiqueta': row[0],
'valor': row[1]
}
for row in validated_table
]
return filename, annotated_image, validated_table, json_data