Error Provisioner::provisionNetwork()

in lib/Runtime/Provisioner/Provisioner.cpp [243:784]


Error Provisioner::provisionNetwork(std::unique_ptr<Network> network) {
  VLOG(1) << "Started provisioner";
  DAGListTy &networks = network->networks;
  Module &module = network->module;
  CompilationContext &cctx = network->cctx;
  // Check that the requested networks don't collide with the names of any other
  // networks being added.
  std::vector<std::string> localActiveNames;
  RETURN_IF_ERR(checkActiveNetworks(networks, localActiveNames));

  // Mapping from function name to its compiled function. NB: compiledFunctions
  // will hold compiled function which might be used in clean up process by
  // cleanupGuard, hence this needs to be declared before cleanupGuard. We
  // probably should clean up the compiledFunctions logic to make this more
  // intuitive.
  llvm::StringMap<std::unique_ptr<CompiledFunction>> compiledFunctions;

  // If any error happens during the provison process, we will clean up the
  // compiled networks.
  std::map<DeviceIDTy, std::vector<std::string>> addedNetworks;
  ScopeGuard cleanupGuard([&localActiveNames, &addedNetworks, this]() {
    cleanupProvision(localActiveNames, addedNetworks);
  });

  // Walk the networks and group by logicalDeviceId.
  auto logicalDevices = generateLogicalDevices(networks);

  if (cctx.backendOpts.collectConstants) {
    VLOG(1) << "Warning: collectConstants is set in a Runtime compile, "
               "ignoring it.";
  }
  if (cctx.backendOpts.backendHints.SRAMPrioritization.size() != 0 ||
      cctx.backendOpts.backendHints.executionUnits) {
    VLOG(1) << "Warning: backendHints is set in a Runtime compile, "
               "ignoring it.";
  }

  // Set collectConstants to false, this is because the DeviceManager will
  // handle moving constants to the device, this way we can eliminate one
  // copy operation.
  cctx.backendOpts.collectConstants = false;

  // Calculate the size of each logical device.
  auto logicalDeviceSize = calculateLogicalDeviceSize(logicalDevices);

  // Get available memory for all devices.
  std::vector<std::pair<DeviceIDTy, uint64_t>> deviceMemory;
  for (unsigned i = 0; i < devices_.size(); i++) {
    uint64_t availableMemory = devices_[i]->getAvailableMemory();
    deviceMemory.push_back(std::make_pair(i, availableMemory));
  }

  // Get available device memory, create a map of vectors for each backend kind
  std::map<std::string, std::vector<std::pair<DeviceIDTy, uint64_t>>>
      deviceMemoryMap;
  for (unsigned i = 0; i < devices_.size(); i++) {
    uint64_t availableMemory = devices_[i]->getAvailableMemory();

    deviceMemoryMap[devices_[i]->getBackendName().str()].push_back(
        std::make_pair(i, availableMemory));
  }

  // Sort all vectors in descending order of available memory.
  for (auto &sizes : deviceMemoryMap) {
    std::sort(sizes.second.begin(), sizes.second.end(), sortMostMemory);
  }

  // Generate assignments between physical and logical devices.
  auto deviceAssignments = generateDeviceAssignments(
      logicalDeviceSize, deviceMemoryMap, logicalDevices);

  VLOG(1) << "Before device assignment";
  // Check for errors.
  if (!deviceAssignments) {
    RETURN_ERR(deviceAssignments.takeError());
  }
  auto assignments = std::move(*deviceAssignments);

  VLOG(1) << "Before compile";

  // Stores function name and the remaining logical device count for that
  // function.
  llvm::StringMap<size_t> remainingDeviceCount;
  // Mapping from function name to its backend options.
  llvm::StringMap<BackendOptions> optsMap;

  // Compile and load.
  // This is done one logical device at a time. All functions in a logical
  // device are compiled and then added to their assigned device. If a function
  // is in multiple logical devices it is stored so that it only needs to be
  // compiled once.
  if (network->networkType == NetworkType::GLOW_NETWORK) {
    for (auto &assignment : assignments) {
      auto logicalDevice = assignment.first;
      auto physicalDevice = assignment.second;
      auto deviceBackendName = logicalDevices[logicalDevice][0]->backendName;

      if (backends_.find(deviceBackendName) == backends_.end()) {
        // Return error requested device type not found.
        return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEVICE_NOT_FOUND,
                        "Unable to find device of type: " + deviceBackendName);
      }

      // Stores all the functions in a logical device.
      std::vector<glow::Function *> functionsToCompile;
      // Stores the compiled functions that will be added to physical device.
      FunctionMapTy functionMap;

      // Collect all the functions in a logical device.
      for (auto &node : logicalDevices[logicalDevice]) {
        // If the function name exist we don't need to compile it again.
        if (optsMap.count(node->name)) {
          remainingDeviceCount[node->name] -= 1;
          continue;
        }

        auto options = cctx.backendOpts;
        options.backendHints = node->backendHints;
        // Insert all options loaded in the Partitioner alongside options
        // previously inserted, with Partitioner options taking precedence in
        // case of a collision of keys.
        for (auto &it : node->backendSpecificOpts) {
          options.backendSpecificOpts[it.first] = it.second;
        }
        std::lock_guard<std::mutex> functionsLock(functionsLock_);
        Function *function = module.getFunction(node->name);

        functionsToCompile.push_back(function);
        optsMap.insert({function->getName(), options});
        functionReplicaCount_.emplace(node->name, node->replicationCount);
        remainingDeviceCount.insert(
            {node->name, node->logicalDevices.size() - 1});
      }

      // Compile all the functions in the logical device together.
      // We add a lock here because some backends are not threadsafe (CPU
      // backend).
      std::unique_lock<std::mutex> compileLock(functionsLock_);
      auto compiledOrErr = backends_[deviceBackendName]->compileFunctions(
          functionsToCompile, optsMap);
      VLOG(1) << "After compile";
      compileLock.unlock();

      // Dump graph and logs
      for (auto *function : functionsToCompile) {
        // Note: This needs to come after compile above because compile may
        // modify the Function as well.
        if (cctx.dumpFinalGraph) {
          auto fname = strFormat(
              "%sfinal_graph_%s_%s.dot", cctx.dumpGraphPath.c_str(),
              deviceBackendName.c_str(), function->getName().str().c_str());
          LOG(INFO) << "Dumping final graph to " << fname;
          function->dumpDAG(fname);
          // print stats of node
          std::map<std::string, int> opCounter;
          for (const auto &node : function->getNodes()) {
            opCounter[node.getKindName()]++;
          }
          std::ostringstream ss;
          ss << "Dump of Node stats for Function:\n";
          ss << folly::stringPrintf("%30s %13s \n", "NodeKind", "Count");
          for (const auto &p : opCounter) {
            ss << folly::stringPrintf("%30s %13d \n", p.first.c_str(),
                                      p.second);
          }
          LOG(INFO) << ss.str();
        }

        if (glow::flags::DumpCompilationLog) {
          llvm::SmallString<64> path;
          std::string prefix =
              llvm::formatv("{0}-{1}", cctx.compilationLogPrefix,
                            function->getName())
                  .str();
          auto tempFileRes =
              llvm::sys::fs::createTemporaryFile(prefix, "log", path);
          if (tempFileRes.value() != 0) {
            LOG(ERROR)
                << "Failed to create temp file for Glow compilation log: "
                << tempFileRes;
          }

          function->getLogContext()->dumpLog(path);
        }
      }

      // If err return it, else store compiled functions into compiledFunctions.
      if (!compiledOrErr) {
        RETURN_ERR(compiledOrErr.takeError());
      }
      auto compiled = std::move(*compiledOrErr);
      for (auto &compiledFunction : compiled) {

        // Deserialize compiled function from cctx.nameToFunctions
        if (cctx.backendOpts.useDeserialize) {
          std::string name = compiledFunction.first().str();
          if (cctx.nameToFunctions.find(name) == cctx.nameToFunctions.end()) {
            return MAKE_ERR(
                ErrorValue::ErrorCode::UNKNOWN,
                "Cannot find compiled function when deserializing " + name);
          }
          RETURN_IF_ERR(compiledFunction.second->deserialize(
              *(cctx.nameToFunctions.find(name)->second)));
        }
        compiledFunctions.try_emplace(compiledFunction.first(),
                                      std::move(compiledFunction.second));
      }
      // Construnct functionMap for physical device.
      for (auto &node : logicalDevices[logicalDevice]) {
        RETURN_ERR_IF_NOT(compiledFunctions.count(node->name),
                          "Can't find corresponding compiled function " +
                              node->name);

        auto *compiledFunction = compiledFunctions[node->name].get();
        functionMap.emplace(node->name, compiledFunction);

        for (unsigned i = 1; i < node->replicationCount; i++) {
          auto replicatedName = getReplicatedName(node->name, i);
          functionMap.emplace(replicatedName, compiledFunction);
        }

        // Dump backend-specific IR
        if (glow::flags::DumpBackendSpecificIRJSON) {
          compiledFunction->dumpJSON(strFormat("%sbackend_specific_ir_%s.json",
                                               cctx.dumpGraphPath.c_str(),
                                               node->name.c_str()));
        }

        node->runtimeBundle = glow::make_unique<RuntimeBundle>(
            compiledFunction->getRuntimeBundle());
      }

      // Now that the functions are compiled add them to their assigned device
      // then cleanup.
      std::promise<void> addPromise;
      auto ready = addPromise.get_future();
      std::unique_ptr<Error> addErr;
      devices_[physicalDevice]->addNetwork(
          &module, functionMap,
          [&addErr, &addPromise](const Module *, Error err) {
            addErr = glow::make_unique<Error>(std::move(err));
            addPromise.set_value();
          });
      ready.wait();
      DCHECK_NOTNULL(addErr.get());
      if (*addErr.get()) {
        return std::move(*addErr.get());
      }

      // Add networks successfully loaded on device to addedNetworks, this way
      // if we fail later we can evict them.
      for (const auto &func : functionMap) {
        addedNetworks[physicalDevice].push_back(func.first);
      }
      VLOG(1) << "Added networks";

      // Free up memory no longer needed by the compiledFunction.
      for (auto &node : logicalDevices[logicalDevice]) {
        // If the compiled function still needs to be added to other device,
        // don't free the resources.
        if (remainingDeviceCount[node->name] > 0) {
          continue;
        }

        // Free compilation resources. This need to be done after add network
        // and before move on to next logical device. If
        // DisableFreeCompilationResource is true, we will not free it here.
        // This is used in scenarios like model serialization.
        auto &funtionPtr = compiledFunctions[node->name];
        if (!glow::flags::DisableFreeCompilationResource) {
          funtionPtr->freeCompilationResources();
        }

        // Move compiled functions from compiledFunctions to functions_.
        {
          std::lock_guard<std::mutex> functionsLock(functionsLock_);
          functions_.emplace(node->name, std::move(funtionPtr));
        }

        compiledFunctions.erase(node->name);
      }
    }
  } else if (network->networkType == NetworkType::FX_NETWORK) {
#if FACEBOOK_INTERNAL
    // Container for duplicated functions and map tracking remaining installs
    // for a duplicated function.
    std::map<std::string, std::unique_ptr<CompiledFunction>>
        duplicatedFunctions;
    std::map<DAGNode *, unsigned> remainingDuplications;
    for (auto &assignment : assignments) {
      auto logicalDevice = assignment.first;
      auto physicalDevice = assignment.second;
      auto deviceBackendName = logicalDevices[logicalDevice][0]->backendName;
      FunctionMapTy functionMap;
      // Container for the compiledFunctions for this logicalDevice.
      std::map<std::string, std::unique_ptr<CompiledFunction>>
          compiledFunctions;

      for (auto &node : logicalDevices[logicalDevice]) {
        // Check if this is a duplicated function that has already been
        // compiled.
        if (duplicatedFunctions.find(node->name) != duplicatedFunctions.end()) {
          functionMap.emplace(node->name,
                              duplicatedFunctions[node->name].get());
          remainingDuplications[node] -= 1;
        } else {
          // Compile and add to function map.
          auto options = cctx.backendOpts;
          options.backendHints = node->backendHints;
          // Insert all options loaded in the Partitioner alongside options
          // previously inserted, with Partitioner options taking precedence in
          // case of a collision of keys.
          for (auto &it : node->backendSpecificOpts) {
            options.backendSpecificOpts[it.first] = it.second;
          }
          if (backends_.find(deviceBackendName) == backends_.end()) {
            // Return error requested device type not found.
            return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEVICE_NOT_FOUND,
                            "Unable to find device of type: " +
                                deviceBackendName);
          }
          auto fxNetwork = static_cast<FXNetwork *>(network.get());
          auto compiledOrErr = backends_[deviceBackendName]->compileFX(
              fxNetwork->FXIR, node->name, fxNetwork->constants, options,
              &module);

          // Check to see if an error was encountered while compiling.
          if (!compiledOrErr) {
            // If an error occured return the error.
            RETURN_ERR(compiledOrErr.takeError());
          }
          auto compiled = std::move(*compiledOrErr);

          node->runtimeBundle =
              glow::make_unique<RuntimeBundle>(compiled->getRuntimeBundle());

          functionMap.emplace(node->name, compiled.get());
          // If this function is in more than one logical device store it for
          // reuse.
          if (node->logicalDevices.size() > 1) {
            duplicatedFunctions.emplace(node->name, std::move(compiled));
            remainingDuplications[node] = node->logicalDevices.size() - 1;
          } else {
            compiledFunctions.emplace(node->name, std::move(compiled));
          }
        }
      }
      VLOG(1) << "After compile";

      // Now that the functions are compiled add them to their assigned device
      // then cleanup.
      std::promise<void> addPromise;
      auto ready = addPromise.get_future();
      std::unique_ptr<Error> addErr;
      devices_[physicalDevice]->addNetwork(
          &module, functionMap,
          [&addErr, &addPromise](const Module *, Error err) {
            addErr = glow::make_unique<Error>(std::move(err));
            addPromise.set_value();
          });
      ready.wait();
      DCHECK_NOTNULL(addErr.get());
      if (*addErr.get()) {
        return std::move(*addErr.get());
      }
      // Add networks successfully loaded on device to addedNetworks, this way
      // if we fail later we can evict them.
      for (auto &node : logicalDevices[logicalDevice]) {
        addedNetworks[physicalDevice].push_back(node->name);
      }
      VLOG(1) << "Added networks";

      // Free up memory no longer needed by the compiledFunction.
      for (auto &func : compiledFunctions) {
        func.second->freeCompilationResources();
      }
      {
        // Move compiled functions from compiledFunctions to functions_.
        std::lock_guard<std::mutex> functionsLock(functionsLock_);
        for (auto &func : compiledFunctions) {
          functions_.emplace(func.first, std::move(func.second));
        }
        // Check if any of the duplicated functions can also be moved.
        for (auto iter = remainingDuplications.begin();
             iter != remainingDuplications.end();) {
          const auto &func = *iter;
          if (func.second == 0) {
            duplicatedFunctions[func.first->name]->freeCompilationResources();
            functions_.emplace(
                func.first->name,
                std::move(duplicatedFunctions[func.first->name]));
            duplicatedFunctions.erase(func.first->name);
            iter = remainingDuplications.erase(iter);
          } else {
            ++iter;
          }
        }
      }
    }
