| """See the data as seen by your model.""" |
| import pandas as pd |
| import streamlit as st |
|
|
| from src.subpages.page import Context, Page |
| from src.utils import aggrid_interactive_table |
|
|
|
|
| @st.cache |
| def convert_df(df): |
| return df.to_csv().encode("utf-8") |
|
|
|
|
| class RawDataPage(Page): |
| name = "Raw data" |
| icon = "qr-code" |
|
|
| def render(self, context: Context): |
| st.title(self.name) |
| with st.expander("💡", expanded=True): |
| st.write("See the data as seen by your model.") |
|
|
| st.subheader("Dataset") |
| st.code( |
| f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}" |
| ) |
|
|
| st.write("**Data after processing and inference**") |
|
|
| processed_df = ( |
| context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3) |
| ) |
| cols = ( |
| "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() |
| ) |
| if "token_type_ids" not in processed_df.columns: |
| cols.remove("token_type_ids") |
| processed_df = processed_df[cols] |
| aggrid_interactive_table(processed_df) |
| processed_df_csv = convert_df(processed_df) |
| st.download_button( |
| "Download csv", |
| processed_df_csv, |
| "processed_data.csv", |
| "text/csv", |
| ) |
|
|
| st.write("**Raw data (exploded by tokens)**") |
| raw_data_df = context.split.to_pandas().apply(pd.Series.explode) |
| aggrid_interactive_table(raw_data_df) |
| raw_data_df_csv = convert_df(raw_data_df) |
| st.download_button( |
| "Download csv", |
| raw_data_df_csv, |
| "raw_data.csv", |
| "text/csv", |
| ) |
|
|