void WeaverOpBase::Compute()

in tensorflow_fold/loom/weaver_op_base.cc [80:128]


void WeaverOpBase::Compute(OpKernelContext *c) {
  Weaver weaver(metadata_str_);
  OP_REQUIRES(c, weaver.error_string().empty(), InvalidArgument(
      "Couldn't initialize weaver from metadata: ", weaver.error_string()));
  OP_REQUIRES_OK(c, Weave(c, &weaver));
  weaver.Finalize();

  // Output the Weaver's wirings:
  std::vector<tensor_idx_t> arg_wiring_concat;
  std::vector<tensor_idx_t> arg_wiring_slice_starts;
  std::vector<tensor_idx_t> arg_wiring_slice_sizes;
  tensor_idx_t max_depth = weaver.MaxDepth();
  tensor_idx_t num_ops = weaver.NumOps();
  std::vector<tensor_idx_t> num_args;
  for (tensor_idx_t op_idx = 0; op_idx < num_ops; ++op_idx) {
    num_args.push_back(weaver.InputTypeShapes(op_idx).size());
  }
  for (tensor_idx_t depth = 1; depth <= max_depth; ++depth) {
    for (tensor_idx_t op_idx = 0; op_idx < num_ops; ++op_idx) {
      for (tensor_idx_t arg_idx = 0; arg_idx < num_args[op_idx]; ++arg_idx) {
        arg_wiring_slice_starts.push_back(arg_wiring_concat.size());
        const std::vector<tensor_idx_t> &wiring = weaver.GetWiring(
            depth, op_idx, arg_idx);
        arg_wiring_slice_sizes.push_back(wiring.size());
        arg_wiring_concat.insert(
            arg_wiring_concat.end(), wiring.begin(), wiring.end());
      }
    }
  }

  OP_REQUIRES_OK(c, OutputTensorIdxVector(c, 0, arg_wiring_concat));
  OP_REQUIRES_OK(c, OutputTensorIdxVector(c, 1, arg_wiring_slice_starts));
  OP_REQUIRES_OK(c, OutputTensorIdxVector(c, 2, arg_wiring_slice_sizes));

  OpOutputList output_wiring_list;
  OP_REQUIRES_OK(c, c->output_list("out_3_output_wiring", &output_wiring_list));
  for (tensor_idx_t ts_idx = 0; ts_idx < num_type_shapes_; ++ts_idx) {
    OP_REQUIRES_OK(c, AddVectorToOutputList(
        weaver.GetOutputWiring(ts_idx),
        ts_idx, &output_wiring_list));
  }

  // Output the constants:
  OpOutputList constants_list;
  OP_REQUIRES_OK(c, c->output_list("out_4_constants", &constants_list));
  for (tensor_idx_t ts_idx = 0; ts_idx < num_type_shapes_; ++ts_idx) {
    constants_list.set(ts_idx, weaver.BatchConstantValues(ts_idx));
  }
}