in functorch/_src/operator_authoring.py [0:0]
def compute_code(self):
bufs = [_te.BufHandle(s.dtype) for s in self.spec]
if not self.spec[-1].out:
options_from = [
i for i in range(len(self.spec)) if self.spec[i].dtype == self.dtype
][0]
self.result.add_allocated_output(options_from, self.output_order)
bufs.append(_te.BufHandle(self.dtype))
self.shapes.append(list(self.shape_vars))
output_strides = [None] * self.ndim
next_stride = _one()
for i in self.output_order:
output_strides[i] = next_stride
next_stride *= self.shape_vars[i]
assert all((x is not None) for x in output_strides)
self.strides.append(output_strides)
bufs_args = list(bufs)
aliases = {}
for i, s in enumerate(self.spec):
assert s.alias_group >= 0, "TODO: support complex aliasing"
if s.alias_group > 0 and s.alias_group not in aliases:
aliases[s.alias_group] = i
elif s.alias_group > 0 and FOLD_ALIASES:
# BufHandle in buf_args is now ignored
bufs[i] = bufs[aliases[s.alias_group]]
input_bufs = bufs[:-1]
input_strides = self.strides[:-1]
output_bufs = bufs[-1:]
output_strides = self.strides[-1:]
inputs = [
_te.Cast.make(self.dtype, buf.load(self.indexing(stride)))
for buf, stride in zip(input_bufs, input_strides)
]
val = _fx_to_expr(self.pointwise_fn, self.dtype)(*inputs)
out = _te.Block(
[
buf.store(self.indexing(stride), val)
for buf, stride in zip(output_bufs, output_strides)
]
)
loops: List[_te.For] = []
for i in self.output_order:
var = self.iter_vars[i]
size = self.shape_vars[i]
out = _te.For.make(var, _zero(), size, out)
loops.insert(0, out)
loopnest = _te.LoopNest(_te.Block([out]), output_bufs)
if self.device == "cuda" and loops:
flattened = loopnest.flatten(loops)
assert flattened
inner = _te.LoopNest.split_with_mask(flattened, 512)
assert inner
flattened.set_gpu_block_index(0)
inner.set_gpu_thread_index(0)
elif self.dtype == "llvm" and loops:
pass # TODO(jansel): need a parallel CPU schedule
loopnest.prepare_for_codegen()
cg = _te.construct_codegen(
self.compile_mode,
loopnest.simplify(),
bufs_args + self.stride_args + self.shape_args,
)
self.result.set_code(cg)