Vik Paruchuri commited on
Commit
ec69c20
·
1 Parent(s): 9c7b7b3

Swap over to t5 editor

Browse files
benchmark.py CHANGED
@@ -119,7 +119,7 @@ if __name__ == "__main__":
119
  score_headers = benchmark_files
120
  for method in methods:
121
  summary_table.append([method, write_data[method]["avg_score"], write_data[method]["time_per_page"], write_data[method]["time_per_doc"]])
122
- score_table.append([method, *[write_data[method]["scores"][h] for h in score_headers]])
123
 
124
  print(tabulate(summary_table, headers=["Method", "Average Score", "Time per page", "Time per document"]))
125
  print("")
 
119
  score_headers = benchmark_files
120
  for method in methods:
121
  summary_table.append([method, write_data[method]["avg_score"], write_data[method]["time_per_page"], write_data[method]["time_per_doc"]])
122
+ score_table.append([method, *[write_data[method]["files"][h]["score"] for h in score_headers]])
123
 
124
  print(tabulate(summary_table, headers=["Method", "Average Score", "Time per page", "Time per document"]))
125
  print("")
marker/cleaners/equations.py CHANGED
@@ -108,7 +108,7 @@ def get_nougat_text_batched(images, reformat_region_lens, nougat_model, batch_si
108
  for j, output in enumerate(model_output["predictions"]):
109
  disclaimer = ""
110
  token_count = get_total_nougat_tokens(output, nougat_model)
111
- if token_count >= max_length:
112
  disclaimer = "[TRUNCATED]"
113
 
114
  image_idx = idx * batch_size + j
 
108
  for j, output in enumerate(model_output["predictions"]):
109
  disclaimer = ""
110
  token_count = get_total_nougat_tokens(output, nougat_model)
111
+ if token_count >= max_length - 1:
112
  disclaimer = "[TRUNCATED]"
113
 
114
  image_idx = idx * batch_size + j
marker/postprocessors/editor.py CHANGED
@@ -3,11 +3,11 @@ from itertools import chain
3
  from typing import Optional
4
  import re
5
 
6
- from transformers import BloomForTokenClassification, AutoTokenizer
7
  from marker.settings import settings
8
  import torch
9
  import torch.nn.functional as F
10
- from marker.postprocessors.t5 import T5ForTokenClassification
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(settings.EDITOR_MODEL_NAME)
13
 
@@ -37,24 +37,18 @@ def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_s
37
  if not model:
38
  return text, {}
39
 
40
- tokenized = tokenizer(
41
- text,
42
- truncation=True,
43
- max_length=settings.EDITOR_MAX_LENGTH,
44
- return_overflowing_tokens=True,
45
- padding="max_length",
46
- )
47
  input_ids = tokenized["input_ids"]
 
48
 
49
  # Tokenize, and make sure reverse tokenization works
50
  model_tokens = [tokenizer.convert_ids_to_tokens(t, skip_special_tokens=True) for t in input_ids]
51
- full_text = "".join(model_tokens)
 
52
  assert full_text == text
53
 
54
  # List of characters in the text
55
- model_tokens = [tokenizer.convert_ids_to_tokens(t) for t in input_ids]
56
- flat_model_tokens = list(chain.from_iterable(model_tokens))
57
- flat_str_tokens = list(text)
58
 
59
  # Run model
60
  token_masks = []
@@ -72,47 +66,50 @@ def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_s
72
  # We want to be conservative to not edit the text too much
73
  probs = F.softmax(logits, dim=-1)
74
  max_prob = torch.max(probs, dim=-1)
75
- cutoff_prob = max_prob.values < 0.9
76
  labels = logits.argmax(-1).squeeze()
77
  labels[cutoff_prob] = model.config.label2id["equal"]
78
-
79
  labels = labels.tolist()
80
  if len(labels) == settings.EDITOR_MAX_LENGTH:
81
  labels = [labels]
82
  labels = list(chain.from_iterable(labels))
83
  token_masks.extend(labels)
84
 
85
- # Strip special tokens
86
- assert len(token_masks) == len(flat_model_tokens)
87
- token_masks = [mask for mask, token in zip(token_masks, flat_model_tokens) if token not in ["<pad>", "<s>", "</s>"]]
88
 
89
- assert len(token_masks) == len(flat_str_tokens)
90
 
91
  edit_stats = defaultdict(int)
92
- out_tokens = []
93
- for i, (str_token, mask) in enumerate(zip(flat_str_tokens, token_masks)):
94
- label = model.config.id2label[mask]
95
-
96
- match label:
97
- case "equal":
98
- out_tokens.append(str_token)
99
- edit_stats[label] += 1
100
- case "delete":
101
- # If we delete whitespace, roll with it, otherwise ignore
102
- if str_token.strip():
103
- out_tokens.append(str_token)
104
- else:
105
- edit_stats[label] += 1
106
- case "newline-1":
107
- out_tokens.append("\n")
108
- out_tokens.append(str_token)
109
- edit_stats[label] += 1
110
- case "space-1":
111
- out_tokens.append(" ")
112
- out_tokens.append(str_token)
113
- edit_stats[label] += 1
114
-
115
- return "".join(out_tokens), edit_stats
 
 
 
 
116
 
117
 
118
 
 
3
  from typing import Optional
4
  import re
5
 
6
+ from transformers import AutoTokenizer
7
  from marker.settings import settings
8
  import torch
9
  import torch.nn.functional as F
10
+ from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(settings.EDITOR_MODEL_NAME)
13
 
 
37
  if not model:
38
  return text, {}
39
 
40
+ tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH)
 
 
 
 
 
 
41
  input_ids = tokenized["input_ids"]
