phoneme_transciptor / gramformer.py
thanhhungtakeshi's picture
rm language-tool
76a5fa3
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
import errant
ERROR_EXPLANATIONS = {
# Replacement (R)
"R:ADJ": "Sai tính từ.",
"R:ADV": "Trạng từ chưa đúng.",
"R:CONJ": "Sai từ nối.",
"R:DET": "Thiếu hoặc sai mạo từ (a, an, the).",
"R:NOUN": "Sai dạng danh từ.",
"R:NOUN:NUM": "Danh từ số ít/ số nhiều không đúng.",
"R:NOUN:POSS": "Diễn đạt sở hữu chưa đúng, hãy thay đổi (ví dụ thêm '’s').",
# "R:ORTH": "Có lỗi chính tả hoặc viết hoa chưa đúng.",
"R:PART": "Sai tiểu từ đi kèm (ví dụ: 'look at' thay vì 'look on').",
"R:PREP": "Sai giới từ.",
"R:PRON": "Đại từ chưa đúng.",
# "R:PUNCT": "Bạn dùng dấu câu sai, hãy đổi sang dấu câu thích hợp.",
"R:VERB": "Động từ chưa đúng.",
"R:VERB:FORM": "Sai dạng động từ (V-ing, to V, V-ed…).",
"R:VERB:INFL": "Sai chia động từ, cần thêm hoặc bớt '-s/-ed'.",
"R:VERB:SVA": "Chủ ngữ và động từ không phù hợp.",
"R:VERB:TENSE": "Thì động từ chưa đúng.",
"R:MORPH": "Kiểm tra chia động từ, số ít/số nhiều và dạng từ.",
"R:OTHER": "Lỗi ngữ pháp khác.",
# Missing (M)
"M:ADJ": "Cần thêm một tính từ để rõ nghĩa hơn.",
"M:ADV": "Cần thêm trạng từ để bổ nghĩa cho động từ.",
"M:CONJ": "Thiếu từ nối, cần thêm để các vế liên kết tự nhiên.",
"M:DET": "Cần thêm mạo từ (a, an, the) trước danh từ.",
"M:NOUN": "Chưa đủ danh từ, cần bổ sung để hoàn chỉnh.",
"M:PART": "Cần thêm tiểu từ đi kèm động từ.",
"M:PREP": "Cần thêm giới từ để diễn đạt đúng.",
"M:PRON": "Thiếu đại từ làm chủ ngữ/tân ngữ.",
# "M:PUNCT": "Bạn cần thêm dấu câu để câu rõ ràng hơn.",
"M:VERB": "Thiếu động từ, cần thêm để đủ ngữ pháp.",
# Unnecessary (U)
"U:ADJ": "Dùng thừa tính từ.",
"U:ADV": "Dùng trạng từ không cần thiết, nên lược bỏ.",
"U:CONJ": "Dùng thừa liên từ, hãy bỏ để câu gọn hơn.",
"U:DET": "Dùng thừa mạo từ.",
"U:NOUN": "Dùng thừa danh từ, cần lược bỏ để đúng ngữ pháp.",
"U:PART": "Dùng thừa tiểu từ không cần thiết.",
"U:PREP": "Dùng thừa giới từ, hãy bỏ để câu tự nhiên hơn.",
"U:PRON": "Dùng thừa đại từ.",
# "U:PUNCT": "Có dấu câu không cần thiết, hãy bỏ đi.",
"U:VERB": "Dùng thừa động từ, cần bỏ đi để đúng ngữ pháp."
}
class Gramformer:
def __init__(self, models=1, use_gpu=False):
self.annotator = errant.load('en')
if use_gpu:
device = "cuda:0"
else:
device = "cpu"
self.device = device
correction_model_tag = "prithivida/grammar_error_correcter_v1"
self.model_loaded = False
if models == 1:
self.correction_tokenizer = AutoTokenizer.from_pretrained(
correction_model_tag, use_auth_token=False)
self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(
correction_model_tag, use_auth_token=False)
self.correction_model = self.correction_model.to(device)
self.model_loaded = True
print("[Gramformer] Grammar error correct/highlight model loaded..")
elif models == 2:
# TODO
print("TO BE IMPLEMENTED!!!")
def correct(self, input_sentence, max_candidates=1):
if self.model_loaded:
correction_prefix = "gec: "
input_sentence = correction_prefix + input_sentence
input_ids = self.correction_tokenizer.encode(
input_sentence, return_tensors='pt')
input_ids = input_ids.to(self.device)
preds = self.correction_model.generate(
input_ids,
do_sample=True,
max_length=128,
# top_k=50,
# top_p=0.95,
num_beams=7,
early_stopping=True,
num_return_sequences=max_candidates)
corrected = set()
for pred in preds:
corrected.add(self.correction_tokenizer.decode(
pred, skip_special_tokens=True).strip())
# corrected = list(corrected)
# scores = self.scorer.sentence_score(corrected, log=True)
# ranked_corrected = [(c,s) for c, s in zip(corrected, scores)]
# ranked_corrected.sort(key = lambda x:x[1], reverse=True)
return corrected
else:
print("Model is not loaded")
return None
def highlight(self, orig, cor):
edits = self._get_edits(orig, cor)
orig_tokens = orig.split()
ignore_indexes = []
for edit in edits:
edit_type = edit[0]
edit_str_start = edit[1]
edit_spos = edit[2]
edit_epos = edit[3]
edit_str_end = edit[4]
# if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
for i in range(edit_spos+1, edit_epos):
ignore_indexes.append(i)
if edit_str_start == "":
if edit_spos - 1 >= 0:
new_edit_str = orig_tokens[edit_spos - 1]
edit_spos -= 1
else:
new_edit_str = orig_tokens[edit_spos + 1]
edit_spos += 1
if edit_type == "PUNCT":
st = "<a type='" + edit_type + "' edit='" + \
edit_str_end + "'>" + new_edit_str + "</a>"
else:
st = "<a type='" + edit_type + "' edit='" + new_edit_str + \
" " + edit_str_end + "'>" + new_edit_str + "</a>"
orig_tokens[edit_spos] = st
elif edit_str_end == "":
st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>"
orig_tokens[edit_spos] = st
else:
st = "<c type='" + edit_type + "' edit='" + \
edit_str_end + "'>" + edit_str_start + "</c>"
orig_tokens[edit_spos] = st
for i in sorted(ignore_indexes, reverse=True):
del (orig_tokens[i])
return (" ".join(orig_tokens))
def _get_edits(self, orig, cor):
orig = self.annotator.parse(orig)
cor = self.annotator.parse(cor)
alignment = self.annotator.align(orig, cor)
edits = self.annotator.merge(alignment)
if len(edits) == 0:
return []
edit_annotations = []
for e in edits:
e = self.annotator.classify(e)
edit_annotations.append(
(e.type, e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end))
if len(edit_annotations) > 0:
return edit_annotations
else:
return []
def get_edits(self, orig, cor):
return self._get_edits(orig, cor)
if __name__ == "__main__":
gf = Gramformer(models=1, use_gpu=False)
sentences = [
"This are good.",
"He don't know.",
"She no went there.",
"I can to do it.",
"What is you name?",
"This is a test sentence.",
"I has a apple.",
"They is playing football.",
"He go to school every day.",
"She like to read books."
]
for sentence in sentences:
print("Input: ", sentence)
corrections = gf.correct(sentence, max_candidates=3)
for c in corrections:
print("Output: ", c)
print("Highlight: ", gf.highlight(sentence, c))
print()