in tensorflow_fold/loom/weaver.cc [617:714]
bool Weaver::MergeFromSerialized(const string &other) {
WeaverMessage message;
if (!message.ParseFromString(other)) {
error_string_ = "WeaverMessage couldn't be parsed.";
return false;
}
if (num_constants_by_type_shape_.size() !=
message.num_constants_by_type_shape_size()) {
error_string_ =
"WeaverMessage didn't have the same number of type-shapes.";
return false;
}
// The loom_results_ from other will get appended to the current set of
// loom_results_, so we need to save the size to know how much to offset all
// the result ID fields from 'other'.
tensor_idx_t result_id_offset = loom_results_.size();
// This block copies message's loom results into loom_results_.
//
// For LoomResults created by MakeInput, no update need occur if this was one
// of the inputs shared between the looms (named tensors.) Otherwise we
// shift by the number of non-shared inputs in the current Scheduler
// (constants.)
//
// For LoomResults created by CallOp, pos_idx needs to be shifted by the
// number of times the op has been called at this depth.
//
// For any LoomResult, cached_passthrough needs to be shifted by
// result_id_offset if it was set.
//
// Note: this copy needs to happen before wiring results get copied over so
// that the call-counts are accurate, and before num_inputs_by_type_shape_ is
// updated.
tensor_idx_t other_num_loom_results = message.depth_size();
for (tensor_idx_t i = 0; i < other_num_loom_results; ++i) {
loom_results_.emplace_back(LoomResultFromMessage(message, i));
auto &r = loom_results_.back();
tensor_idx_t pos_idx_offset;
if (r.depth == 0) {
if (r.pos_idx < num_named_tensors_by_ts_idx_[r.ts_idx]) {
// NamedTensor case.
pos_idx_offset = 0;
} else { // Constant/BatchInput case.
// Note: for BatchInput typeshapes this is always 0.
pos_idx_offset = num_constants_by_type_shape_[r.ts_idx];
}
} else { // Op output case.
auto key = std::make_tuple(r.depth, r.op_idx, 0);
pos_idx_offset = wiring_results_[key].size();
}
r.pos_idx += pos_idx_offset;
if (r.cached_passthrough != -1) {
r.cached_passthrough += result_id_offset;
}
// Update deepest_.
deepest_ = std::max(deepest_, r.depth);
}
// Update num_inputs_by_type_shape_ from message's constants.
//
// Also concatenate the lists of constant values.
for (tensor_idx_t ts_idx = 0; ts_idx < num_type_shapes_; ++ts_idx) {
num_constants_by_type_shape_[ts_idx] +=
message.num_constants_by_type_shape(ts_idx);
Tensor constants(metadata_.type_shape_metadata(ts_idx).dtype());
if (!constants.FromProto(message.constant_values_by_type_shape(ts_idx))) {
error_string_ = StrCat(
"Conversion from TensorProto to Tensor failed in deserialization. ",
"ts_idx=", ts_idx);
return false;
}
auto unstacked = UnstackTensors(constants);
constant_values_by_type_shape_[ts_idx].insert(
constant_values_by_type_shape_[ts_idx].end(),
unstacked.begin(), unstacked.end());
}
// Copy over message.wiring into wiring_results_
// (Shifting all the result IDs by result_id_offset.)
for (const auto &w : message.wiring()) {
auto &ids = wiring_results_[
std::make_tuple(w.depth(), w.op_idx(), w.arg_idx())];
for (tensor_idx_t result_id : w.result_id()) {
ids.push_back(result_id_offset + result_id);
}
}
// Copy over outputs_result_ids (shifting by result_id_offset.)
for (tensor_idx_t result_id : message.output_result_id()) {
output_result_ids_.push_back(result_id_offset + result_id);
}
return true;
}