42
+ char_token_lengths = tokenized["char_token_lengths"]
43
 
44
  # Tokenize, and make sure reverse tokenization works
45
  model_tokens = [tokenizer.convert_ids_to_tokens(t, skip_special_tokens=True) for t in input_ids]
46
+ model_tokens_str = [tokenizer.convert_tokens_to_string(t) for t in model_tokens]
47
+ full_text = "".join(model_tokens_str)
48
  assert full_text == text
49
 
50
  # List of characters in the text
51
+ flat_input_ids = list(chain.from_iterable(input_ids))
 
 
52
 
53
  # Run model
54
  token_masks = []
 
66
  # We want to be conservative to not edit the text too much
67
  probs = F.softmax(logits, dim=-1)
68
  max_prob = torch.max(probs, dim=-1)
69
+ cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
70
  labels = logits.argmax(-1).squeeze()
71
  labels[cutoff_prob] = model.config.label2id["equal"]
 
72
  labels = labels.tolist()
73
  if len(labels) == settings.EDITOR_MAX_LENGTH:
74
  labels = [labels]
75
  labels = list(chain.from_iterable(labels))
76
  token_masks.extend(labels)
77
 
78
+ # Strip special tokens 0,1. Keep unknown token, although it should never be used
79
+ assert len(token_masks) == len(flat_input_ids)
80
+ token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]
81
 
82
+ assert len(token_masks) == len(list(text.encode("utf-8")))
83
 
84
  edit_stats = defaultdict(int)
85
+ out_text = []
86
+ start = 0
87
+ for i, char in enumerate(text):
88
+ char_token_length = char_token_lengths[i]
89
+ masks = token_masks[start: start + char_token_length]
90
+ labels = [model.config.id2label[mask] for mask in masks]
91
+ if all(l == "delete" for l in labels):
92
+ # If we delete whitespace, roll with it, otherwise ignore
93
+ if char.strip():
94
+ out_text.append(char)
95
+ else:
96
+ edit_stats["delete"] += 1
97
+ elif labels[0] == "newline-1":
98
+ out_text.append("\n")
99
+ out_text.append(char)
100
+ edit_stats["newline-1"] += 1
101
+ elif labels[0] == "space-1":
102
+ out_text.append(" ")
103
+ out_text.append(char)
104
+ edit_stats["space-1"] += 1
105
+ else:
106
+ out_text.append(char)
107
+ edit_stats["equal"] += 1
108
+
109
+ start += char_token_length
110
+
111
+ out_text = "".join(out_text)
112
+ return out_text, edit_stats
113
 
