Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -132,31 +132,33 @@ def main():
|
|
| 132 |
input_text = st.text_input('Enter a sentence')
|
| 133 |
# put two buttons side by side in the sidebar
|
| 134 |
# translate_button = st.button('Translate', key='translate_button')
|
| 135 |
-
viz_button = st.button('Visualize Attention', key='viz_button')
|
| 136 |
attn_type = st.selectbox('Select attention type', ['encoder', 'decoder', 'encoder-decoder'])
|
| 137 |
-
layers = st.multiselect('Select layers', list(range(
|
| 138 |
-
heads = st.multiselect('Select heads', list(range(
|
| 139 |
# allow the user to select the all the layers and heads at once to visualize
|
| 140 |
if st.checkbox('Select all layers'):
|
| 141 |
-
layers = list(range(
|
| 142 |
if st.checkbox('Select all heads'):
|
| 143 |
-
heads = list(range(
|
| 144 |
|
| 145 |
-
if
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
| 160 |
else:
|
| 161 |
st.write('Enter a sentence to visualize the attention of the model')
|
| 162 |
|
|
|
|
| 132 |
input_text = st.text_input('Enter a sentence')
|
| 133 |
# put two buttons side by side in the sidebar
|
| 134 |
# translate_button = st.button('Translate', key='translate_button')
|
| 135 |
+
# viz_button = st.button('Visualize Attention', key='viz_button')
|
| 136 |
attn_type = st.selectbox('Select attention type', ['encoder', 'decoder', 'encoder-decoder'])
|
| 137 |
+
layers = st.multiselect('Select layers', list(range(6)))
|
| 138 |
+
heads = st.multiselect('Select heads', list(range(8)))
|
| 139 |
# allow the user to select the all the layers and heads at once to visualize
|
| 140 |
if st.checkbox('Select all layers'):
|
| 141 |
+
layers = list(range(6))
|
| 142 |
if st.checkbox('Select all heads'):
|
| 143 |
+
heads = list(range(8))
|
| 144 |
|
| 145 |
+
if input_text != '':
|
| 146 |
+
with st.spinner("Translating..."):
|
| 147 |
+
encoder_input_tokens, decoder_input_tokens, output = process_input(input_text, tokenizer_src, tokenizer_tgt, model, config, device)
|
| 148 |
+
max_sentence_len = len(encoder_input_tokens)
|
| 149 |
+
row_tokens = encoder_input_tokens
|
| 150 |
+
col_tokens = decoder_input_tokens
|
| 151 |
+
st.write('Input:', ' '.join(encoder_input_tokens))
|
| 152 |
+
st.write('Output:', ' '.join(decoder_input_tokens))
|
| 153 |
+
st.write('Translated:', output)
|
| 154 |
+
st.write('Attention Visualization')
|
| 155 |
+
with st.spinner("Visualizing Attention..."):
|
| 156 |
+
if attn_type == 'encoder':
|
| 157 |
+
st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, row_tokens, max_sentence_len, model))
|
| 158 |
+
elif attn_type == 'decoder':
|
| 159 |
+
st.write(get_all_attention_maps(attn_type, layers, heads, col_tokens, col_tokens, max_sentence_len, model))
|
| 160 |
+
elif attn_type == 'encoder-decoder':
|
| 161 |
+
st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, col_tokens, max_sentence_len, model))
|
| 162 |
else:
|
| 163 |
st.write('Enter a sentence to visualize the attention of the model')
|
| 164 |
|