orva06 commited on
Commit
a11a937
Β·
verified Β·
1 Parent(s): 620e204

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os, time, io, logging
3
+ import pandas as pd
4
+ import streamlit as st
5
+ from inference import InferenceEngine
6
+
7
+ # --------- CONFIG ----------
8
+ MODEL_CKPT = "best_indobert_multipredict.pt"
9
+ TOKENIZER_NAME = "indobenchmark/indobert-base-p2"
10
+ CSV_INPUT_COL = "soal" # expected text column in uploaded CSV
11
+ LOG_PATH = "inference_log.csv"
12
+
13
+ st.set_page_config(page_title="IndoBERT Multi-Predict (Topic + Taxonomy)", layout="wide")
14
+
15
+ # basic logging to csv
16
+ if not os.path.exists(LOG_PATH):
17
+ pd.DataFrame(columns=["timestamp","input_sample","topic_pred","topic_conf","tax_pred","tax_conf","runtime_s"]).to_csv(LOG_PATH, index=False)
18
+
19
+ @st.cache_resource
20
+ def load_engine():
21
+ eng = InferenceEngine(ckpt_path=MODEL_CKPT, tokenizer_name=TOKENIZER_NAME)
22
+ return eng
23
+
24
+ st.title("IndoBERT β€” Multi-Predict (Topic & Taxonomy)")
25
+ st.caption("Shared encoder β†’ 2 output heads. Fast inference (CPU/GPU).")
26
+
27
+ eng = load_engine()
28
+
29
+ # Left column: input
30
+ c1, c2 = st.columns([1,1.2])
31
+
32
+ with c1:
33
+ st.header("Single prediction")
34
+ text = st.text_area("Paste your question / soal here:", height=160, placeholder="Tulis soal / pernyataan ...")
35
+ st.write("Light cleaning applied (lowercase, trim, normalize spaces).")
36
+ if st.button("Predict single"):
37
+ start = time.time()
38
+ res = eng.predict_texts([text])[0]
39
+ runtime = time.time() - start
40
+ # show result
41
+ st.subheader("Result")
42
+ st.metric("Topic", f"{res['topic_label']} ({res['topic_idx']})", delta=f"{res['topic_conf']:.3f}")
43
+ st.metric("Taxonomy", f"{res['tax_label']} ({res['tax_idx']})", delta=f"{res['tax_conf']:.3f}")
44
+ # optional: probability bar
45
+ st.write("Topic confidence:", f"{res['topic_conf']:.3f}")
46
+ st.write("Taxonomy confidence:", f"{res['tax_conf']:.3f}")
47
+ st.write("Raw probs (topic head): β€” first 8 shown")
48
+ st.write(res["topic_probs"][:8])
49
+ # logging
50
+ # logging disabled on HF Spaces
51
+ pass
52
+
53
+ with c2:
54
+ st.header("Batch prediction (CSV)")
55
+ st.write("CSV format: must contain column called:", f"`{CSV_INPUT_COL}`")
56
+ uploaded = st.file_uploader("Upload CSV file", type=["csv"])
57
+ if uploaded is not None:
58
+ try:
59
+ df = pd.read_csv(uploaded)
60
+ except Exception:
61
+ df = pd.read_csv(uploaded, encoding="latin1")
62
+ st.write("Preview uploaded data (first 5 rows):")
63
+ st.dataframe(df.head())
64
+ if CSV_INPUT_COL not in df.columns:
65
+ st.error(f"CSV must contain column `{CSV_INPUT_COL}`. Rename your text column accordingly.")
66
+ else:
67
+ if st.button("Predict batch"):
68
+ texts = df[CSV_INPUT_COL].astype(str).tolist()
69
+ t0 = time.time()
70
+ results = eng.predict_texts(texts)
71
+ elapsed = time.time() - t0
72
+ # join results into dataframe
73
+ out = pd.DataFrame(results)
74
+ out = out.rename(columns={
75
+ "topic_label":"pred_topic",
76
+ "topic_conf":"pred_topic_conf",
77
+ "tax_label":"pred_tax",
78
+ "tax_conf":"pred_tax_conf"
79
+ })
80
+ # attach to original
81
+ df_out = pd.concat([df.reset_index(drop=True), out[["pred_topic","pred_topic_conf","pred_tax","pred_tax_conf"]]], axis=1)
82
+ st.success(f"Done β€” {len(df_out)} rows in {elapsed:.2f}s")
83
+ st.dataframe(df_out.head(50))
84
+ # allow download
85
+ csv_bytes = df_out.to_csv(index=False).encode("utf-8")
86
+ st.download_button("Download predictions (CSV)", csv_bytes, file_name="predictions.csv", mime="text/csv")
87
+ # append logs
88
+ # logging disabled on HF Spaces
89
+ pass # logging disabled on HF Spaces
90
+
91
+ # logging disabled on HF Spaces
92
+
93
+ st.write("---")
94
+ st.markdown("**Model info:** IndoBERT shared encoder (multi-head). Checkpoint: `" + MODEL_CKPT + "`")
95
+ st.markdown("**Notes:** For best label names ensure `le_topic_classes.npy` and `le_tax_classes.npy` are present in the app folder.")
best_indobert_multipredict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e24acc8a8a4af4b79019ec97b7b921ec5512e5e62b9001f7946daa150e2c3054
3
+ size 43092
inference.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os, time, io, logging
3
+ import pandas as pd
4
+ import streamlit as st
5
+ from inference import InferenceEngine
6
+
7
+ # --------- CONFIG ----------
8
+ MODEL_CKPT = "best_indobert_multipredict.pt"
9
+ TOKENIZER_NAME = "indobenchmark/indobert-base-p2"
10
+ CSV_INPUT_COL = "soal" # expected text column in uploaded CSV
11
+ LOG_PATH = "inference_log.csv"
12
+
13
+ st.set_page_config(page_title="IndoBERT Multi-Predict (Topic + Taxonomy)", layout="wide")
14
+
15
+ # basic logging to csv
16
+ if not os.path.exists(LOG_PATH):
17
+ pd.DataFrame(columns=["timestamp","input_sample","topic_pred","topic_conf","tax_pred","tax_conf","runtime_s"]).to_csv(LOG_PATH, index=False)
18
+
19
+ @st.cache_resource
20
+ def load_engine():
21
+ eng = InferenceEngine(ckpt_path=MODEL_CKPT, tokenizer_name=TOKENIZER_NAME)
22
+ return eng
23
+
24
+ st.title("IndoBERT β€” Multi-Predict (Topic & Taxonomy)")
25
+ st.caption("Shared encoder β†’ 2 output heads. Fast inference (CPU/GPU).")
26
+
27
+ eng = load_engine()
28
+
29
+ # Left column: input
30
+ c1, c2 = st.columns([1,1.2])
31
+
32
+ with c1:
33
+ st.header("Single prediction")
34
+ text = st.text_area("Paste your question / soal here:", height=160, placeholder="Tulis soal / pernyataan ...")
35
+ st.write("Light cleaning applied (lowercase, trim, normalize spaces).")
36
+ if st.button("Predict single"):
37
+ start = time.time()
38
+ res = eng.predict_texts([text])[0]
39
+ runtime = time.time() - start
40
+ # show result
41
+ st.subheader("Result")
42
+ st.metric("Topic", f"{res['topic_label']} ({res['topic_idx']})", delta=f"{res['topic_conf']:.3f}")
43
+ st.metric("Taxonomy", f"{res['tax_label']} ({res['tax_idx']})", delta=f"{res['tax_conf']:.3f}")
44
+ # optional: probability bar
45
+ st.write("Topic confidence:", f"{res['topic_conf']:.3f}")
46
+ st.write("Taxonomy confidence:", f"{res['tax_conf']:.3f}")
47
+ st.write("Raw probs (topic head): β€” first 8 shown")
48
+ st.write(res["topic_probs"][:8])
49
+ # logging
50
+ # logging disabled on HF Spaces
51
+ pass
52
+
53
+ with c2:
54
+ st.header("Batch prediction (CSV)")
55
+ st.write("CSV format: must contain column called:", f"`{CSV_INPUT_COL}`")
56
+ uploaded = st.file_uploader("Upload CSV file", type=["csv"])
57
+ if uploaded is not None:
58
+ try:
59
+ df = pd.read_csv(uploaded)
60
+ except Exception:
61
+ df = pd.read_csv(uploaded, encoding="latin1")
62
+ st.write("Preview uploaded data (first 5 rows):")
63
+ st.dataframe(df.head())
64
+ if CSV_INPUT_COL not in df.columns:
65
+ st.error(f"CSV must contain column `{CSV_INPUT_COL}`. Rename your text column accordingly.")
66
+ else:
67
+ if st.button("Predict batch"):
68
+ texts = df[CSV_INPUT_COL].astype(str).tolist()
69
+ t0 = time.time()
70
+ results = eng.predict_texts(texts)
71
+ elapsed = time.time() - t0
72
+ # join results into dataframe
73
+ out = pd.DataFrame(results)
74
+ out = out.rename(columns={
75
+ "topic_label":"pred_topic",
76
+ "topic_conf":"pred_topic_conf",
77
+ "tax_label":"pred_tax",
78
+ "tax_conf":"pred_tax_conf"
79
+ })
80
+ # attach to original
81
+ df_out = pd.concat([df.reset_index(drop=True), out[["pred_topic","pred_topic_conf","pred_tax","pred_tax_conf"]]], axis=1)
82
+ st.success(f"Done β€” {len(df_out)} rows in {elapsed:.2f}s")
83
+ st.dataframe(df_out.head(50))
84
+ # allow download
85
+ csv_bytes = df_out.to_csv(index=False).encode("utf-8")
86
+ st.download_button("Download predictions (CSV)", csv_bytes, file_name="predictions.csv", mime="text/csv")
87
+ # append logs
88
+ # logging disabled on HF Spaces
89
+ pass # logging disabled on HF Spaces
90
+
91
+ # logging disabled on HF Spaces
92
+
93
+ st.write("---")
94
+ st.markdown("**Model info:** IndoBERT shared encoder (multi-head). Checkpoint: `" + MODEL_CKPT + "`")
95
+ st.markdown("**Notes:** For best label names ensure `le_topic_classes.npy` and `le_tax_classes.npy` are present in the app folder.")
le_tax_classes.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa81876133d73c7c61dcb583cfa4e74f6dc72a5b1d30fe25f800576333a279be
3
+ size 160
le_topic_classes.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a3c127a53576a488189251889d52c07bb3d0bb33ff8b8a5a25b26e8f0ad90dc
3
+ size 535
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.1
2
+ transformers==4.44.2
3
+ streamlit==1.38.0
4
+ numpy==1.26.4
5
+ accelerate
6
+ protobuf==3.20.3
7
+ huggingface-hub==0.17.4
8
+ joblib
9
+ scikit-learn
10
+ sentencepiece