in lib/core_runtime/core_runtime.cc [240:364]
Expected<CoreRuntimeOp> CoreRuntime::MakeCompositeOp(const Function* fn) {
for (auto iter : llvm::enumerate(fn->argument_types().drop_front())) {
size_t i = iter.index();
auto& type = iter.value();
if (type.GetName() != kTensorHandleType) {
return MakeStringError("The function should only takes type [",
kTensorHandleType, "] as input. But the ", i,
"-th argument is type [", type.GetName(), "].");
}
}
for (auto iter : llvm::enumerate(fn->result_types().drop_front())) {
size_t i = iter.index();
auto& type = iter.value();
if (type.GetName() != kTensorHandleType) {
return MakeStringError("The function should only returns type [",
kTensorHandleType, "]. But the ", i,
"-th results is type [", type.GetName(), "].");
}
}
auto execute_fn = [fn = fn](const OpInvocation& invocation) {
auto* host = invocation.exec_ctx.host();
// TODO(fishx): Return an error to the client instead of asserting.
assert(invocation.arguments.size() + 1 == fn->argument_types().size());
assert(invocation.results.size() + 1 == fn->result_types().size());
llvm::SmallVector<AsyncValue*, 4> arguments;
llvm::SmallVector<RCReference<AsyncValue>, 4> arguments_ref;
arguments.reserve(fn->argument_types().size());
arguments_ref.reserve(fn->argument_types().size());
// The first argument is a chain for side-effects.
if (invocation.chain && *invocation.chain) {
arguments.push_back(invocation.chain->GetAsyncValue());
} else {
arguments_ref.push_back(GetReadyChain());
arguments.push_back(arguments_ref.back().get());
}
for (size_t i = 0, e = invocation.arguments.size(); i != e; ++i) {
arguments_ref.push_back(MakeAvailableAsyncValueRef<TensorHandle>(
host, invocation.arguments[i].CopyRef()));
arguments.push_back(arguments_ref.back().get());
// Clean up the argument to enable input forwarding.
invocation.arguments[i] = TensorHandle();
}
llvm::SmallVector<RCReference<AsyncValue>, 4> results;
results.resize(fn->result_types().size());
fn->Execute(invocation.exec_ctx, arguments, results);
// The first result is the a chain for side-effects.
if (invocation.chain)
*invocation.chain = AsyncValueRef<Chain>(std::move(results[0]));
for (auto iter : llvm::enumerate(llvm::drop_begin(results, 1))) {
size_t i = iter.index();
auto& result_av = iter.value();
if (result_av->IsAvailable()) {
if (result_av->IsError()) {
invocation.results[i] =
TensorHandle(AsyncValueRef<TensorHandle>(std::move(result_av)));
} else {
assert(result_av->IsType<TensorHandle>());
invocation.results[i] = result_av->get<TensorHandle>().CopyRef();
}
} else {
auto device_av =
MakeUnconstructedAsyncValueRef<RCReference<Device>>(host);
auto metadata_av = MakeUnconstructedAsyncValueRef<TensorMetadata>(host);
auto tensor_ind_av = MakeIndirectAsyncValue(host);
result_av->AndThen([result_av = result_av,
device_av = device_av.CopyRef(),
metadata_av = metadata_av.CopyRef(),
tensor_ind_av = tensor_ind_av]() mutable {
if (result_av->IsError()) {
device_av.SetError(result_av->GetError());
metadata_av.SetError(result_av->GetError());
tensor_ind_av->SetError(result_av->GetError());
return;
}
auto& th = result_av->get<TensorHandle>();
if (th.IsDeviceAvailable()) {
device_av.emplace(th.GetAvailableDevice());
} else {
th.GetAsyncDevice().AndThen(
[th_device = th.GetAsyncDevice().CopyRef(),
device_av = std::move(device_av)]() {
if (th_device.IsError()) {
device_av.SetError(th_device.GetError());
} else {
device_av.emplace(th_device.get());
}
});
}
if (th.IsMetadataAvailable()) {
metadata_av.emplace(th.GetAvailableMetadata());
} else {
th.GetAsyncMetadata().AndThen(
[th_metadata = th.GetAsyncMetadata().CopyRef(),
metadata_av = std::move(metadata_av)]() {
if (th_metadata.IsError()) {
metadata_av.SetError(th_metadata.GetError());
} else {
metadata_av.emplace(th_metadata.get());
}
});
}
tensor_ind_av->ForwardTo(FormRef(th.GetAsyncTensor()));
});
invocation.results[i] =
TensorHandle(std::move(device_av), std::move(metadata_av),
AsyncValueRef<Tensor>(std::move(tensor_ind_av)));
}
}
};
return CoreRuntimeOp(std::move(execute_fn), false);
}