bool Weaver::MergeFromSerialized()

in tensorflow_fold/loom/weaver.cc [617:714]


bool Weaver::MergeFromSerialized(const string &other) {
  WeaverMessage message;
  if (!message.ParseFromString(other)) {
    error_string_ = "WeaverMessage couldn't be parsed.";
    return false;
  }

  if (num_constants_by_type_shape_.size() !=
      message.num_constants_by_type_shape_size()) {
    error_string_ =
        "WeaverMessage didn't have the same number of type-shapes.";
    return false;
  }

  // The loom_results_ from other will get appended to the current set of
  // loom_results_, so we need to save the size to know how much to offset all
  // the result ID fields from 'other'.
  tensor_idx_t result_id_offset = loom_results_.size();

  // This block copies message's loom results into loom_results_.
  //
  // For LoomResults created by MakeInput, no update need occur if this was one
  // of the inputs shared between the looms (named tensors.)  Otherwise we
  // shift by the number of non-shared inputs in the current Scheduler
  // (constants.)
  //
  // For LoomResults created by CallOp, pos_idx needs to be shifted by the
  // number of times the op has been called at this depth.
  //
  // For any LoomResult, cached_passthrough needs to be shifted by
  // result_id_offset if it was set.
  //
  // Note: this copy needs to happen before wiring results get copied over so
  // that the call-counts are accurate, and before num_inputs_by_type_shape_ is
  // updated.
  tensor_idx_t other_num_loom_results = message.depth_size();
  for (tensor_idx_t i = 0; i < other_num_loom_results; ++i) {
    loom_results_.emplace_back(LoomResultFromMessage(message, i));
    auto &r = loom_results_.back();

    tensor_idx_t pos_idx_offset;
    if (r.depth == 0) {
      if (r.pos_idx < num_named_tensors_by_ts_idx_[r.ts_idx]) {
        // NamedTensor case.
        pos_idx_offset = 0;
      } else {  // Constant/BatchInput case.
        // Note: for BatchInput typeshapes this is always 0.
        pos_idx_offset = num_constants_by_type_shape_[r.ts_idx];
      }
    } else {  // Op output case.
      auto key = std::make_tuple(r.depth, r.op_idx, 0);
      pos_idx_offset = wiring_results_[key].size();
    }
    r.pos_idx += pos_idx_offset;
    if (r.cached_passthrough != -1) {
      r.cached_passthrough += result_id_offset;
    }

    // Update deepest_.
    deepest_ = std::max(deepest_, r.depth);
  }


  // Update num_inputs_by_type_shape_ from message's constants.
  //
  // Also concatenate the lists of constant values.
  for (tensor_idx_t ts_idx = 0; ts_idx < num_type_shapes_; ++ts_idx) {
    num_constants_by_type_shape_[ts_idx] +=
        message.num_constants_by_type_shape(ts_idx);
    Tensor constants(metadata_.type_shape_metadata(ts_idx).dtype());
    if (!constants.FromProto(message.constant_values_by_type_shape(ts_idx))) {
      error_string_ = StrCat(
          "Conversion from TensorProto to Tensor failed in deserialization.  ",
          "ts_idx=", ts_idx);
      return false;
    }
    auto unstacked = UnstackTensors(constants);
    constant_values_by_type_shape_[ts_idx].insert(
        constant_values_by_type_shape_[ts_idx].end(),
        unstacked.begin(), unstacked.end());
  }

  // Copy over message.wiring into wiring_results_
  // (Shifting all the result IDs by result_id_offset.)
  for (const auto &w : message.wiring()) {
    auto &ids = wiring_results_[
        std::make_tuple(w.depth(), w.op_idx(), w.arg_idx())];
    for (tensor_idx_t result_id : w.result_id()) {
      ids.push_back(result_id_offset + result_id);
    }
  }

  // Copy over outputs_result_ids (shifting by result_id_offset.)
  for (tensor_idx_t result_id : message.output_result_id()) {
    output_result_ids_.push_back(result_id_offset + result_id);
  }
  return true;
}