#endif
  }
  RETURN_ERR_IF_NOT(compiledFunctions.empty(),
                    "compiledFunctions should be empty because all compiled "
                    "functions should be moved to Provisioner::function_");

  // Map from Placeholder* to DeviceManager, this is used for deferred weight
  // loading.
  std::unordered_map<Placeholder *, std::vector<unsigned>>
      placeholderToDeviceManager;
  if (cctx.deferredWeightLoader) {
    // Populate placeholdeToDeviceManager map.
    for (auto &assignment : assignments) {
      for (const auto &node : logicalDevices[assignment.first]) {
        auto symbolTable = node->runtimeBundle->getSymbolTable();
        for (auto info : symbolTable) {
          if (info.second.symbolCategory ==
              glow::runtime::SymbolCategory::Placeholder) {
            auto PH = module.getPlaceholderByNameSlow(info.first);
            if (PH->isStatic()) {
              placeholderToDeviceManager[PH].push_back(assignment.second);
            }
          }
        }
      }
    }
  } else {
    // Make sure there are no static placeholders.
    for (auto PH : module.getPlaceholders()) {
      if (PH->isStatic()) {
        return MAKE_ERR(
            ErrorValue::ErrorCode::RUNTIME_ERROR,
            llvm::formatv("Error Placholder: {0} is marked as static but no "
                          "deferredWeightLoader is provided.",
                          PH->getName())
                .str());
        ;
      }
    }
  }
  // If a deferredWeightLoader is provided, create a deferredWeightLoader and
  // load deferred weights.
  if (cctx.deferredWeightLoader) {
    const size_t totalNumDeferredWeights = placeholderToDeviceManager.size();
    LOG(INFO) << "Loading " << totalNumDeferredWeights << " deferred weights";

    auto startTime = std::chrono::steady_clock::now();
    auto loader = cctx.deferredWeightLoader;
    // Load the first weight.
    auto err = loader->loadNextWeight();
    if (err) {
      auto val = takeErrorValue(std::move(err));
      std::string msg = val->logToString();
      return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR,
                      msg);
    }
    std::string weightName = loader->getName();
    // Load weights while there are weights to be loaded.
    unsigned int weightCount = 0;
    while (weightName != "") {
      LOG(INFO) << "Loading deferred weight (" << ++weightCount << " / "
                << totalNumDeferredWeights << "): " << weightName;
      const auto PH = module.getPlaceholderByNameSlow(weightName);
      if (!PH) {
        return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR,
                        llvm::formatv("Error loading deferred weight. Name: "
                                      "{0} not found in module.",
                                      weightName)
                            .str());
      }
      // Convert the weight if needed.
      auto newTy = PH->getType();
      auto weight = loader->getTensor();
      auto oldKind = weight->getElementType();
      // Ensure we are working with a static PH.
      assert(PH->isStatic());
      if (!weight->getType().isEqual(newTy)) {
        ElemKind newK = newTy->getElementType();

        if (!isQuantizedElemKind(oldKind) && isQuantizedElemKind(newK)) {
          Tensor QT = quantization::quantizeTensor(
              *weight, {newTy->getScale(), newTy->getOffset()}, newK);
          weight->assign(&QT);
        } else {
          weight->convertToType(newK);
        }
      }
      // Transfer weight to all devices needed.
      std::list<Error> errors;
      std::list<std::future<void>> futures;
      for (const auto &device : placeholderToDeviceManager[PH]) {
        std::promise<void> transferPromise;
        errors.emplace_back(Error::empty());
        futures.emplace_back(transferPromise.get_future());
        devices_[device]->transferStaticPlaceholderToDevice(
            PH, weight,
            [&transferPromise, &error = errors.back()](Error err) mutable {
              error = std::move(err);
              transferPromise.set_value();
            });
      }

      for (auto &done : futures) {
        done.get();
      }

      for (auto &error : errors) {
        RETURN_IF_ERR(error);
      }

      err = loader->loadNextWeight();
      if (err) {
        auto val = takeErrorValue(std::move(err));
        std::string msg = val->logToString();
        return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR,
                        msg);
      }
      weightName = loader->getName();
      // Remove PH from map, this way we can know that we've added all static
      // PH's
      placeholderToDeviceManager.erase(PH);
    }
    if (placeholderToDeviceManager.size()) {
      return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR,
                      "Error not all static placeholders were initialized.");
    }

    std::chrono::duration<double> duration =
        std::chrono::steady_clock::now() - startTime;
    LOG(INFO) << "Done loading deferred weights in " << duration.count()
              << " seconds";
  }
  // Init alternate name states.
  for (auto &network : networks) {
    for (auto &node : network.nodes) {
      node->initAlternateState();
    }
  }

  cleanupGuard.dismiss();
  cleanupProvision(localActiveNames, {}, false);
  return Error::success();
};