def find_inflections()

in grok/visualization.py [0:0]


def find_inflections(Y, smoothing_steps=100):
    avg_Y = moving_avg(Y, smoothing_steps)
    avg_direction = torch.FloatTensor(np.sign(avg_Y[1:] - avg_Y[:-1]))
    avg_direction = torch.cat([avg_direction[0].unsqueeze(0), avg_direction])
    avg_inflections = torch.nonzero(avg_direction[1:] - avg_direction[:-1]).squeeze()
    avg_inflections = [0] + (avg_inflections + 1).tolist() + [len(Y) - 1]
    logger.debug(f"avg_inflections = {avg_inflections}")
    inflections = []
    for i in range(2, len(avg_inflections)):
        low = avg_inflections[i - 2]
        high = avg_inflections[i]
        logger.debug(f"low={low}")
        logger.debug(f"high={high}")
        if avg_direction[low + 1] < 0:
            indices = Y[low:high].argmin() + low
            logger.debug(f"min = (Y[{indices}] = {Y[int(indices)]}")
        else:
            indices = Y[low:high].argmax() + low
            logger.debug(f"max = (Y[{indices}] = {Y[int(indices)]}")
        inflections.append(indices)
    return torch.LongTensor(inflections)