def make_args()

in fbgemm_gpu/codegen/embedding_backward_code_generator.py [0:0]


def make_args(arg_spec: List[Tuple[int, str]]) -> Dict[str, Any]:
    def make_kernel_arg(ty: int, name: str) -> str:
        return {
            TENSOR: acc_cache_tensor_arg,
            INT_TENSOR: int_tensor_arg,
            LONG_TENSOR: long_tensor_arg,
            INT: int64_arg,
            FLOAT: float_arg,
        }[ty](name)

    def make_kernel_arg_constructor(ty: int, name: str) -> str:
        return {
            TENSOR: acc_cache_tensor_arg_constructor,
            INT_TENSOR: int_tensor_arg_constructor,
            LONG_TENSOR: long_tensor_arg_constructor,
            INT: lambda x: x,
            FLOAT: lambda x: x,
        }[ty](name)

    def make_cpu_kernel_arg(ty: int, name: str) -> str:
        return {
            TENSOR: lambda x: acc_cache_tensor_arg(x, gpu=False),
            INT_TENSOR: lambda x: int_tensor_arg(x, gpu=False),
            LONG_TENSOR: lambda x: long_tensor_arg(x, gpu=False),
            INT: int64_arg,
            FLOAT: float_arg,
        }[ty](name)

    def make_cpu_kernel_arg_constructor(ty: int, name: str) -> str:
        return {
            TENSOR: lambda x: acc_cache_tensor_arg_constructor(x, gpu=False),
            INT_TENSOR: lambda x: int_tensor_arg_constructor(x, gpu=False),
            LONG_TENSOR: lambda x: long_tensor_arg_constructor(x, gpu=False),
            INT: lambda x: x,
            FLOAT: lambda x: x,
        }[ty](name)

    def make_function_arg(ty: int, name: str) -> str:
        return {
            TENSOR: tensor_arg,
            INT_TENSOR: tensor_arg,
            LONG_TENSOR: tensor_arg,
            INT: int64_arg,
            FLOAT: double_arg,
        }[ty](name)

    def make_function_schema_arg(ty: int, name: str) -> str:
        return {
            TENSOR: tensor_arg,
            INT_TENSOR: tensor_arg,
            LONG_TENSOR: tensor_arg,
            INT: int_arg,
            FLOAT: float_arg,
        }[ty](name)

    def make_ivalue_cast(ty: int) -> str:
        return {INT: "toInt", FLOAT: "toDouble"}[ty]

    def make_args_for_compute_device(split_arg_spec: List[Tuple[int, str]]) -> Args:
        return Args(
            split_kernel_args=[
                make_kernel_arg(ty, name) for (ty, name) in split_arg_spec
            ],
            split_kernel_arg_constructors=[
                make_kernel_arg_constructor(ty, name) for (ty, name) in split_arg_spec
            ],
            split_cpu_kernel_args=[
                make_cpu_kernel_arg(ty, name) for (ty, name) in split_arg_spec
            ],
            split_cpu_kernel_arg_constructors=[
                make_cpu_kernel_arg_constructor(ty, name)
                for (ty, name) in split_arg_spec
            ],
            split_function_args=[
                make_function_arg(ty, name) for (ty, name) in split_arg_spec
            ],
            split_tensors=[name for (ty, name) in arg_spec if ty == TENSOR],
            split_saved_tensors=[
                name
                for (ty, name) in split_arg_spec
                if ty in (TENSOR, INT_TENSOR, LONG_TENSOR)
            ],
            saved_data=[
                (name, make_ivalue_cast(ty)) for (ty, name) in arg_spec if ty != TENSOR
            ],
            split_function_arg_names=[name for (ty, name) in split_arg_spec],
            split_function_schemas=[
                make_function_schema_arg(ty, name) for (ty, name) in split_arg_spec
            ],
            split_variables=["Variable()" for _ in split_arg_spec],
        )

    split_arg_spec = []
    for (ty, arg) in arg_spec:
        if ty in (FLOAT, INT):
            split_arg_spec.append((ty, arg))
        else:
            assert ty == TENSOR
            split_arg_spec.extend(
                [
                    (TENSOR, f"{arg}_host"),
                    (INT_TENSOR, f"{arg}_placements"),
                    (LONG_TENSOR, f"{arg}_offsets"),
                ]
            )
    cpu = make_args_for_compute_device(split_arg_spec)

    split_arg_spec = []
    for (ty, arg) in arg_spec:
        if ty in (FLOAT, INT):
            split_arg_spec.append((ty, arg))
        else:
            assert ty == TENSOR
            split_arg_spec.extend(
                [
                    (TENSOR, f"{arg}_dev"),
                    (TENSOR, f"{arg}_uvm"),
                    (INT_TENSOR, f"{arg}_placements"),
                    (LONG_TENSOR, f"{arg}_offsets"),
                ]
            )
    cuda = make_args_for_compute_device(split_arg_spec)

    return {"cpu": cpu, "cuda": cuda}