def stackoverflow_28()

in Synthesis_incorporation/benchmarks/stackoverflow_benchmarks.py [0:0]


def stackoverflow_28():
    examples = [
        benchmark.Example(
            inputs=[
                [
                    [[5, 3], [0, 2]],
                    [[7, 4], [5, 1]],
                    [[10, 20], [15, 30]],
                    [[11, 16], [14, 12]],
                    [[-2, -7], [-4, 6]],
                ],
                [1, 0, 1, 1, 0],
            ],
            output=[[3, 2], [7, 5], [20, 30], [16, 12], [-2, -4]],
        ),
    ]
    constants = []
    description = "extract columns from a 3D tensor given column indices"
    target_program = "torch.transpose(in1, 1, 2)[torch.arange(in1.size(0)), in2, :]"
    source = "https://stackoverflow.com/questions/54274074/selecting-columns-from-3d-tensor-according-to-a-1d-tensor-of-indices-tensorflow"
    return benchmark.Benchmark(
        examples=examples,
        constants=constants,
        description=description,
        target_program=target_program,
        source=source,
        name="stackoverflow_28",
    )