in src/sal/utils/math.py [0:0]
def subsample_completions(x: Dict[str, List[Any]], n: int) -> Dict[str, List[Any]]:
completions = x["completions"]
agg_scores = x["agg_scores"]
if len(completions) != len(agg_scores):
raise ValueError(
f"The number of completions and agg_scores should be the same. Got {len(completions)} completions and {len(agg_scores)} agg_scores."
)
# Take the first n samples, as the completions are ordered in groups of size m e.g [0,0,0,0, 1,1,1,1, 2,2,2,2, ...]
# We need to ensure these groups are not broken up in order to have a valid comparison at smaller n
return {
f"completions@{n}": completions[:n],
f"agg_scores@{n}": agg_scores[:n],
}