Expected CoreRuntime::MakeCompositeOp()

in lib/core_runtime/core_runtime.cc [240:364]


Expected<CoreRuntimeOp> CoreRuntime::MakeCompositeOp(const Function* fn) {
  for (auto iter : llvm::enumerate(fn->argument_types().drop_front())) {
    size_t i = iter.index();
    auto& type = iter.value();
    if (type.GetName() != kTensorHandleType) {
      return MakeStringError("The function should only takes type [",
                             kTensorHandleType, "] as input. But the ", i,
                             "-th argument is type [", type.GetName(), "].");
    }
  }
  for (auto iter : llvm::enumerate(fn->result_types().drop_front())) {
    size_t i = iter.index();
    auto& type = iter.value();
    if (type.GetName() != kTensorHandleType) {
      return MakeStringError("The function should only returns type [",
                             kTensorHandleType, "]. But the ", i,
                             "-th results is type [", type.GetName(), "].");
    }
  }
  auto execute_fn = [fn = fn](const OpInvocation& invocation) {
    auto* host = invocation.exec_ctx.host();

    // TODO(fishx): Return an error to the client instead of asserting.
    assert(invocation.arguments.size() + 1 == fn->argument_types().size());
    assert(invocation.results.size() + 1 == fn->result_types().size());

    llvm::SmallVector<AsyncValue*, 4> arguments;
    llvm::SmallVector<RCReference<AsyncValue>, 4> arguments_ref;
    arguments.reserve(fn->argument_types().size());
    arguments_ref.reserve(fn->argument_types().size());

    // The first argument is a chain for side-effects.
    if (invocation.chain && *invocation.chain) {
      arguments.push_back(invocation.chain->GetAsyncValue());
    } else {
      arguments_ref.push_back(GetReadyChain());
      arguments.push_back(arguments_ref.back().get());
    }

    for (size_t i = 0, e = invocation.arguments.size(); i != e; ++i) {
      arguments_ref.push_back(MakeAvailableAsyncValueRef<TensorHandle>(
          host, invocation.arguments[i].CopyRef()));
      arguments.push_back(arguments_ref.back().get());

      // Clean up the argument to enable input forwarding.
      invocation.arguments[i] = TensorHandle();
    }

    llvm::SmallVector<RCReference<AsyncValue>, 4> results;
    results.resize(fn->result_types().size());

    fn->Execute(invocation.exec_ctx, arguments, results);

    // The first result is the a chain for side-effects.
    if (invocation.chain)
      *invocation.chain = AsyncValueRef<Chain>(std::move(results[0]));

    for (auto iter : llvm::enumerate(llvm::drop_begin(results, 1))) {
      size_t i = iter.index();
      auto& result_av = iter.value();
      if (result_av->IsAvailable()) {
        if (result_av->IsError()) {
          invocation.results[i] =
              TensorHandle(AsyncValueRef<TensorHandle>(std::move(result_av)));
        } else {
          assert(result_av->IsType<TensorHandle>());
          invocation.results[i] = result_av->get<TensorHandle>().CopyRef();
        }
      } else {
        auto device_av =
            MakeUnconstructedAsyncValueRef<RCReference<Device>>(host);
        auto metadata_av = MakeUnconstructedAsyncValueRef<TensorMetadata>(host);
        auto tensor_ind_av = MakeIndirectAsyncValue(host);

        result_av->AndThen([result_av = result_av,
                            device_av = device_av.CopyRef(),
                            metadata_av = metadata_av.CopyRef(),
                            tensor_ind_av = tensor_ind_av]() mutable {
          if (result_av->IsError()) {
            device_av.SetError(result_av->GetError());
            metadata_av.SetError(result_av->GetError());
            tensor_ind_av->SetError(result_av->GetError());
            return;
          }
          auto& th = result_av->get<TensorHandle>();

          if (th.IsDeviceAvailable()) {
            device_av.emplace(th.GetAvailableDevice());
          } else {
            th.GetAsyncDevice().AndThen(
                [th_device = th.GetAsyncDevice().CopyRef(),
                 device_av = std::move(device_av)]() {
                  if (th_device.IsError()) {
                    device_av.SetError(th_device.GetError());
                  } else {
                    device_av.emplace(th_device.get());
                  }
                });
          }

          if (th.IsMetadataAvailable()) {
            metadata_av.emplace(th.GetAvailableMetadata());
          } else {
            th.GetAsyncMetadata().AndThen(
                [th_metadata = th.GetAsyncMetadata().CopyRef(),
                 metadata_av = std::move(metadata_av)]() {
                  if (th_metadata.IsError()) {
                    metadata_av.SetError(th_metadata.GetError());
                  } else {
                    metadata_av.emplace(th_metadata.get());
                  }
                });
          }

          tensor_ind_av->ForwardTo(FormRef(th.GetAsyncTensor()));
        });

        invocation.results[i] =
            TensorHandle(std::move(device_av), std::move(metadata_av),
                         AsyncValueRef<Tensor>(std::move(tensor_ind_av)));
      }
    }
  };
  return CoreRuntimeOp(std::move(execute_fn), false);
}