void PipeImpl::onReadWhileClientWaitingForBrochureAnswer()

in tensorpipe/core/pipe_impl.cc [1098:1206]


void PipeImpl::onReadWhileClientWaitingForBrochureAnswer(
    const BrochureAnswer& nopBrochureAnswer) {
  TP_DCHECK(context_->inLoop());
  TP_DCHECK_EQ(state_, CLIENT_WAITING_FOR_BROCHURE_ANSWER);

  const std::string& transport = nopBrochureAnswer.transport;
  std::string address = nopBrochureAnswer.address;
  std::shared_ptr<transport::Context> transportContext =
      context_->getTransport(transport);
  TP_DCHECK(transportContext->canCommunicateWithRemote(
      nopBrochureAnswer.transportDomainDescriptor))
      << "The two endpoints disagree on whether transport " << transport
      << " can be used to communicate";

  if (transport != transport_) {
    TP_VLOG(3) << "Pipe " << id_
               << " is opening connection (descriptor, as replacement)";
    std::shared_ptr<transport::Connection> connection =
        transportContext->connect(address);
    connection->setId(id_ + ".d.tr_" + transport);
    const auto& transportRegistrationIter =
        nopBrochureAnswer.transportRegistrationIds.find(
            ConnectionId::DESCRIPTOR);
    TP_DCHECK(
        transportRegistrationIter !=
        nopBrochureAnswer.transportRegistrationIds.end());
    initConnection(*connection, transportRegistrationIter->second);

    transport_ = transport;
    descriptorConnection_ = std::move(connection);
  }

  {
    TP_VLOG(3) << "Pipe " << id_ << " is opening connection (descriptor_reply)";
    std::shared_ptr<transport::Connection> connection =
        transportContext->connect(address);
    connection->setId(id_ + ".r.tr_" + transport);
    const auto& transportRegistrationIter =
        nopBrochureAnswer.transportRegistrationIds.find(
            ConnectionId::DESCRIPTOR_REPLY);
    TP_DCHECK(
        transportRegistrationIter !=
        nopBrochureAnswer.transportRegistrationIds.end());
    initConnection(*connection, transportRegistrationIter->second);

    descriptorReplyConnection_ = std::move(connection);
  }

  // Recompute the channel map based on this side's channels and priorities.
  SelectedChannels selectedChannels = selectChannels(
      context_->getOrderedChannels(),
      nopBrochureAnswer.channelDeviceDescriptors);
  channelForDevicePair_ = std::move(selectedChannels.channelForDevicePair);

  // Verify that the locally and remotely computed channel maps are consistent.
  TP_THROW_ASSERT_IF(
      nopBrochureAnswer.channelForDevicePair.size() !=
      channelForDevicePair_.size())
      << "Inconsistent channel selection";
  for (const auto& iter : channelForDevicePair_) {
    Device localDevice;
    Device remoteDevice;
    std::tie(localDevice, remoteDevice) = iter.first;
    const std::string& channelName = iter.second;

    const auto& answerIter = nopBrochureAnswer.channelForDevicePair.find(
        {remoteDevice, localDevice});

    TP_THROW_ASSERT_IF(
        answerIter == nopBrochureAnswer.channelForDevicePair.end())
        << "Inconsistent channel selection";
    TP_THROW_ASSERT_IF(answerIter->second != channelName)
        << "Inconsistent channel selection";
  }

  for (const auto& channelDeviceDescriptorsIter :
       selectedChannels.descriptorsMap) {
    const std::string& channelName = channelDeviceDescriptorsIter.first;
    std::shared_ptr<channel::Context> channelContext =
        context_->getChannel(channelName);

    const std::vector<uint64_t>& registrationIds =
        nopBrochureAnswer.channelRegistrationIds.at(channelName);
    const size_t numConnectionsNeeded = channelContext->numConnectionsNeeded();
    TP_DCHECK_EQ(numConnectionsNeeded, registrationIds.size());
    std::vector<std::shared_ptr<transport::Connection>> connections(
        numConnectionsNeeded);
    for (size_t connId = 0; connId < numConnectionsNeeded; ++connId) {
      TP_VLOG(3) << "Pipe " << id_ << " is opening connection " << connId << "/"
                 << numConnectionsNeeded << " (for channel " << channelName
                 << ")";
      std::shared_ptr<transport::Connection> connection =
          transportContext->connect(address);
      connection->setId(
          id_ + ".ch_" + channelName + "_" + std::to_string(connId));
      initConnection(*connection, registrationIds[connId]);
      connections[connId] = std::move(connection);
    }

    std::shared_ptr<channel::Channel> channel = channelContext->createChannel(
        std::move(connections), channel::Endpoint::kConnect);
    channel->setId(id_ + ".ch_" + channelName);
    channels_.emplace(channelName, std::move(channel));
  }

  state_ = ESTABLISHED;
  readOps_.advanceAllOperations();
  writeOps_.advanceAllOperations();
}