Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer | |
| from transformers import AutoModelForSeq2SeqLM | |
| import errant | |
| ERROR_EXPLANATIONS = { | |
| # Replacement (R) | |
| "R:ADJ": "Sai tính từ.", | |
| "R:ADV": "Trạng từ chưa đúng.", | |
| "R:CONJ": "Sai từ nối.", | |
| "R:DET": "Thiếu hoặc sai mạo từ (a, an, the).", | |
| "R:NOUN": "Sai dạng danh từ.", | |
| "R:NOUN:NUM": "Danh từ số ít/ số nhiều không đúng.", | |
| "R:NOUN:POSS": "Diễn đạt sở hữu chưa đúng, hãy thay đổi (ví dụ thêm '’s').", | |
| # "R:ORTH": "Có lỗi chính tả hoặc viết hoa chưa đúng.", | |
| "R:PART": "Sai tiểu từ đi kèm (ví dụ: 'look at' thay vì 'look on').", | |
| "R:PREP": "Sai giới từ.", | |
| "R:PRON": "Đại từ chưa đúng.", | |
| # "R:PUNCT": "Bạn dùng dấu câu sai, hãy đổi sang dấu câu thích hợp.", | |
| "R:VERB": "Động từ chưa đúng.", | |
| "R:VERB:FORM": "Sai dạng động từ (V-ing, to V, V-ed…).", | |
| "R:VERB:INFL": "Sai chia động từ, cần thêm hoặc bớt '-s/-ed'.", | |
| "R:VERB:SVA": "Chủ ngữ và động từ không phù hợp.", | |
| "R:VERB:TENSE": "Thì động từ chưa đúng.", | |
| "R:MORPH": "Kiểm tra chia động từ, số ít/số nhiều và dạng từ.", | |
| "R:OTHER": "Lỗi ngữ pháp khác.", | |
| # Missing (M) | |
| "M:ADJ": "Cần thêm một tính từ để rõ nghĩa hơn.", | |
| "M:ADV": "Cần thêm trạng từ để bổ nghĩa cho động từ.", | |
| "M:CONJ": "Thiếu từ nối, cần thêm để các vế liên kết tự nhiên.", | |
| "M:DET": "Cần thêm mạo từ (a, an, the) trước danh từ.", | |
| "M:NOUN": "Chưa đủ danh từ, cần bổ sung để hoàn chỉnh.", | |
| "M:PART": "Cần thêm tiểu từ đi kèm động từ.", | |
| "M:PREP": "Cần thêm giới từ để diễn đạt đúng.", | |
| "M:PRON": "Thiếu đại từ làm chủ ngữ/tân ngữ.", | |
| # "M:PUNCT": "Bạn cần thêm dấu câu để câu rõ ràng hơn.", | |
| "M:VERB": "Thiếu động từ, cần thêm để đủ ngữ pháp.", | |
| # Unnecessary (U) | |
| "U:ADJ": "Dùng thừa tính từ.", | |
| "U:ADV": "Dùng trạng từ không cần thiết, nên lược bỏ.", | |
| "U:CONJ": "Dùng thừa liên từ, hãy bỏ để câu gọn hơn.", | |
| "U:DET": "Dùng thừa mạo từ.", | |
| "U:NOUN": "Dùng thừa danh từ, cần lược bỏ để đúng ngữ pháp.", | |
| "U:PART": "Dùng thừa tiểu từ không cần thiết.", | |
| "U:PREP": "Dùng thừa giới từ, hãy bỏ để câu tự nhiên hơn.", | |
| "U:PRON": "Dùng thừa đại từ.", | |
| # "U:PUNCT": "Có dấu câu không cần thiết, hãy bỏ đi.", | |
| "U:VERB": "Dùng thừa động từ, cần bỏ đi để đúng ngữ pháp." | |
| } | |
| class Gramformer: | |
| def __init__(self, models=1, use_gpu=False): | |
| self.annotator = errant.load('en') | |
| if use_gpu: | |
| device = "cuda:0" | |
| else: | |
| device = "cpu" | |
| self.device = device | |
| correction_model_tag = "prithivida/grammar_error_correcter_v1" | |
| self.model_loaded = False | |
| if models == 1: | |
| self.correction_tokenizer = AutoTokenizer.from_pretrained( | |
| correction_model_tag, use_auth_token=False) | |
| self.correction_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| correction_model_tag, use_auth_token=False) | |
| self.correction_model = self.correction_model.to(device) | |
| self.model_loaded = True | |
| print("[Gramformer] Grammar error correct/highlight model loaded..") | |
| elif models == 2: | |
| # TODO | |
| print("TO BE IMPLEMENTED!!!") | |
| def correct(self, input_sentence, max_candidates=1): | |
| if self.model_loaded: | |
| correction_prefix = "gec: " | |
| input_sentence = correction_prefix + input_sentence | |
| input_ids = self.correction_tokenizer.encode( | |
| input_sentence, return_tensors='pt') | |
| input_ids = input_ids.to(self.device) | |
| preds = self.correction_model.generate( | |
| input_ids, | |
| do_sample=True, | |
| max_length=128, | |
| # top_k=50, | |
| # top_p=0.95, | |
| num_beams=7, | |
| early_stopping=True, | |
| num_return_sequences=max_candidates) | |
| corrected = set() | |
| for pred in preds: | |
| corrected.add(self.correction_tokenizer.decode( | |
| pred, skip_special_tokens=True).strip()) | |
| # corrected = list(corrected) | |
| # scores = self.scorer.sentence_score(corrected, log=True) | |
| # ranked_corrected = [(c,s) for c, s in zip(corrected, scores)] | |
| # ranked_corrected.sort(key = lambda x:x[1], reverse=True) | |
| return corrected | |
| else: | |
| print("Model is not loaded") | |
| return None | |
| def highlight(self, orig, cor): | |
| edits = self._get_edits(orig, cor) | |
| orig_tokens = orig.split() | |
| ignore_indexes = [] | |
| for edit in edits: | |
| edit_type = edit[0] | |
| edit_str_start = edit[1] | |
| edit_spos = edit[2] | |
| edit_epos = edit[3] | |
| edit_str_end = edit[4] | |
| # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion | |
| for i in range(edit_spos+1, edit_epos): | |
| ignore_indexes.append(i) | |
| if edit_str_start == "": | |
| if edit_spos - 1 >= 0: | |
| new_edit_str = orig_tokens[edit_spos - 1] | |
| edit_spos -= 1 | |
| else: | |
| new_edit_str = orig_tokens[edit_spos + 1] | |
| edit_spos += 1 | |
| if edit_type == "PUNCT": | |
| st = "<a type='" + edit_type + "' edit='" + \ | |
| edit_str_end + "'>" + new_edit_str + "</a>" | |
| else: | |
| st = "<a type='" + edit_type + "' edit='" + new_edit_str + \ | |
| " " + edit_str_end + "'>" + new_edit_str + "</a>" | |
| orig_tokens[edit_spos] = st | |
| elif edit_str_end == "": | |
| st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>" | |
| orig_tokens[edit_spos] = st | |
| else: | |
| st = "<c type='" + edit_type + "' edit='" + \ | |
| edit_str_end + "'>" + edit_str_start + "</c>" | |
| orig_tokens[edit_spos] = st | |
| for i in sorted(ignore_indexes, reverse=True): | |
| del (orig_tokens[i]) | |
| return (" ".join(orig_tokens)) | |
| def _get_edits(self, orig, cor): | |
| orig = self.annotator.parse(orig) | |
| cor = self.annotator.parse(cor) | |
| alignment = self.annotator.align(orig, cor) | |
| edits = self.annotator.merge(alignment) | |
| if len(edits) == 0: | |
| return [] | |
| edit_annotations = [] | |
| for e in edits: | |
| e = self.annotator.classify(e) | |
| edit_annotations.append( | |
| (e.type, e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end)) | |
| if len(edit_annotations) > 0: | |
| return edit_annotations | |
| else: | |
| return [] | |
| def get_edits(self, orig, cor): | |
| return self._get_edits(orig, cor) | |
| if __name__ == "__main__": | |
| gf = Gramformer(models=1, use_gpu=False) | |
| sentences = [ | |
| "This are good.", | |
| "He don't know.", | |
| "She no went there.", | |
| "I can to do it.", | |
| "What is you name?", | |
| "This is a test sentence.", | |
| "I has a apple.", | |
| "They is playing football.", | |
| "He go to school every day.", | |
| "She like to read books." | |
| ] | |
| for sentence in sentences: | |
| print("Input: ", sentence) | |
| corrections = gf.correct(sentence, max_candidates=3) | |
| for c in corrections: | |
| print("Output: ", c) | |
| print("Highlight: ", gf.highlight(sentence, c)) | |
| print() | |