| |
| """ |
| Created on Fri May 26 14:07:22 2023 |
| |
| @author: vibin |
| """ |
|
|
| import streamlit as st |
| from pandasql import sqldf |
| import pandas as pd |
| import re |
| from typing import List |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
| import re |
|
|
|
|
| @st.cache_resource() |
| def tapas_model(): |
| return(pipeline(task="table-question-answering", model="google/tapas-base-finetuned-wtq")) |
|
|
| @st.cache_resource() |
| def prepare_input(question: str, table: List[str]): |
| table_prefix = "table:" |
| question_prefix = "question:" |
| join_table = ",".join(table) |
| inputs = f"{question_prefix} {question} {table_prefix} {join_table}" |
| input_ids = tokenizer(inputs, max_length=512, return_tensors="pt").input_ids |
| return input_ids |
|
|
| @st.cache_resource() |
| def inference(question: str, table: List[str]) -> str: |
| input_data = prepare_input(question=question, table=table) |
| input_data = input_data.to(model.device) |
| outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700) |
| result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True) |
| return result |
|
|
| @st.cache_resource() |
| def tokmod(tok_md): |
| tkn = AutoTokenizer.from_pretrained(tok_md) |
| mdl = AutoModelForSeq2SeqLM.from_pretrained(tok_md) |
| return(tkn,mdl) |
|
|
|
|
| |
|
|
| nav = st.sidebar.radio("Navigation",["TAPAS","Text2SQL"]) |
| if nav == "TAPAS": |
| |
| col1 , col2, col3 = st.columns(3) |
| col2.title("TAPAS") |
| |
| col3 , col4 = st.columns([3,12]) |
| col4.text("Tabular Data Text Extraction using text") |
| |
| table = pd.read_csv("data.csv") |
| table = table.astype(str) |
| st.text("DataSet - ") |
| st.dataframe(table,width=3000,height= 400) |
| |
| st.title("") |
| |
| lst_q = ["Which country has low medicare","Who are the patients from india","Who are the patients from india","Patients who have Edema","CUI code for diabetes patients","Patients having oxygen less than 94 but 91"] |
| |
| v2 = st.selectbox("Choose your text",lst_q,index = 0) |
|
|
| st.title("") |
| |
| sql_txt = st.text_area("TAPAS Input",v2) |
| |
| if st.button("Predict"): |
| tqa = tapas_model() |
| txt_sql = tqa(table=table, query=sql_txt)["answer"] |
| st.text("Output - ") |
| st.success(f"{txt_sql}") |
| |
| |
| |
| |
| elif nav == "Text2SQL": |
| |
| |
| col1 , col2, col3 = st.columns(3) |
| col2.title("Text2SQL") |
| |
| col3 , col4 = st.columns([1,20]) |
| col4.text("Text will be converted to SQL Query and can extract the data from DataSet") |
| |
| |
| |
| df_qna = pd.read_csv("qnacsv.csv", encoding= 'unicode_escape') |
| |
| st.title("") |
| |
| st.text("DataSet - ") |
| st.dataframe(df_qna,width=3000,height= 500) |
| |
| st.title("") |
| |
| lst_q = ["what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD", "get class code with measure = 72_HR_ABX", "get sum of version for Class_Code is Antibiotic Stewardship", "what interface is measure indicator code = 72_HR_ABX"] |
| v2 = st.selectbox("Choose your text",lst_q,index = 0) |
|
|
| st.title("") |
| |
| |
| sql_txt = st.text_area("Text for SQL Conversion",v2) |
| |
| |
| if st.button("Predict"): |
| |
| tok_model = "juierror/flan-t5-text2sql-with-schema" |
| tokenizer,model = tokmod(tok_model) |
| |
| |
| table_name = "df_qna" |
| table_col = ["Type","Class_Code", "Version","Measure_Indicator_Code","Measure_Indicator_Name","Description_Definition", "Source", "Interfaces"] |
| |
| txt_sql = inference(question=sql_txt, table=table_col) |
| |
| |
| |
| sql_avg = ["AVG","COUNT","DISTINCT","MAX","MIN","SUM"] |
| txt_sql = txt_sql.replace("table",table_name) |
| sql_quotes = [] |
| for match in re.finditer("=",txt_sql): |
| new_txt = txt_sql[match.span()[1]+1:] |
| try: |
| match2 = re.search("AND",new_txt) |
| sql_quotes.append((new_txt[:match2.span()[0]]).strip()) |
| except: |
| sql_quotes.append(new_txt.strip()) |
| |
| for i in sql_quotes: |
| qts = "'" + i + "'" |
| txt_sql = txt_sql.replace(i, qts) |
| |
| for r in sql_avg: |
| if r in txt_sql: |
| rr = re.search(rf"{r} (\w+)", txt_sql) |
| init = " " + rr[1] |
| qts = "(" + rr[1] + ")" |
| txt_sql = txt_sql.replace(init,qts) |
| else: |
| pass |
| |
| |
| st.success(f"{txt_sql}") |
| all_students = sqldf(txt_sql) |
| |
| st.text("Output - ") |
| st.write(all_students) |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |