File size: 3,733 Bytes
51a8322
3ce8aff
51a8322
0ad83bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51a8322
 
 
0ad83bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51a8322
 
 
0ad83bb
 
 
 
 
 
 
 
 
 
 
 
51a8322
 
 
0ad83bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51a8322
 
 
3ce8aff
0ad83bb
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
""""
TEDS Code Adapted from https://github.com/ibm-aur-nlp/EDD
"""

import distance
from apted import APTED, Config
from apted.helpers import Tree
from lxml import html
from collections import deque

def wrap_table_html(table_html:str)->str:
    return f'<html><body>{table_html}</body></html>'

class TableTree(Tree):
    def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
        self.tag = tag
        self.colspan = colspan
        self.rowspan = rowspan
        self.content = content

        # Sets self.name and self.children
        super().__init__(tag, *children)

    def bracket(self):
        """Show tree using brackets notation"""
        if self.tag == 'td':
            result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
                     (self.tag, self.colspan, self.rowspan, self.content)
        else:
            result = '"tag": %s' % self.tag
        for child in self.children:
            result += child.bracket()
        return "{{{}}}".format(result)

class CustomConfig(Config):
    @staticmethod
    def maximum(*sequences):
        return max(map(len, sequences))

    def normalized_distance(self, *sequences):
        return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)

    def rename(self, node1, node2):
        if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
            return 1.
        if node1.tag == 'td':
            if node1.content or node2.content:
                return self.normalized_distance(node1.content, node2.content)
        return 0.

def tokenize(node):
    """
    Tokenizes table cells
    """
    global __tokens__
    __tokens__.append('<%s>' % node.tag)
    if node.text is not None:
        __tokens__ += list(node.text)
    for n in node.getchildren():
        tokenize(n)
    if node.tag != 'unk':
        __tokens__.append('</%s>' % node.tag)
    if node.tag != 'td' and node.tail is not None:
            __tokens__ += list(node.tail)

def tree_convert_html(node, convert_cell=False, parent=None):
    """
    Converts HTML tree to the format required by apted
    """
    global __tokens__
    if node.tag == 'td':
        if convert_cell:
            __tokens__ = []
            tokenize(node)
            cell = __tokens__[1:-1].copy()
        else:
            cell = []
        new_node = TableTree(node.tag,
                             int(node.attrib.get('colspan', '1')),
                             int(node.attrib.get('rowspan', '1')),
                             cell, *deque())
    else:
        new_node = TableTree(node.tag, None, None, None, *deque())
    if parent is not None:
        parent.children.append(new_node)
    if node.tag != 'td':
        for n in node.getchildren():
            tree_convert_html(n, convert_cell, new_node)
    if parent is None:
        return new_node

def similarity_eval_html(pred, true, structure_only=False):
    """
    Computes TEDS score between the prediction and the ground truth of a given samples
    """
    pred, true = html.fromstring(pred), html.fromstring(true)
    if pred.xpath('body/table') and true.xpath('body/table'):
        pred = pred.xpath('body/table')[0]
        true = true.xpath('body/table')[0]
        n_nodes_pred = len(pred.xpath(".//*"))
        n_nodes_true = len(true.xpath(".//*"))
        tree_pred = tree_convert_html(pred, convert_cell=not structure_only)
        tree_true = tree_convert_html(true, convert_cell=not structure_only)
        n_nodes = max(n_nodes_pred, n_nodes_true)
        distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
        return 1.0 - (float(distance) / n_nodes)
    else:
        return 0.0