114
 
115
 
marker/postprocessors/t5.py CHANGED
@@ -1,91 +1,141 @@
1
- from transformers import T5ForSequenceClassification, T5Config
2
  import torch
 
 
3
  from typing import Optional, Tuple, Union, List
 
4
 
5
- from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput
 
 
6
 
7
 
8
- class T5ForTokenClassification(T5ForSequenceClassification):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def forward(
11
  self,
12
- input_ids: torch.LongTensor = None,
13
- attention_mask: Optional[torch.Tensor] = None,
14
- decoder_input_ids: Optional[torch.LongTensor] = None,
15
- decoder_attention_mask: Optional[torch.LongTensor] = None,
16
- head_mask: Optional[torch.Tensor] = None,
17
- decoder_head_mask: Optional[torch.Tensor] = None,
18
- cross_attn_head_mask: Optional[torch.Tensor] = None,
19
- encoder_outputs: Optional[List[torch.FloatTensor]] = None,
20
  inputs_embeds: Optional[torch.FloatTensor] = None,
21
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
22
  labels: Optional[torch.LongTensor] = None,
23
- use_cache: Optional[bool] = None,
24
  output_attentions: Optional[bool] = None,
25
  output_hidden_states: Optional[bool] = None,
26
  return_dict: Optional[bool] = None,
27
- ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
28
- r"""
29
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
30
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
31
- config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
32
- Returns:
33
- """
34
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
35
- if labels is not None:
36
- use_cache = False
37
-
38
- if input_ids is None and inputs_embeds is not None:
39
- raise NotImplementedError(
40
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
41
- )
42
-
43
- # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
44
- # decoder_input_ids from input_ids if no decoder_input_ids are provided
45
- if decoder_input_ids is None and decoder_inputs_embeds is None:
46
- if input_ids is None:
47
- raise ValueError(
48
- "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
49
- "passed, `input_ids` cannot be `None`. Please pass either "
50
- "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
51
- )
52
- decoder_input_ids = self._shift_right(input_ids)
53
-
54
- outputs = self.transformer(
55
- input_ids,
56
  attention_mask=attention_mask,
57
- decoder_input_ids=decoder_input_ids,
58
- decoder_attention_mask=decoder_attention_mask,
59
- head_mask=head_mask,
60
- decoder_head_mask=decoder_head_mask,
61
- cross_attn_head_mask=cross_attn_head_mask,
62
- encoder_outputs=encoder_outputs,
63
  inputs_embeds=inputs_embeds,
64
- decoder_inputs_embeds=decoder_inputs_embeds,
65
- use_cache=use_cache,
66
  output_attentions=output_attentions,
67
  output_hidden_states=output_hidden_states,
68
  return_dict=return_dict,
69
  )
70
- # Make predictions for all tokens
71
  sequence_output = outputs[0]
72
- logits = self.classification_head(sequence_output)
73
 
74
- assert labels.numel() * self.config.num_labels == logits.numel()
 
 
75
  loss = None
76
 
77
  if not return_dict:
78
- output = (logits,) + outputs[1:]
79
  return ((loss,) + output) if loss is not None else output
80
 
81
- return Seq2SeqSequenceClassifierOutput(
82
  loss=loss,
83
  logits=logits,
84
- past_key_values=outputs.past_key_values,
85
- decoder_hidden_states=outputs.decoder_hidden_states,
86
- decoder_attentions=outputs.decoder_attentions,
87
- cross_attentions=outputs.cross_attentions,
88
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
89
- encoder_hidden_states=outputs.encoder_hidden_states,
90
- encoder_attentions=outputs.encoder_attentions,
91
  )
 
1
+ from transformers import T5Config, T5PreTrainedModel
2
  import torch
3
+ from torch import nn
4
+ from copy import deepcopy
5
  from typing import Optional, Tuple, Union, List
6
+ from itertools import chain
7
 
8
+ from transformers.modeling_outputs import TokenClassifierOutput
9
+ from transformers.models.t5.modeling_t5 import T5Stack
10
+ from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
11
 
12
 
13
+ def byt5_tokenize(text: str, max_length: int, pad_token_id: int = 0):
14
+ byte_codes = []
15
+ for char in text:
16
+ # Add 3 to account for special tokens
17
+ byte_codes.append([byte + 3 for byte in char.encode('utf-8')])
18
+
19
+ tokens = list(chain.from_iterable(byte_codes))
20
+ # Map each token to the character it represents
21
+ char_token_lengths = [len(b) for b in byte_codes]
22
+
23
+ batched_tokens = []
24
+ attention_mask = []
25
+ for i in range(0, len(tokens), max_length):
26
+ batched_tokens.append(tokens[i:i + max_length])
27
+ attention_mask.append([1] * len(batched_tokens[-1]))
28
+
29
+ # Pad last item
30
+ if len(batched_tokens[-1]) < max_length:
31
+ batched_tokens[-1] += [pad_token_id] * (max_length - len(batched_tokens[-1]))
32
+ attention_mask[-1] += [0] * (max_length - len(attention_mask[-1]))
33
+
34
+ return {"input_ids": batched_tokens, "attention_mask": attention_mask, "char_token_lengths": char_token_lengths}
35
+
36
+
37
+
38
+
39
+ # From https://github.com/osainz59/t5-encoder
40
+ class T5ForTokenClassification(T5PreTrainedModel):
41
+ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
42
+
43
+ def __init__(self, config: T5Config):
44
+ super().__init__(config)
45
+ self.model_dim = config.d_model
46
+
47
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
48
+
49
+ encoder_config = deepcopy(config)
50
+ encoder_config.is_decoder = False
51
+ encoder_config.is_encoder_decoder = False
52
+ encoder_config.use_cache = False
53
+ self.encoder = T5Stack(encoder_config, self.shared)
54
+
55
+ classifier_dropout = (
56
+ config.classifier_dropout if hasattr(config, 'classifier_dropout') else config.dropout_rate
57
+ )
58
+ self.dropout = nn.Dropout(classifier_dropout)
59
+ self.classifier = nn.Linear(config.d_model, config.num_labels)
60
+
61
+ # Initialize weights and apply final processing
62
+ self.post_init()
63
+
64
+ # Model parallel
65
+ self.model_parallel = False
66
+ self.device_map = None
67
+
68
+
69
+ def parallelize(self, device_map=None):
70
+ self.device_map = (
71
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
72
+ if device_map is None
73
+ else device_map
74
+ )
75
+ assert_device_map(self.device_map, len(self.encoder.block))
76
+ self.encoder.parallelize(self.device_map)
77
+ self.classifier.to(self.encoder.first_device)
78
+ self.model_parallel = True
79
+
80
+ def deparallelize(self):
81
+ self.encoder.deparallelize()
82
+ self.encoder = self.encoder.to("cpu")
83
+ self.classifier = self.classifier.to("cpu")
84
+ self.model_parallel = False
85
+ self.device_map = None
86
+ torch.cuda.empty_cache()
87
+
88
+ def get_input_embeddings(self):
89
+ return self.shared
90
+
91
+ def set_input_embeddings(self, new_embeddings):
92
+ self.shared = new_embeddings
93
+ self.encoder.set_input_embeddings(new_embeddings)
94
+
95
+ def get_encoder(self):
96
+ return self.encoder
97
+
98
+ def _prune_heads(self, heads_to_prune):
99
+ for layer, heads in heads_to_prune.items():
100
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
101
 
102
  def forward(
103
  self,
104
+ input_ids: Optional[torch.LongTensor] = None,
105
+ attention_mask: Optional[torch.FloatTensor] = None,
106
+ head_mask: Optional[torch.FloatTensor] = None,
 
 
 
 
 
107
  inputs_embeds: Optional[torch.FloatTensor] = None,
 
108
  labels: Optional[torch.LongTensor] = None,
 
109
  output_attentions: Optional[bool] = None,
110
  output_hidden_states: Optional[bool] = None,
111
  return_dict: Optional[bool] = None,
112
+ ) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
 
 
 
 
 
 
113
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
114
+
115
+ outputs = self.encoder(
116
+ input_ids=input_ids,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  attention_mask=attention_mask,
 
 
 
 
 
 
118
  inputs_embeds=inputs_embeds,
119
+ head_mask=head_mask,
 
120
  output_attentions=output_attentions,
121
  output_hidden_states=output_hidden_states,
122
  return_dict=return_dict,
123
  )
124
+
125
  sequence_output = outputs[0]
 
126
 
127
+ sequence_output = self.dropout(sequence_output)
128
+ logits = self.classifier(sequence_output)
129
+
130
  loss = None
131
 
132
  if not return_dict:
133
+ output = (logits,) + outputs[2:]
134
  return ((loss,) + output) if loss is not None else output
135
 
136
+ return TokenClassifierOutput(
137
  loss=loss,
138
  logits=logits,
139
+ hidden_states=outputs.hidden_states,
140
+ attentions=outputs.attentions
 
 
 
 
 
141
  )
marker/settings.py CHANGED
@@ -54,7 +54,7 @@ class Settings(BaseSettings):
54
  NOUGAT_MODEL_MAX: int = 512 # Max inference length for nougat
55
  NOUGAT_TOKEN_BUFFER: int = 256 # Number of tokens to buffer above max for nougat
56
  NOUGAT_HALLUCINATION_WORDS: List[str] = ["[MISSING_PAGE_POST]", "## References\n", "**Figure Captions**\n", "Footnote",
57
- "\par\par\par", "## Chapter", "Fig.", "particle", "[REPEATS]", "[TRUNCATED]", "### "]
58
  NOUGAT_DPI: int = 96 # DPI to render images at, matches default settings for nougat
59
  NOUGAT_MODEL_NAME: str = "0.1.0-small" # Name of the model to use
60
  NOUGAT_BATCH_SIZE: int = 6 if TORCH_DEVICE == "cuda" else 1 # Batch size for nougat, don't batch on cpu
@@ -74,8 +74,9 @@ class Settings(BaseSettings):
74
  # Final editing model
75
  EDITOR_BATCH_SIZE: int = 4
76
  EDITOR_MAX_LENGTH: int = 2048
77
- EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor"
78
- ENABLE_EDITOR_MODEL: bool = False # The editor model can create false positives
 
79
 
80
  # Ray
81
  RAY_CACHE_PATH: Optional[str] = None # Where to save ray cache
 
54
  NOUGAT_MODEL_MAX: int = 512 # Max inference length for nougat
55
  NOUGAT_TOKEN_BUFFER: int = 256 # Number of tokens to buffer above max for nougat
56
  NOUGAT_HALLUCINATION_WORDS: List[str] = ["[MISSING_PAGE_POST]", "## References\n", "**Figure Captions**\n", "Footnote",
57
+ "\par\par\par", "## Chapter", "Fig.", "particle", "[REPEATS]", "[TRUNCATED]", "### ", "effective field strength", "\Phi_{\rm eff}"]
58
  NOUGAT_DPI: int = 96 # DPI to render images at, matches default settings for nougat
59
  NOUGAT_MODEL_NAME: str = "0.1.0-small" # Name of the model to use
60
  NOUGAT_BATCH_SIZE: int = 6 if TORCH_DEVICE == "cuda" else 1 # Batch size for nougat, don't batch on cpu
 
74
  # Final editing model
75
  EDITOR_BATCH_SIZE: int = 4
76
  EDITOR_MAX_LENGTH: int = 2048
77
+ EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5"
78
+ ENABLE_EDITOR_MODEL: bool = True # The editor model can create false positives
79
+ EDITOR_CUTOFF_THRESH: float = 0.75 # Ignore predictions below this probability
80
 
81
  # Ray
82
  RAY_CACHE_PATH: Optional[str] = None # Where to save ray cache