in 5-4o_fine_tuning/data_explorer.py [0:0]
def main():
st.set_page_config(layout="wide")
st.title("🔍 Fine-tuning Data Explorer (Beta)")
st.markdown(
"**Demo Purposes Only** - This app allows you to explore and edit JSONL files with chat data. ")
base_dir = os.path.dirname(os.path.abspath(__file__))
jsonl_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)
if file.endswith('.jsonl') and ('train' in file or 'val' in file) and
open(os.path.join(base_dir, file)).readline().strip().startswith('{"messages": ')]
assert jsonl_files, "No JSONL files found in the directory."
selected_file = st.selectbox(
"Choose the JSONL file with chat data", jsonl_files)
if selected_file:
if 'last_selected_file' not in st.session_state or st.session_state.last_selected_file != selected_file:
st.session_state.update(
{"row_index": 0, "last_selected_file": selected_file})
st.session_state.data = load_jsonl(open(selected_file, 'r'))
st.session_state.total_records = len(st.session_state.data)
data = st.session_state.data
st.sidebar.header("Useful Stats")
file_stats = get_file_stats(selected_file)
st.sidebar.write(f"Total records: {st.session_state.total_records}")
st.sidebar.write(f"File size: {file_stats['size']} bytes")
total_messages = sum(len(record['messages']) for record in data)
avg_messages_per_record = total_messages / st.session_state.total_records
total_characters = sum(
len(message['content']) for record in data for message in record['messages'])
avg_characters_per_message = total_characters / total_messages
st.sidebar.write(f"Total messages: {total_messages}")
st.sidebar.write(
f"Average messages per record: {avg_messages_per_record:.2f}")
st.sidebar.write(f"Total characters: {total_characters}")
st.sidebar.write(
f"Average characters per message: {avg_characters_per_message:.2f}")
if 'row_index' not in st.session_state:
st.session_state.row_index = 0
st.session_state.row_index = st.slider("Record Index", 1, st.session_state.total_records,
st.session_state.row_index + 1, label_visibility="hidden") - 1
col1, col2, col3 = st.columns([1, 1, 1])
with col1:
if st.button("⬅️ Previous", on_click=lambda: st.session_state.update(
row_index=(st.session_state.row_index - 1) % st.session_state.total_records)):
pass
with col2:
st.write(
f"Displaying record {st.session_state.row_index + 1} of {st.session_state.total_records}")
with col3:
if st.button("Next ➡️", on_click=lambda: st.session_state.update(
row_index=(st.session_state.row_index + 1) % st.session_state.total_records)):
pass
row_index = st.session_state.row_index
row_data = data[row_index]
for i, message in enumerate(row_data['messages']):
display_message(
message, message['role'], row_index, i, data, selected_file)
if i < len(row_data['messages']) - 1:
st.markdown("---")