Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModel, utils | |
| from bertviz import model_view | |
| import streamlit.components.v1 as components | |
| from train import get_or_build_tokenizer, greedy_decode | |
| from config import get_config, latest_weights_file_path | |
| from model import build_transformer | |
| import torch | |
| from bertviz import model_view | |
| import torch | |
| import altair as alt | |
| import pandas as pd | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| utils.logging.set_verbosity_error() # Suppress standard warnings | |
| st.set_page_config(page_title='Attention Visualizer', layout='wide') | |
| def mtx2df(m, max_row, max_col, row_tokens, col_tokens): | |
| return pd.DataFrame( | |
| [ | |
| ( | |
| r, | |
| c, | |
| float(m[r, c]), | |
| "%.2d - %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"), | |
| "%.2d - %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"), | |
| ) | |
| for r in range(m.shape[0]) | |
| for c in range(m.shape[1]) | |
| if r < max_row and c < max_col | |
| ], | |
| columns=["row", "column", "value", "row_token", "col_token"], | |
| ) | |
| def get_attn_map(attn_type: str, layer: int, head: int, model): | |
| if attn_type == "encoder": | |
| attn = model.encoder.layers[layer].self_attention_block.attention_scores | |
| elif attn_type == "decoder": | |
| attn = model.decoder.layers[layer].self_attention_block.attention_scores | |
| elif attn_type == "encoder-decoder": | |
| attn = model.decoder.layers[layer].cross_attention_block.attention_scores | |
| return attn[0, head].data | |
| def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len, model): | |
| df = mtx2df( | |
| get_attn_map(attn_type, layer, head, model), | |
| max_sentence_len, | |
| max_sentence_len, | |
| row_tokens, | |
| col_tokens, | |
| ) | |
| return ( | |
| alt.Chart(data=df) | |
| .mark_rect() | |
| .encode( | |
| x=alt.X("col_token", axis=alt.Axis(title="")), | |
| y=alt.Y("row_token", axis=alt.Axis(title="")), | |
| color=alt.Color("value", scale=alt.Scale(scheme="blues")), | |
| tooltip=["row", "column", "value", "row_token", "col_token"], | |
| ) | |
| #.title(f"Layer {layer} Head {head}") | |
| .properties(height=200, width=200, title=f"Layer {layer} Head {head}") | |
| .interactive() | |
| ) | |
| def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int, model): | |
| charts = [] | |
| for layer in layers: | |
| rowCharts = [] | |
| for head in heads: | |
| rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len, model)) | |
| charts.append(alt.hconcat(*rowCharts)) | |
| return alt.vconcat(*charts) | |
| def initiate_model(config, device): | |
| tokenizer_src = get_or_build_tokenizer(config, None, config["lang_src"]) | |
| tokenizer_tgt = get_or_build_tokenizer(config, None, config["lang_tgt"]) | |
| model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device) | |
| model_filename = latest_weights_file_path(config) | |
| state = torch.load(model_filename, map_location=torch.device('cpu')) | |
| model.load_state_dict(state['model_state_dict']) | |
| return model, tokenizer_src, tokenizer_tgt | |
| def process_input(input_text, tokenizer_src, tokenizer_tgt, model, config, device): | |
| src = tokenizer_src.encode(input_text) | |
| src = torch.cat([ | |
| torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64), | |
| torch.tensor(src.ids, dtype=torch.int64), | |
| torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64), | |
| torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (config['seq_len'] - len(src.ids) - 2), dtype=torch.int64) | |
| ], dim=0).to(device) | |
| source_mask = (src != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device) | |
| encoder_input_tokens = [tokenizer_src.id_to_token(i) for i in src.cpu().numpy()] | |
| encoder_input_tokens = [i for i in encoder_input_tokens if i != '[PAD]'] | |
| model_out = greedy_decode(model, src, source_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device) | |
| decoder_input_tokens = [tokenizer_tgt.id_to_token(i) for i in model_out.cpu().numpy()] | |
| output = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) | |
| return encoder_input_tokens, decoder_input_tokens, output | |
| # def get_html_data(model_name, input_text): | |
| # model_name ="microsoft/xtremedistil-l12-h384-uncased" | |
| # model = AutoModel.from_pretrained(model_name, output_attentions=True, cache_dir='__pycache__') # Configure model to return attention values | |
| # tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # inputs = tokenizer.encode(input_text, return_tensors='pt') # Tokenize input text | |
| # outputs = model(inputs) # Run model | |
| # attention = outputs[-1] # Retrieve attention from model outputs | |
| # tokens = tokenizer.convert_ids_to_tokens(inputs[0]) # Convert input ids to token strings | |
| # model_html = model_view(attention, tokens, html_action="return") # Display model view | |
| # with open("static/model_view.html", 'w') as file: | |
| # file.write(model_html.data) | |
| def main(): | |
| st.title('Transformer Visualizer') | |
| # st.info('Enter a sentence to visualize the attention of the model') | |
| st.write('This app visualizes the attention of a transformer model on a given sentence.') | |
| # add a side bar with model options and a prompt | |
| config = get_config() | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model, tokenizer_src, tokenizer_tgt = initiate_model(config, device) | |
| with st.sidebar: | |
| input_text = st.text_input('Enter a sentence') | |
| # put two buttons side by side in the sidebar | |
| # translate_button = st.button('Translate', key='translate_button') | |
| # viz_button = st.button('Visualize Attention', key='viz_button') | |
| attn_type = st.selectbox('Select attention type', ['encoder', 'decoder', 'encoder-decoder']) | |
| layers = st.multiselect('Select layers', list(range(6))) | |
| heads = st.multiselect('Select heads', list(range(8))) | |
| # allow the user to select the all the layers and heads at once to visualize | |
| if st.checkbox('Select all layers'): | |
| layers = list(range(6)) | |
| if st.checkbox('Select all heads'): | |
| heads = list(range(8)) | |
| if input_text != '': | |
| with st.spinner("Translating..."): | |
| encoder_input_tokens, decoder_input_tokens, output = process_input(input_text, tokenizer_src, tokenizer_tgt, model, config, device) | |
| max_sentence_len = len(encoder_input_tokens) | |
| row_tokens = encoder_input_tokens | |
| col_tokens = decoder_input_tokens | |
| st.write('Input:', ' '.join(encoder_input_tokens)) | |
| st.write('Output:', ' '.join(decoder_input_tokens)) | |
| st.write('Translated:', output) | |
| st.write('Attention Visualization') | |
| with st.spinner("Visualizing Attention..."): | |
| if attn_type == 'encoder': | |
| st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, row_tokens, max_sentence_len, model)) | |
| elif attn_type == 'decoder': | |
| st.write(get_all_attention_maps(attn_type, layers, heads, col_tokens, col_tokens, max_sentence_len, model)) | |
| elif attn_type == 'encoder-decoder': | |
| st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, col_tokens, max_sentence_len, model)) | |
| else: | |
| st.write('Enter a sentence to visualize the attention of the model') | |
| # add a footer with the github repo link and dataset link | |
| st.markdown('---') | |
| st.write('Made by [Pratik Dwivedi](https://github.com/Dekode1859)') | |
| st.write('Check out the Scratch Implementation and Visualizer Code on [GitHub](https://github.com/Dekode1859/transformer-visualizer)') | |
| st.write('Dataset: [Opus-books: english-Italian](https://huggingface.co/datasets/Helsinki-NLP/opus_books)') | |
| # st.write('This app is a Streamlit implementation of the [BERTViz]( | |
| if __name__ == '__main__': | |
| main() |