in UI/utils.py [0:0]
def run_visualization(df, custom_flag, key_counter):
columns = df.columns.tolist()
data_types = [str(dtype) for dtype in df.dtypes]
metadata = dict(zip(columns, data_types))
plot_chart_list = ['Line Chart', 'Bar Chart',
'Histogram', 'Scatter Plot',
'Pie Chart', 'Area Chart',
'Box Plot', 'Heatmap', 'Bubble Chart']
# 3. "plot_chooser" Prompt Template
plot_chooser_template = f"""
You are an expert in data visualization.
Given the following DataFrame metadata:
dataframe columns: {columns}
Data Types: {metadata}
number of columns: {len(columns)}
You are tasked to suggest:
1. The best plot type out of {plot_chart_list}
2. The dataframe columns to use for plotting.
Consider the data types and potential relationships between columns.
Remember to suggest below plots according to number of dataframe coumns
** 1 Column: Histogram, Pie Chart
** 2 Columns: Histogram, Pie Chart, Line Chart, Bar Chart,\
Scatter Plot, Area Chart, Box Plot
** More Than 2 Columns: Heatmap, Bubble Chart
Output should be strictly in json format as described below:
```{{
"plot_type": "...",
"x_column": "[...]",
"y_column": "[...]", // Optional, depending on plot type
"z_column": "[...]" // Optional, depending on plot type
}}
```
"""
result_json_string = generate_result(plot_chooser_template)
result_json = json.loads(extract_substring(result_json_string))
if not custom_flag:
plot_type = result_json.get("plot_type", "line")
x_column_list = result_json.get("x_column", columns[0])
logger.info(f"Plot type is : {plot_type}")
if len(columns) > 2:
y_column_list = result_json.get("y_column", columns[1])
z_column_list = result_json.get("y_column", columns[2])
plot_data(df, plot_type, x_column_list,
y_column_list, z_column_list)
elif len(columns) == 2:
y_column_list = result_json.get("y_column", columns[1])
plot_data(df, plot_type, x_column_list, y_column_list)
else:
plot_data(df, plot_type, x_column_list)
else:
with st.container():
plot_type = st.selectbox('Select Plot Type',
['Line Chart', 'Bar Chart',
'Histogram', 'Scatter Plot',
'Pie Chart', 'Area Chart',
'Box Plot', 'Heatmap',
'Bubble Chart'],
key=f'plot_type_{key_counter}')
x_column_list = st.multiselect('Select Primary Column(s)',
df.columns,
key=f'x_column_{key_counter}')
y_column_list = (
st.multiselect(
'Select Secondary Column(s) if applicable',
df.columns,
key=f'y_column_{key_counter}'
)
)
z_column_list = (
st.multiselect(
'Select Other Column(s) if applicable',
df.columns,
key=f'z_column_{key_counter}'
)
if plot_type not in ['Histogram', 'Pie Chart', 'Line Chart',
'Bar Chart', 'Scatter Plot',
'Area Chart', 'Box Plot'] else None
)
if z_column_list is not None:
plot_data(df, plot_type, x_column_list,
y_column_list, z_column_list)
elif y_column_list is not None:
plot_data(df, plot_type, x_column_list,
y_column_list)
else:
plot_data(df, plot_type, x_column_list)