in fbgemm_gpu/codegen/embedding_backward_code_generator.py [0:0]
def make_args(arg_spec: List[Tuple[int, str]]) -> Dict[str, Any]:
def make_kernel_arg(ty: int, name: str) -> str:
return {
TENSOR: acc_cache_tensor_arg,
INT_TENSOR: int_tensor_arg,
LONG_TENSOR: long_tensor_arg,
INT: int64_arg,
FLOAT: float_arg,
}[ty](name)
def make_kernel_arg_constructor(ty: int, name: str) -> str:
return {
TENSOR: acc_cache_tensor_arg_constructor,
INT_TENSOR: int_tensor_arg_constructor,
LONG_TENSOR: long_tensor_arg_constructor,
INT: lambda x: x,
FLOAT: lambda x: x,
}[ty](name)
def make_cpu_kernel_arg(ty: int, name: str) -> str:
return {
TENSOR: lambda x: acc_cache_tensor_arg(x, gpu=False),
INT_TENSOR: lambda x: int_tensor_arg(x, gpu=False),
LONG_TENSOR: lambda x: long_tensor_arg(x, gpu=False),
INT: int64_arg,
FLOAT: float_arg,
}[ty](name)
def make_cpu_kernel_arg_constructor(ty: int, name: str) -> str:
return {
TENSOR: lambda x: acc_cache_tensor_arg_constructor(x, gpu=False),
INT_TENSOR: lambda x: int_tensor_arg_constructor(x, gpu=False),
LONG_TENSOR: lambda x: long_tensor_arg_constructor(x, gpu=False),
INT: lambda x: x,
FLOAT: lambda x: x,
}[ty](name)
def make_function_arg(ty: int, name: str) -> str:
return {
TENSOR: tensor_arg,
INT_TENSOR: tensor_arg,
LONG_TENSOR: tensor_arg,
INT: int64_arg,
FLOAT: double_arg,
}[ty](name)
def make_function_schema_arg(ty: int, name: str) -> str:
return {
TENSOR: tensor_arg,
INT_TENSOR: tensor_arg,
LONG_TENSOR: tensor_arg,
INT: int_arg,
FLOAT: float_arg,
}[ty](name)
def make_ivalue_cast(ty: int) -> str:
return {INT: "toInt", FLOAT: "toDouble"}[ty]
def make_args_for_compute_device(split_arg_spec: List[Tuple[int, str]]) -> Args:
return Args(
split_kernel_args=[
make_kernel_arg(ty, name) for (ty, name) in split_arg_spec
],
split_kernel_arg_constructors=[
make_kernel_arg_constructor(ty, name) for (ty, name) in split_arg_spec
],
split_cpu_kernel_args=[
make_cpu_kernel_arg(ty, name) for (ty, name) in split_arg_spec
],
split_cpu_kernel_arg_constructors=[
make_cpu_kernel_arg_constructor(ty, name)
for (ty, name) in split_arg_spec
],
split_function_args=[
make_function_arg(ty, name) for (ty, name) in split_arg_spec
],
split_tensors=[name for (ty, name) in arg_spec if ty == TENSOR],
split_saved_tensors=[
name
for (ty, name) in split_arg_spec
if ty in (TENSOR, INT_TENSOR, LONG_TENSOR)
],
saved_data=[
(name, make_ivalue_cast(ty)) for (ty, name) in arg_spec if ty != TENSOR
],
split_function_arg_names=[name for (ty, name) in split_arg_spec],
split_function_schemas=[
make_function_schema_arg(ty, name) for (ty, name) in split_arg_spec
],
split_variables=["Variable()" for _ in split_arg_spec],
)
split_arg_spec = []
for (ty, arg) in arg_spec:
if ty in (FLOAT, INT):
split_arg_spec.append((ty, arg))
else:
assert ty == TENSOR
split_arg_spec.extend(
[
(TENSOR, f"{arg}_host"),
(INT_TENSOR, f"{arg}_placements"),
(LONG_TENSOR, f"{arg}_offsets"),
]
)
cpu = make_args_for_compute_device(split_arg_spec)
split_arg_spec = []
for (ty, arg) in arg_spec:
if ty in (FLOAT, INT):
split_arg_spec.append((ty, arg))
else:
assert ty == TENSOR
split_arg_spec.extend(
[
(TENSOR, f"{arg}_dev"),
(TENSOR, f"{arg}_uvm"),
(INT_TENSOR, f"{arg}_placements"),
(LONG_TENSOR, f"{arg}_offsets"),
]
)
cuda = make_args_for_compute_device(split_arg_spec)
return {"cpu": cpu, "cuda": cuda}