in simulation/decai/simulation/simulate.py [0:0]
def simulate(self,
agents: List[Agent],
baseline_accuracy: float = None,
init_train_data_portion: float = 0.1,
pm_test_sets: list = None,
accuracy_plot_wait_s=2E5,
train_size: int = None, test_size: int = None,
filename_indicator: str = None
):
"""
Run a simulation.
:param agents: The agents that will interact with the data.
:param baseline_accuracy: The baseline accuracy of the model.
Usually the accuracy on a hidden test set when the model is trained with all data.
:param init_train_data_portion: The portion of the data to initially use for training. Must be [0,1].
:param pm_test_sets: The test sets for the prediction market incentive mechanism.
:param accuracy_plot_wait_s: The amount of time to wait in seconds between plotting the accuracy.
:param train_size: The amount of training data to use.
:param test_size: The amount of test data to use.
:param filename_indicator: Path of the filename to create for the run.
"""
assert 0 <= init_train_data_portion <= 1
# Data to save.
save_data = dict(agents=[asdict(a) for a in agents],
baselineAccuracy=baseline_accuracy,
initTrainDataPortion=init_train_data_portion,
accuracies=[],
balances=[],
)
time_for_filenames = int(time.time())
save_path = f'saved_runs/{time_for_filenames}-{filename_indicator}-simulation_data.json'
model_save_path = f'saved_runs/{time_for_filenames}-{filename_indicator}-model.json'
plot_save_path = f'saved_runs/{time_for_filenames}-{filename_indicator}.png'
self._logger.info("Saving run info to \"%s\".", save_path)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Set up plots.
doc: Document = curdoc()
doc.title = "DeCAI Simulation"
plot = figure(title="Balances & Accuracy on Hidden Test Set",
)
plot.width = 800
plot.height = 600
plot.xaxis.axis_label = "Time (days)"
plot.yaxis.axis_label = "Percent"
plot.title.text_font_size = '20pt'
plot.xaxis.major_label_text_font_size = '20pt'
plot.xaxis.axis_label_text_font_size = '20pt'
plot.yaxis.major_label_text_font_size = '20pt'
plot.yaxis.axis_label_text_font_size = '20pt'
plot.xaxis[0].ticker = AdaptiveTicker(base=5 * 24 * 60 * 60)
plot.xgrid[0].ticker = AdaptiveTicker(base=24 * 60 * 60)
balance_plot_sources_per_agent = dict()
good_colors = cycle([
colors.named.green,
colors.named.lawngreen,
colors.named.darkgreen,
colors.named.limegreen,
])
bad_colors = cycle([
colors.named.red,
colors.named.darkred,
])
for agent in agents:
source = ColumnDataSource(dict(t=[], b=[]))
assert agent.address not in balance_plot_sources_per_agent
balance_plot_sources_per_agent[agent.address] = source
if agent.calls_model:
color = 'blue'
line_dash = 'dashdot'
elif agent.good:
color = next(good_colors)
line_dash = 'dotted'
else:
color = next(bad_colors)
line_dash = 'dashed'
plot.line(x='t', y='b',
line_dash=line_dash,
line_width=2,
source=source,
color=color,
legend=f"{agent.address} Balance")
plot.legend.location = 'top_left'
plot.legend.label_text_font_size = '12pt'
# JavaScript code.
plot.xaxis[0].formatter = FuncTickFormatter(code="""