ax/plot/js/common/helpers.js (189 lines of code) (raw):
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// helper functions used across multiple plots
function rgb(rgb_array) {
return 'rgb(' + rgb_array.join() + ')';
}
function copy_and_reverse(arr) {
const copy = arr.slice();
copy.reverse();
return copy;
}
function axis_range(grid, is_log) {
return is_log
? [Math.log10(Math.min(...grid)), Math.log10(Math.max(...grid))]
: [Math.min(...grid), Math.max(...grid)];
}
function relativize_data(f, sd, rel, arm_data, metric) {
// if relative, extract status quo & compute ratio
const f_final = rel === true ? [] : f;
const sd_final = rel === true ? [] : sd;
if (rel === true) {
const f_sq =
arm_data['in_sample'][arm_data['status_quo_name']]['y'][metric];
const sd_sq =
arm_data['in_sample'][arm_data['status_quo_name']]['se'][metric];
for (let i = 0; i < f.length; i++) {
res = relativize(f[i], sd[i], f_sq, sd_sq);
f_final.push(100 * res[0]);
sd_final.push(100 * res[1]);
}
}
return [f_final, sd_final];
}
function relativize(m_t, sem_t, m_c, sem_c) {
r_hat =
(m_t - m_c) / Math.abs(m_c) -
(Math.pow(sem_c, 2) * m_t) / Math.pow(Math.abs(m_c), 3);
variance =
(Math.pow(sem_t, 2) + Math.pow((m_t / m_c) * sem_c, 2)) / Math.pow(m_c, 2);
return [r_hat, Math.sqrt(variance)];
}
function slice_config_to_trace(
arm_data,
arm_name_to_parameters,
f,
fit_data,
grid,
metric,
param,
rel,
setx,
sd,
is_log,
visible,
) {
// format data
const res = relativize_data(f, sd, rel, arm_data, metric);
const f_final = res[0];
const sd_final = res[1];
// get data for standard deviation fill plot
const sd_upper = [];
const sd_lower = [];
for (let i = 0; i < sd.length; i++) {
sd_upper.push(f_final[i] + 2 * sd_final[i]);
sd_lower.push(f_final[i] - 2 * sd_final[i]);
}
const grid_rev = copy_and_reverse(grid);
const sd_lower_rev = copy_and_reverse(sd_lower);
const sd_x = grid.concat(grid_rev);
const sd_y = sd_upper.concat(sd_lower_rev);
// get data for observed arms and error bars
const arm_x = [];
const arm_y = [];
const arm_sem = [];
fit_data.forEach(row => {
parameters = arm_name_to_parameters[row['arm_name']];
plot = true;
Object.keys(setx).forEach(p => {
if (p !== param && parameters[p] !== setx[p]) {
plot = false;
}
});
if (plot === true) {
arm_x.push(parameters[param]);
arm_y.push(row['mean']);
arm_sem.push(row['sem']);
}
});
const arm_res = relativize_data(arm_y, arm_sem, rel, arm_data, metric);
const arm_y_final = arm_res[0];
const arm_sem_final = arm_res[1].map(x => x * 2);
// create traces
const f_trace = {
x: grid,
y: f_final,
showlegend: false,
hoverinfo: 'x+y',
line: {
color: 'rgba(128, 177, 211, 1)',
},
visible: visible,
};
const arms_trace = {
x: arm_x,
y: arm_y_final,
mode: 'markers',
error_y: {
type: 'data',
array: arm_sem_final,
visible: true,
color: 'black',
},
line: {
color: 'black',
},
showlegend: false,
hoverinfo: 'x+y',
visible: visible,
};
const sd_trace = {
x: sd_x,
y: sd_y,
fill: 'toself',
fillcolor: 'rgba(128, 177, 211, 0.2)',
line: {
color: 'transparent',
},
showlegend: false,
hoverinfo: 'none',
visible: visible,
};
traces = [sd_trace, f_trace, arms_trace];
// iterate over out-of-sample arms
let i = 1;
Object.keys(arm_data['out_of_sample']).forEach(generator_run_name => {
const ax = [];
const ay = [];
const asem = [];
const atext = [];
Object.keys(arm_data['out_of_sample'][generator_run_name]).forEach(
arm_name => {
const parameters =
arm_data['out_of_sample'][generator_run_name][arm_name]['parameters'];
plot = true;
Object.keys(setx).forEach(p => {
if (p !== param && parameters[p] !== setx[p]) {
plot = false;
}
});
if (plot === true) {
ax.push(parameters[param]);
ay.push(
arm_data['out_of_sample'][generator_run_name][arm_name]['y_hat'][
metric
],
);
asem.push(
arm_data['out_of_sample'][generator_run_name][arm_name]['se_hat'][
metric
],
);
atext.push('<em>Candidate ' + arm_name + '</em>');
}
},
);
const out_of_sample_arm_res = relativize_data(
ay,
asem,
rel,
arm_data,
metric,
);
const ay_final = out_of_sample_arm_res[0];
const asem_final = out_of_sample_arm_res[1].map(x => x * 2);
traces.push({
hoverinfo: 'text',
legendgroup: generator_run_name,
marker: {color: 'black', symbol: i, opacity: 0.5},
mode: 'markers',
error_y: {
type: 'data',
array: asem_final,
visible: true,
color: 'black',
},
name: generator_run_name,
text: atext,
type: 'scatter',
xaxis: 'x',
x: ax,
yaxis: 'y',
y: ay_final,
visible: visible,
});
i += 1;
});
return traces;
}