Vik Paruchuri
commited on
Commit
·
ec69c20
1
Parent(s):
9c7b7b3
Swap over to t5 editor
Browse files- benchmark.py +1 -1
- marker/cleaners/equations.py +1 -1
- marker/postprocessors/editor.py +40 -43
- marker/postprocessors/t5.py +111 -61
- marker/settings.py +4 -3
benchmark.py
CHANGED
|
@@ -119,7 +119,7 @@ if __name__ == "__main__":
|
|
| 119 |
score_headers = benchmark_files
|
| 120 |
for method in methods:
|
| 121 |
summary_table.append([method, write_data[method]["avg_score"], write_data[method]["time_per_page"], write_data[method]["time_per_doc"]])
|
| 122 |
-
score_table.append([method, *[write_data[method]["
|
| 123 |
|
| 124 |
print(tabulate(summary_table, headers=["Method", "Average Score", "Time per page", "Time per document"]))
|
| 125 |
print("")
|
|
|
|
| 119 |
score_headers = benchmark_files
|
| 120 |
for method in methods:
|
| 121 |
summary_table.append([method, write_data[method]["avg_score"], write_data[method]["time_per_page"], write_data[method]["time_per_doc"]])
|
| 122 |
+
score_table.append([method, *[write_data[method]["files"][h]["score"] for h in score_headers]])
|
| 123 |
|
| 124 |
print(tabulate(summary_table, headers=["Method", "Average Score", "Time per page", "Time per document"]))
|
| 125 |
print("")
|
marker/cleaners/equations.py
CHANGED
|
@@ -108,7 +108,7 @@ def get_nougat_text_batched(images, reformat_region_lens, nougat_model, batch_si
|
|
| 108 |
for j, output in enumerate(model_output["predictions"]):
|
| 109 |
disclaimer = ""
|
| 110 |
token_count = get_total_nougat_tokens(output, nougat_model)
|
| 111 |
-
if token_count >= max_length:
|
| 112 |
disclaimer = "[TRUNCATED]"
|
| 113 |
|
| 114 |
image_idx = idx * batch_size + j
|
|
|
|
| 108 |
for j, output in enumerate(model_output["predictions"]):
|
| 109 |
disclaimer = ""
|
| 110 |
token_count = get_total_nougat_tokens(output, nougat_model)
|
| 111 |
+
if token_count >= max_length - 1:
|
| 112 |
disclaimer = "[TRUNCATED]"
|
| 113 |
|
| 114 |
image_idx = idx * batch_size + j
|
marker/postprocessors/editor.py
CHANGED
|
@@ -3,11 +3,11 @@ from itertools import chain
|
|
| 3 |
from typing import Optional
|
| 4 |
import re
|
| 5 |
|
| 6 |
-
from transformers import
|
| 7 |
from marker.settings import settings
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
| 10 |
-
from marker.postprocessors.t5 import T5ForTokenClassification
|
| 11 |
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained(settings.EDITOR_MODEL_NAME)
|
| 13 |
|
|
@@ -37,24 +37,18 @@ def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_s
|
|
| 37 |
if not model:
|
| 38 |
return text, {}
|
| 39 |
|
| 40 |
-
tokenized =
|
| 41 |
-
text,
|
| 42 |
-
truncation=True,
|
| 43 |
-
max_length=settings.EDITOR_MAX_LENGTH,
|
| 44 |
-
return_overflowing_tokens=True,
|
| 45 |
-
padding="max_length",
|
| 46 |
-
)
|
| 47 |
input_ids = tokenized["input_ids"]
|
|
|
|
| 48 |
|
| 49 |
# Tokenize, and make sure reverse tokenization works
|
| 50 |
model_tokens = [tokenizer.convert_ids_to_tokens(t, skip_special_tokens=True) for t in input_ids]
|
| 51 |
-
|
|
|
|
| 52 |
assert full_text == text
|
| 53 |
|
| 54 |
# List of characters in the text
|
| 55 |
-
|
| 56 |
-
flat_model_tokens = list(chain.from_iterable(model_tokens))
|
| 57 |
-
flat_str_tokens = list(text)
|
| 58 |
|
| 59 |
# Run model
|
| 60 |
token_masks = []
|
|
@@ -72,47 +66,50 @@ def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_s
|
|
| 72 |
# We want to be conservative to not edit the text too much
|
| 73 |
probs = F.softmax(logits, dim=-1)
|
| 74 |
max_prob = torch.max(probs, dim=-1)
|
| 75 |
-
cutoff_prob = max_prob.values <
|
| 76 |
labels = logits.argmax(-1).squeeze()
|
| 77 |
labels[cutoff_prob] = model.config.label2id["equal"]
|
| 78 |
-
|
| 79 |
labels = labels.tolist()
|
| 80 |
if len(labels) == settings.EDITOR_MAX_LENGTH:
|
| 81 |
labels = [labels]
|
| 82 |
labels = list(chain.from_iterable(labels))
|
| 83 |
token_masks.extend(labels)
|
| 84 |
|
| 85 |
-
# Strip special tokens
|
| 86 |
-
assert len(token_masks) == len(
|
| 87 |
-
token_masks = [mask for mask, token in zip(token_masks,
|
| 88 |
|
| 89 |
-
assert len(token_masks) == len(
|
| 90 |
|
| 91 |
edit_stats = defaultdict(int)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
|
|
|
|
| 3 |
from typing import Optional
|
| 4 |
import re
|
| 5 |
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
from marker.settings import settings
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
| 10 |
+
from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize
|
| 11 |
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained(settings.EDITOR_MODEL_NAME)
|
| 13 |
|
|
|
|
| 37 |
if not model:
|
| 38 |
return text, {}
|
| 39 |
|
| 40 |
+
tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
input_ids = tokenized["input_ids"]
|
| 42 |
+
char_token_lengths = tokenized["char_token_lengths"]
|
| 43 |
|
| 44 |
# Tokenize, and make sure reverse tokenization works
|
| 45 |
model_tokens = [tokenizer.convert_ids_to_tokens(t, skip_special_tokens=True) for t in input_ids]
|
| 46 |
+
model_tokens_str = [tokenizer.convert_tokens_to_string(t) for t in model_tokens]
|
| 47 |
+
full_text = "".join(model_tokens_str)
|
| 48 |
assert full_text == text
|
| 49 |
|
| 50 |
# List of characters in the text
|
| 51 |
+
flat_input_ids = list(chain.from_iterable(input_ids))
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# Run model
|
| 54 |
token_masks = []
|
|
|
|
| 66 |
# We want to be conservative to not edit the text too much
|
| 67 |
probs = F.softmax(logits, dim=-1)
|
| 68 |
max_prob = torch.max(probs, dim=-1)
|
| 69 |
+
cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
|
| 70 |
labels = logits.argmax(-1).squeeze()
|
| 71 |
labels[cutoff_prob] = model.config.label2id["equal"]
|
|
|
|
| 72 |
labels = labels.tolist()
|
| 73 |
if len(labels) == settings.EDITOR_MAX_LENGTH:
|
| 74 |
labels = [labels]
|
| 75 |
labels = list(chain.from_iterable(labels))
|
| 76 |
token_masks.extend(labels)
|
| 77 |
|
| 78 |
+
# Strip special tokens 0,1. Keep unknown token, although it should never be used
|
| 79 |
+
assert len(token_masks) == len(flat_input_ids)
|
| 80 |
+
token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]
|
| 81 |
|
| 82 |
+
assert len(token_masks) == len(list(text.encode("utf-8")))
|
| 83 |
|
| 84 |
edit_stats = defaultdict(int)
|
| 85 |
+
out_text = []
|
| 86 |
+
start = 0
|
| 87 |
+
for i, char in enumerate(text):
|
| 88 |
+
char_token_length = char_token_lengths[i]
|
| 89 |
+
masks = token_masks[start: start + char_token_length]
|
| 90 |
+
labels = [model.config.id2label[mask] for mask in masks]
|
| 91 |
+
if all(l == "delete" for l in labels):
|
| 92 |
+
# If we delete whitespace, roll with it, otherwise ignore
|
| 93 |
+
if char.strip():
|
| 94 |
+
out_text.append(char)
|
| 95 |
+
else:
|
| 96 |
+
edit_stats["delete"] += 1
|
| 97 |
+
elif labels[0] == "newline-1":
|
| 98 |
+
out_text.append("\n")
|
| 99 |
+
out_text.append(char)
|
| 100 |
+
edit_stats["newline-1"] += 1
|
| 101 |
+
elif labels[0] == "space-1":
|
| 102 |
+
out_text.append(" ")
|
| 103 |
+
out_text.append(char)
|
| 104 |
+
edit_stats["space-1"] += 1
|
| 105 |
+
else:
|
| 106 |
+
out_text.append(char)
|
| 107 |
+
edit_stats["equal"] += 1
|
| 108 |
+
|
| 109 |
+
start += char_token_length
|
| 110 |
+
|
| 111 |
+
out_text = "".join(out_text)
|
| 112 |
+
return out_text, edit_stats
|
| 113 |
|
| 114 |
|
| 115 |
|
marker/postprocessors/t5.py
CHANGED
|
@@ -1,91 +1,141 @@
|
|
| 1 |
-
from transformers import
|
| 2 |
import torch
|
|
|
|
|
|
|
| 3 |
from typing import Optional, Tuple, Union, List
|
|
|
|
| 4 |
|
| 5 |
-
from transformers.modeling_outputs import
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def forward(
|
| 11 |
self,
|
| 12 |
-
input_ids: torch.LongTensor = None,
|
| 13 |
-
attention_mask: Optional[torch.
|
| 14 |
-
|
| 15 |
-
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
| 16 |
-
head_mask: Optional[torch.Tensor] = None,
|
| 17 |
-
decoder_head_mask: Optional[torch.Tensor] = None,
|
| 18 |
-
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
| 19 |
-
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
| 20 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 21 |
-
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 22 |
labels: Optional[torch.LongTensor] = None,
|
| 23 |
-
use_cache: Optional[bool] = None,
|
| 24 |
output_attentions: Optional[bool] = None,
|
| 25 |
output_hidden_states: Optional[bool] = None,
|
| 26 |
return_dict: Optional[bool] = None,
|
| 27 |
-
) -> Union[Tuple,
|
| 28 |
-
r"""
|
| 29 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 30 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 31 |
-
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 32 |
-
Returns:
|
| 33 |
-
"""
|
| 34 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
if input_ids is None and inputs_embeds is not None:
|
| 39 |
-
raise NotImplementedError(
|
| 40 |
-
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
# Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
|
| 44 |
-
# decoder_input_ids from input_ids if no decoder_input_ids are provided
|
| 45 |
-
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
| 46 |
-
if input_ids is None:
|
| 47 |
-
raise ValueError(
|
| 48 |
-
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
|
| 49 |
-
"passed, `input_ids` cannot be `None`. Please pass either "
|
| 50 |
-
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
|
| 51 |
-
)
|
| 52 |
-
decoder_input_ids = self._shift_right(input_ids)
|
| 53 |
-
|
| 54 |
-
outputs = self.transformer(
|
| 55 |
-
input_ids,
|
| 56 |
attention_mask=attention_mask,
|
| 57 |
-
decoder_input_ids=decoder_input_ids,
|
| 58 |
-
decoder_attention_mask=decoder_attention_mask,
|
| 59 |
-
head_mask=head_mask,
|
| 60 |
-
decoder_head_mask=decoder_head_mask,
|
| 61 |
-
cross_attn_head_mask=cross_attn_head_mask,
|
| 62 |
-
encoder_outputs=encoder_outputs,
|
| 63 |
inputs_embeds=inputs_embeds,
|
| 64 |
-
|
| 65 |
-
use_cache=use_cache,
|
| 66 |
output_attentions=output_attentions,
|
| 67 |
output_hidden_states=output_hidden_states,
|
| 68 |
return_dict=return_dict,
|
| 69 |
)
|
| 70 |
-
|
| 71 |
sequence_output = outputs[0]
|
| 72 |
-
logits = self.classification_head(sequence_output)
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
loss = None
|
| 76 |
|
| 77 |
if not return_dict:
|
| 78 |
-
output = (logits,) + outputs[
|
| 79 |
return ((loss,) + output) if loss is not None else output
|
| 80 |
|
| 81 |
-
return
|
| 82 |
loss=loss,
|
| 83 |
logits=logits,
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
decoder_attentions=outputs.decoder_attentions,
|
| 87 |
-
cross_attentions=outputs.cross_attentions,
|
| 88 |
-
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 89 |
-
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 90 |
-
encoder_attentions=outputs.encoder_attentions,
|
| 91 |
)
|
|
|
|
| 1 |
+
from transformers import T5Config, T5PreTrainedModel
|
| 2 |
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from copy import deepcopy
|
| 5 |
from typing import Optional, Tuple, Union, List
|
| 6 |
+
from itertools import chain
|
| 7 |
|
| 8 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
| 9 |
+
from transformers.models.t5.modeling_t5 import T5Stack
|
| 10 |
+
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
| 11 |
|
| 12 |
|
| 13 |
+
def byt5_tokenize(text: str, max_length: int, pad_token_id: int = 0):
|
| 14 |
+
byte_codes = []
|
| 15 |
+
for char in text:
|
| 16 |
+
# Add 3 to account for special tokens
|
| 17 |
+
byte_codes.append([byte + 3 for byte in char.encode('utf-8')])
|
| 18 |
+
|
| 19 |
+
tokens = list(chain.from_iterable(byte_codes))
|
| 20 |
+
# Map each token to the character it represents
|
| 21 |
+
char_token_lengths = [len(b) for b in byte_codes]
|
| 22 |
+
|
| 23 |
+
batched_tokens = []
|
| 24 |
+
attention_mask = []
|
| 25 |
+
for i in range(0, len(tokens), max_length):
|
| 26 |
+
batched_tokens.append(tokens[i:i + max_length])
|
| 27 |
+
attention_mask.append([1] * len(batched_tokens[-1]))
|
| 28 |
+
|
| 29 |
+
# Pad last item
|
| 30 |
+
if len(batched_tokens[-1]) < max_length:
|
| 31 |
+
batched_tokens[-1] += [pad_token_id] * (max_length - len(batched_tokens[-1]))
|
| 32 |
+
attention_mask[-1] += [0] * (max_length - len(attention_mask[-1]))
|
| 33 |
+
|
| 34 |
+
return {"input_ids": batched_tokens, "attention_mask": attention_mask, "char_token_lengths": char_token_lengths}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# From https://github.com/osainz59/t5-encoder
|
| 40 |
+
class T5ForTokenClassification(T5PreTrainedModel):
|
| 41 |
+
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
|
| 42 |
+
|
| 43 |
+
def __init__(self, config: T5Config):
|
| 44 |
+
super().__init__(config)
|
| 45 |
+
self.model_dim = config.d_model
|
| 46 |
+
|
| 47 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
| 48 |
+
|
| 49 |
+
encoder_config = deepcopy(config)
|
| 50 |
+
encoder_config.is_decoder = False
|
| 51 |
+
encoder_config.is_encoder_decoder = False
|
| 52 |
+
encoder_config.use_cache = False
|
| 53 |
+
self.encoder = T5Stack(encoder_config, self.shared)
|
| 54 |
+
|
| 55 |
+
classifier_dropout = (
|
| 56 |
+
config.classifier_dropout if hasattr(config, 'classifier_dropout') else config.dropout_rate
|
| 57 |
+
)
|
| 58 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 59 |
+
self.classifier = nn.Linear(config.d_model, config.num_labels)
|
| 60 |
+
|
| 61 |
+
# Initialize weights and apply final processing
|
| 62 |
+
self.post_init()
|
| 63 |
+
|
| 64 |
+
# Model parallel
|
| 65 |
+
self.model_parallel = False
|
| 66 |
+
self.device_map = None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def parallelize(self, device_map=None):
|
| 70 |
+
self.device_map = (
|
| 71 |
+
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
| 72 |
+
if device_map is None
|
| 73 |
+
else device_map
|
| 74 |
+
)
|
| 75 |
+
assert_device_map(self.device_map, len(self.encoder.block))
|
| 76 |
+
self.encoder.parallelize(self.device_map)
|
| 77 |
+
self.classifier.to(self.encoder.first_device)
|
| 78 |
+
self.model_parallel = True
|
| 79 |
+
|
| 80 |
+
def deparallelize(self):
|
| 81 |
+
self.encoder.deparallelize()
|
| 82 |
+
self.encoder = self.encoder.to("cpu")
|
| 83 |
+
self.classifier = self.classifier.to("cpu")
|
| 84 |
+
self.model_parallel = False
|
| 85 |
+
self.device_map = None
|
| 86 |
+
torch.cuda.empty_cache()
|
| 87 |
+
|
| 88 |
+
def get_input_embeddings(self):
|
| 89 |
+
return self.shared
|
| 90 |
+
|
| 91 |
+
def set_input_embeddings(self, new_embeddings):
|
| 92 |
+
self.shared = new_embeddings
|
| 93 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
| 94 |
+
|
| 95 |
+
def get_encoder(self):
|
| 96 |
+
return self.encoder
|
| 97 |
+
|
| 98 |
+
def _prune_heads(self, heads_to_prune):
|
| 99 |
+
for layer, heads in heads_to_prune.items():
|
| 100 |
+
self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
|
| 101 |
|
| 102 |
def forward(
|
| 103 |
self,
|
| 104 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 105 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 106 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
| 108 |
labels: Optional[torch.LongTensor] = None,
|
|
|
|
| 109 |
output_attentions: Optional[bool] = None,
|
| 110 |
output_hidden_states: Optional[bool] = None,
|
| 111 |
return_dict: Optional[bool] = None,
|
| 112 |
+
) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 114 |
+
|
| 115 |
+
outputs = self.encoder(
|
| 116 |
+
input_ids=input_ids,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
attention_mask=attention_mask,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
inputs_embeds=inputs_embeds,
|
| 119 |
+
head_mask=head_mask,
|
|
|
|
| 120 |
output_attentions=output_attentions,
|
| 121 |
output_hidden_states=output_hidden_states,
|
| 122 |
return_dict=return_dict,
|
| 123 |
)
|
| 124 |
+
|
| 125 |
sequence_output = outputs[0]
|
|
|
|
| 126 |
|
| 127 |
+
sequence_output = self.dropout(sequence_output)
|
| 128 |
+
logits = self.classifier(sequence_output)
|
| 129 |
+
|
| 130 |
loss = None
|
| 131 |
|
| 132 |
if not return_dict:
|
| 133 |
+
output = (logits,) + outputs[2:]
|
| 134 |
return ((loss,) + output) if loss is not None else output
|
| 135 |
|
| 136 |
+
return TokenClassifierOutput(
|
| 137 |
loss=loss,
|
| 138 |
logits=logits,
|
| 139 |
+
hidden_states=outputs.hidden_states,
|
| 140 |
+
attentions=outputs.attentions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
)
|
marker/settings.py
CHANGED
|
@@ -54,7 +54,7 @@ class Settings(BaseSettings):
|
|
| 54 |
NOUGAT_MODEL_MAX: int = 512 # Max inference length for nougat
|
| 55 |
NOUGAT_TOKEN_BUFFER: int = 256 # Number of tokens to buffer above max for nougat
|
| 56 |
NOUGAT_HALLUCINATION_WORDS: List[str] = ["[MISSING_PAGE_POST]", "## References\n", "**Figure Captions**\n", "Footnote",
|
| 57 |
-
"\par\par\par", "## Chapter", "Fig.", "particle", "[REPEATS]", "[TRUNCATED]", "### "]
|
| 58 |
NOUGAT_DPI: int = 96 # DPI to render images at, matches default settings for nougat
|
| 59 |
NOUGAT_MODEL_NAME: str = "0.1.0-small" # Name of the model to use
|
| 60 |
NOUGAT_BATCH_SIZE: int = 6 if TORCH_DEVICE == "cuda" else 1 # Batch size for nougat, don't batch on cpu
|
|
@@ -74,8 +74,9 @@ class Settings(BaseSettings):
|
|
| 74 |
# Final editing model
|
| 75 |
EDITOR_BATCH_SIZE: int = 4
|
| 76 |
EDITOR_MAX_LENGTH: int = 2048
|
| 77 |
-
EDITOR_MODEL_NAME: str = "vikp/
|
| 78 |
-
ENABLE_EDITOR_MODEL: bool =
|
|
|
|
| 79 |
|
| 80 |
# Ray
|
| 81 |
RAY_CACHE_PATH: Optional[str] = None # Where to save ray cache
|
|
|
|
| 54 |
NOUGAT_MODEL_MAX: int = 512 # Max inference length for nougat
|
| 55 |
NOUGAT_TOKEN_BUFFER: int = 256 # Number of tokens to buffer above max for nougat
|
| 56 |
NOUGAT_HALLUCINATION_WORDS: List[str] = ["[MISSING_PAGE_POST]", "## References\n", "**Figure Captions**\n", "Footnote",
|
| 57 |
+
"\par\par\par", "## Chapter", "Fig.", "particle", "[REPEATS]", "[TRUNCATED]", "### ", "effective field strength", "\Phi_{\rm eff}"]
|
| 58 |
NOUGAT_DPI: int = 96 # DPI to render images at, matches default settings for nougat
|
| 59 |
NOUGAT_MODEL_NAME: str = "0.1.0-small" # Name of the model to use
|
| 60 |
NOUGAT_BATCH_SIZE: int = 6 if TORCH_DEVICE == "cuda" else 1 # Batch size for nougat, don't batch on cpu
|
|
|
|
| 74 |
# Final editing model
|
| 75 |
EDITOR_BATCH_SIZE: int = 4
|
| 76 |
EDITOR_MAX_LENGTH: int = 2048
|
| 77 |
+
EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5"
|
| 78 |
+
ENABLE_EDITOR_MODEL: bool = True # The editor model can create false positives
|
| 79 |
+
EDITOR_CUTOFF_THRESH: float = 0.75 # Ignore predictions below this probability
|
| 80 |
|
| 81 |
# Ray
|
| 82 |
RAY_CACHE_PATH: Optional[str] = None # Where to save ray cache
|