5-4o_fine_tuning/data_explorer.py (84 lines of code) (raw):
"""
streamlit app for exploring and editing JSONL files with chat data.
"""
import json
import os
import time
import streamlit as st
def load_jsonl(file):
return [json.loads(line) for line in file]
def save_jsonl(file_path, data):
with open(file_path, 'w') as file:
for record in data:
file.write(json.dumps(record) + '\n')
def display_message(message, role, index, message_index, data, file_path):
"""Display and edit a single message."""
st.markdown(f"**{role.upper()}**")
content = st.text_area("Message Content", message['content'],
key=f"message_{index}_{message_index}", label_visibility="hidden")
if st.button("Save", key=f"save_{index}_{message_index}"):
data[index]['messages'][message_index]['content'] = content
save_jsonl(file_path, data)
st.success("Content saved!")
def get_file_stats(file_path):
"""Get file statistics such as size, creation, and modification times."""
stats = os.stat(file_path)
return {
"size": stats.st_size,
"created": time.ctime(stats.st_ctime),
"last_modified": time.ctime(stats.st_mtime)
}
def main():
st.set_page_config(layout="wide")
st.title("🔍 Fine-tuning Data Explorer (Beta)")
st.markdown(
"**Demo Purposes Only** - This app allows you to explore and edit JSONL files with chat data. ")
base_dir = os.path.dirname(os.path.abspath(__file__))
jsonl_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)
if file.endswith('.jsonl') and ('train' in file or 'val' in file) and
open(os.path.join(base_dir, file)).readline().strip().startswith('{"messages": ')]
assert jsonl_files, "No JSONL files found in the directory."
selected_file = st.selectbox(
"Choose the JSONL file with chat data", jsonl_files)
if selected_file:
if 'last_selected_file' not in st.session_state or st.session_state.last_selected_file != selected_file:
st.session_state.update(
{"row_index": 0, "last_selected_file": selected_file})
st.session_state.data = load_jsonl(open(selected_file, 'r'))
st.session_state.total_records = len(st.session_state.data)
data = st.session_state.data
st.sidebar.header("Useful Stats")
file_stats = get_file_stats(selected_file)
st.sidebar.write(f"Total records: {st.session_state.total_records}")
st.sidebar.write(f"File size: {file_stats['size']} bytes")
total_messages = sum(len(record['messages']) for record in data)
avg_messages_per_record = total_messages / st.session_state.total_records
total_characters = sum(
len(message['content']) for record in data for message in record['messages'])
avg_characters_per_message = total_characters / total_messages
st.sidebar.write(f"Total messages: {total_messages}")
st.sidebar.write(
f"Average messages per record: {avg_messages_per_record:.2f}")
st.sidebar.write(f"Total characters: {total_characters}")
st.sidebar.write(
f"Average characters per message: {avg_characters_per_message:.2f}")
if 'row_index' not in st.session_state:
st.session_state.row_index = 0
st.session_state.row_index = st.slider("Record Index", 1, st.session_state.total_records,
st.session_state.row_index + 1, label_visibility="hidden") - 1
col1, col2, col3 = st.columns([1, 1, 1])
with col1:
if st.button("⬅️ Previous", on_click=lambda: st.session_state.update(
row_index=(st.session_state.row_index - 1) % st.session_state.total_records)):
pass
with col2:
st.write(
f"Displaying record {st.session_state.row_index + 1} of {st.session_state.total_records}")
with col3:
if st.button("Next ➡️", on_click=lambda: st.session_state.update(
row_index=(st.session_state.row_index + 1) % st.session_state.total_records)):
pass
row_index = st.session_state.row_index
row_data = data[row_index]
for i, message in enumerate(row_data['messages']):
display_message(
message, message['role'], row_index, i, data, selected_file)
if i < len(row_data['messages']) - 1:
st.markdown("---")
if __name__ == "__main__":
main()