Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| st.set_page_config( | |
| page_title="Holistic AI - LLM Risks", | |
| page_icon="👋", | |
| layout='wide' | |
| ) | |
| import json | |
| import os | |
| from huggingface_hub import HfApi, login | |
| from streamlit_cookies_manager import EncryptedCookieManager | |
| import re | |
| def program(): | |
| dataset_name = "holistic-ai/LLM-Risks" | |
| token = os.getenv("HF_TOKEN") | |
| api = HfApi() | |
| login(token) | |
| repo_path = api.snapshot_download(repo_id=dataset_name, repo_type="dataset") | |
| with open(f'{repo_path}/risk_annotation_consolidated.json') as file: | |
| data = json.load(file) | |
| task_names = list(set([item['task'] for item in data])) | |
| def camel_to_whitespace(camel_str): | |
| spaced_str = re.sub(r'([A-Z])', r' \1', camel_str).lower() | |
| spaced_str = spaced_str.strip().title() | |
| return spaced_str | |
| task_2_task_string = {task: camel_to_whitespace(task) for task in task_names} | |
| task_string_2_task = {task_string:task for task,task_string in task_2_task_string.items()} | |
| task_strings = [task_2_task_string[t] for t in task_names] | |
| # Sidebar filters | |
| with st.sidebar: | |
| st.sidebar.image("hai_logo.png", width=150, use_column_width=True) | |
| st.header("Filters") | |
| # Extract unique task names and groups | |
| selected_task_string = st.selectbox("Select a Task", task_strings) | |
| selected_task = task_string_2_task[selected_task_string] | |
| # Filter data based on selected task | |
| filtered_data_by_task = [item for item in data if item['task'] == selected_task] | |
| groups = list(set([item['group'] for item in filtered_data_by_task])) | |
| selected_group = st.selectbox("Select a Risk Group", groups) | |
| # Filter data based on selected group | |
| filtered_data_by_group = [item for item in filtered_data_by_task if item['group'] == selected_group] | |
| st.divider() | |
| st.sidebar.markdown(f"**Task**: {selected_task_string}") | |
| st.sidebar.markdown(f"**Risk Group**: {selected_group}") | |
| # CSS for reducing the vertical spacing between <p> tags, justifying text, and ensuring equal height cards | |
| st.markdown(""" | |
| <style> | |
| .card { | |
| border: 1px solid #ddd; | |
| border-radius: 10px; | |
| padding: 10px; | |
| margin: 10px; | |
| height: 100%; | |
| display: flex; | |
| flex-direction: column; | |
| justify-content: space-between; | |
| box-sizing: border-box; | |
| background-color: #e4e8f5; | |
| } | |
| .card h3 { | |
| margin-top: 0; | |
| background-color: #e4e8f5; | |
| } | |
| .card p { | |
| margin: 2px 0; | |
| padding: 0; | |
| text-align: justify; | |
| background-color: #e4e8f5; | |
| } | |
| .stApp { | |
| max-width: 100%; | |
| padding: 1rem; | |
| } | |
| .grid { | |
| display: flex; | |
| flex-wrap: wrap; | |
| justify-content: space-between; | |
| } | |
| .grid-item { | |
| flex: 1 0 23%; /* 4 items per row */ | |
| box-sizing: border-box; | |
| margin: 1%; | |
| display: flex; | |
| } | |
| .grid-item .card { | |
| flex: 1; | |
| display: flex; | |
| flex-direction: column; | |
| justify-content: space-between; | |
| background-color: #e4e8f5; | |
| } | |
| @media (max-width: 1200px) { | |
| .grid-item { | |
| flex: 1 0 46%; /* 2 items per row */ | |
| } | |
| } | |
| @media (max-width: 768px) { | |
| .grid-item { | |
| flex: 1 0 96%; /* 1 item per row */ | |
| } | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| sidebar_style = """ | |
| <style> | |
| [data-testid="stSidebar"] { | |
| background-color: white; | |
| } | |
| </style> | |
| """ | |
| # Aplica el estilo al sidebar | |
| st.markdown(sidebar_style, unsafe_allow_html=True) | |
| #st.title("LLM Risks and Mitigators") | |
| tabs = st.tabs(["Examples", "Mitigators"]) | |
| with tabs[0]: | |
| # Display the filtered news as a grid of cards | |
| if len(filtered_data_by_group) > 0: | |
| for risk in set([item['risk'] for item in filtered_data_by_group]): | |
| item = [item for item in filtered_data_by_group if item['risk'] == risk][0] | |
| st.header(risk) | |
| st.write(f"Risk Description: {item['description']}") | |
| # Define the number of columns | |
| num_columns = 3 | |
| col_index = 0 | |
| # Create an empty container for the grid | |
| grid = st.container() | |
| # Initialize an empty row | |
| row = grid.columns(num_columns) | |
| for news in item['examples']: | |
| with row[col_index]: | |
| st.markdown( | |
| f""" | |
| <div class="grid-item"> | |
| <div class="card"> | |
| <h3>{news['title']}</h3> | |
| <p>{news['incident']}</p> | |
| <a href="{news['link']}" target="_blank">Read more</a> | |
| </div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| col_index = (col_index + 1) % num_columns | |
| # Start a new row after the last column | |
| if col_index == 0: | |
| row = grid.columns(num_columns) | |
| if len(filtered_data_by_group) == 0: | |
| st.write("No news found for the selected task and group.") | |
| with tabs[1]: | |
| # Display the filtered news as a grid of cards | |
| if len(filtered_data_by_group) > 0: | |
| for risk in set([item['risk'] for item in filtered_data_by_group]): | |
| item = [item for item in filtered_data_by_group if item['risk'] == risk][0] | |
| st.header(risk) | |
| st.write(f"Risk Description: {item['description']}") | |
| num_columns = 3 | |
| col_index = 0 | |
| # Create an empty container for the grid | |
| grid = st.container() | |
| # Initialize an empty row | |
| row = grid.columns(num_columns) | |
| for news in item['mitigators']: | |
| with row[col_index]: | |
| st.markdown( | |
| f""" | |
| <div class="grid-item"> | |
| <div class="card"> | |
| <h3>{news['title']}</h3> | |
| <p>{news['recommendation']}</p> | |
| <p><b>Year:</b> {news['year']}</p> | |
| <a href="{news['link']}" target="_blank">Read more</a> | |
| </div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| col_index = (col_index + 1) % num_columns | |
| # Start a new row after the last column | |
| if col_index == 0: | |
| row = grid.columns(num_columns) | |
| if len(filtered_data_by_group) == 0: | |
| st.write("No news found for the selected task and group.") | |
| SECRET_KEY = os.getenv('SECRET_KEY') | |
| cookies = EncryptedCookieManager( | |
| prefix="login", | |
| password=os.getenv('COOKIES_PASSWORD') | |
| ) | |
| if not cookies.ready(): | |
| st.stop() | |
| def main(): | |
| # Título de la aplicación | |
| st.title("LLM Mitigation") | |
| if not cookies.get("authenticated"): | |
| # Entrada de la clave secreta | |
| user_key = st.text_input("Password:", type="password") | |
| if st.button("Login"): | |
| # Verificar si la clave ingresada coincide con la clave secreta | |
| if user_key == SECRET_KEY: | |
| cookies.__setitem__("authenticated", "True") | |
| st.experimental_rerun() | |
| else: | |
| st.error("Acceso denegado. Clave incorrecta.") | |
| else: | |
| program() | |
| if __name__ == "__main__": | |
| main() |