void ComputeConstant()

in torch/csrc/jit/passes/onnx/shape_type_inference.cpp [1300:1610]


void ComputeConstant(Node* n, int opset_version) {
  if (n->kind() == ::c10::onnx::Constant) {
    if (n->kindOf(attr::value) == AttributeKind::t) {
      at::Tensor const_val = n->t(attr::value);
      at::Tensor const_val_copy =
          at::empty(const_val.sizes(), const_val.options());
      const_val_copy.copy_(const_val);
      ConstantValueMap::SetValue(n->output()->debugName(), const_val_copy);
    }
    return;
  }
  auto only_rank_available = false;
  size_t rank = 0;

  // Constant folding.
  auto const_fold_val = ComputeConstantFolding(n, opset_version);
  if (const_fold_val.has_value()) {
    at::Tensor const_fold_val_copy = at::empty(
        const_fold_val.value().sizes(), const_fold_val.value().options());
    const_fold_val_copy.copy_(const_fold_val.value());
    ConstantValueMap::SetValue(n->output()->debugName(), const_fold_val_copy);
    UpdateShapeFromVector(n->output(), const_fold_val_copy.sizes().vec());
    return;
  }

  switch (n->kind()) {
    case ::c10::onnx::Add:
    case ::c10::onnx::Div:
    case ::c10::onnx::Equal:
    case ::c10::onnx::Greater:
    case ::c10::onnx::GreaterOrEqual:
    case ::c10::onnx::Less:
    case ::c10::onnx::LessOrEqual:
    case ::c10::onnx::Mod:
    case ::c10::onnx::Mul:
    case ::c10::onnx::Pow:
    case ::c10::onnx::Sub: {
      ProcessBroadcastNode(n);
      break;
    }
    case ::c10::onnx::Shape: {
      auto input_shape =
          ConstantValueMap::GetShapeInto1DInt64Vector(n->input()->debugName());
      if (input_shape.has_value()) {
        auto shape_value = input_shape.value();
        // TODO: getDevice() ?
        auto options = c10::TensorOptions().dtype(at::kLong).device(at::kCPU);
        auto shape_value_size = static_cast<int64_t>(shape_value.size());
        auto f =
            at::from_blob(shape_value.data(), {shape_value_size}, at::kLong)
                .to(at::kCPU);
        // Need copy here
        at::Tensor f_copy = at::empty({shape_value_size}, options);
        f_copy.copy_(f);
        ConstantValueMap::SetValue(n->output()->debugName(), f_copy);
        std::vector<::c10::ShapeSymbol> final_shape_vector(
            1, c10::ShapeSymbol::fromStaticSize(shape_value_size));
        ::c10::SymbolicShape final_shape(final_shape_vector);
        UpdateShape(n->output(), final_shape);
      } else if (ConstantValueMap::HasShape(n->input()->debugName())) {
        ConstantValueMap::SetShapeValue(
            n->output()->debugName(),
            ConstantValueMap::GetShape(n->input()->debugName()).value());
      }
      break;
    }
    case ::c10::onnx::Reshape: {
      ProcessReshapeNode(n, opset_version);
      break;
    }
    case ::c10::onnx::Gather: {
      if (ConstantValueMap::HasRank(n->input(0)->debugName()) &&
          ConstantValueMap::HasRank(n->input(1)->debugName())) {
        auto rank_0 =
            ConstantValueMap::GetRank(n->input(0)->debugName()).value();
        auto rank_1 =
            ConstantValueMap::GetRank(n->input(1)->debugName()).value();
        only_rank_available = true;
        rank = rank_0 + rank_1 - 1;
      }
      if (ConstantValueMap::HasShapeValue(n->input(0)->debugName()) &&
          ConstantValueMap::HasValue(n->input(1)->debugName())) {
        auto shape_value =
            ConstantValueMap::GetShapeValue(n->input(0)->debugName()).value();
        auto idx_value =
            ConstantValueMap::GetValue(n->input(1)->debugName()).value();
        // Consider the case when Gather index is a scalar.
        if (idx_value.dim() == 0) {
          auto idx_value_0 = idx_value.item<int64_t>();
          if (idx_value_0 >= 0) {
            std::vector<c10::ShapeSymbol> dims = {shape_value.at(idx_value_0)};
            c10::SymbolicShape symShape(dims);
            ConstantValueMap::SetShapeValue(n->output()->debugName(), symShape);
          }
        }
      }
      break;
    }
    case ::c10::onnx::Transpose: {
      if (n->hasAttributeS("perm")) {
        auto perm_v = n->is(attr::perm);
        rank = perm_v.size();
        auto is_default_perm = false;
        if (rank == 2 && perm_v[0] == 1 && perm_v[1] == 0) {
          is_default_perm = true;
        }
        auto shape_updated = false;
        if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
          auto shape_size_0 =
              ConstantValueMap::GetShape(n->input(0)->debugName())
                  .value()
                  .sizes();
          if (shape_size_0.has_value()) {
            auto shape_vector_0 = shape_size_0.value();
            std::vector<::c10::ShapeSymbol> final_shape_vector(
                shape_vector_0.size(), ::c10::ShapeSymbol());
            if (is_default_perm) {
              std::reverse_copy(
                  std::begin(shape_vector_0),
                  std::end(shape_vector_0),
                  std::begin(final_shape_vector));
            } else {
              for (const auto i : c10::irange(shape_vector_0.size())) {
                final_shape_vector[i] = shape_vector_0[perm_v[i]];
              }
            }
            ::c10::SymbolicShape final_shape(final_shape_vector);
            UpdateShape(n->output(), final_shape);
            shape_updated = true;
          }
        }
        if (!shape_updated) {
          if (!is_default_perm) {
            only_rank_available = true;
          } else if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
            rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
            only_rank_available = true;
          }
        }
      }
      break;
    }
    case ::c10::onnx::Concat: {
      ProcessConcatNode(n);
      break;
    }
    case ::c10::onnx::ConstantOfShape: {
      if (ConstantValueMap::HasValue(n->input()->debugName())) {
        auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
            n->input()->debugName());
        UpdateShapeFromVector(n->output(), shape_temp);
        if (!shape_temp.empty()) {
          if (n->hasAttributeS("value")) {
            auto value = n->t(attr::value).repeat(shape_temp);
            ConstantValueMap::SetValue(n->output()->debugName(), value);
          } else {
            auto options =
                c10::TensorOptions().dtype(at::kFloat).device(at::kCPU);
            auto value = at::full({1}, 0.0, options).repeat(shape_temp);
            ConstantValueMap::SetValue(n->output()->debugName(), value);
          }
        }
      }
      break;
    }
    case ::c10::onnx::Expand: {
      if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
        auto input0_shape_size =
            ConstantValueMap::GetShape(n->input(0)->debugName())
                .value()
                .sizes();
        if (input0_shape_size.has_value()) {
          auto input0_shape_value = input0_shape_size.value();
          if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
            auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
                n->input(1)->debugName());
            auto final_shape =
                ComputeShapeFromExpand(input0_shape_value, shape_temp);
            if (final_shape.has_value()) {
              UpdateShape(n->output(), final_shape.value());
            }
          }
        }
      }
      break;
    }
    case ::c10::onnx::NonZero: {
      if (ConstantValueMap::HasRank(n->input()->debugName())) {
        auto rank = ConstantValueMap::GetRank(n->input()->debugName()).value();
        std::vector<c10::ShapeSymbol> dims;
        dims.emplace_back(
            c10::ShapeSymbol::fromStaticSize(static_cast<int64_t>(rank)));
        auto input_node = n->input()->node();
        if (input_node->kind() == ::c10::onnx::ConstantOfShape) {
          if (input_node->hasAttributeS("value")) {
            auto value =
                input_node->t(attr::value).toType(at::ScalarType::Float);
            auto value_a = value.accessor<float, 1>();
            if (value_a.size(0) == 1 && std::abs(value_a[0]) > 1e-6) {
              if (ConstantValueMap::HasShape(n->input()->debugName())) {
                auto shape_size_0 =
                    ConstantValueMap::GetShape(n->input()->debugName()).value();
                if (shape_size_0.isComplete()) {
                  auto shape_vector_0 = shape_size_0.sizes().value();
                  int64_t num_elements = 1;
                  for (auto cur_dim : shape_vector_0) {
                    num_elements *= cur_dim.static_size();
                  }
                  dims.emplace_back(c10::ShapeSymbol::fromStaticSize(
                      static_cast<int64_t>(num_elements)));
                }
              }
            }
          }
        }
        if (dims.size() == 1) {
          dims.emplace_back(c10::ShapeSymbol::newSymbol());
        }
        c10::SymbolicShape shape_v(dims);
        UpdateShape(n->output(), shape_v);
      }
      break;
    }
    case ::c10::onnx::MatMul: {
      ProcessMatMulNode(n);
      break;
    }
    case ::c10::onnx::ReduceMean:
    case ::c10::onnx::ReduceProd: {
      ProcessReduceNode(n);
      break;
    }
    case ::c10::onnx::RNN:
    case ::c10::onnx::LSTM:
    case ::c10::onnx::GRU: {
      ProcessTimeSeriesNode(n);
      break;
    }
    case ::c10::onnx::Size: {
      if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
        auto input0_shape_size =
            ConstantValueMap::GetShape(n->input(0)->debugName())
                .value()
                .sizes();
        if (input0_shape_size.has_value()) {
          auto input0_shape_value = input0_shape_size.value();
          int64_t total_size = 1;
          auto is_full_static = true;
          for (const auto i : c10::irange(input0_shape_value.size())) {
            if (input0_shape_value[i].is_static()) {
              total_size *= input0_shape_value[i].static_size();
            } else {
              is_full_static = false;
              break;
            }
          }
          if (is_full_static) {
            auto f_final = onnx_constant_fold::IntToTensor(total_size);
            ConstantValueMap::SetValue(n->output(0)->debugName(), f_final);
          }
        }
      }
      break;
    }
    case ::c10::onnx::Slice: {
      ProcessSliceNode(n, opset_version);
      break;
    }
    case ::c10::onnx::Cast:
    case ::c10::onnx::Relu:
    case ::c10::onnx::Softmax: {
      ProcessUnchangeNode(n);
      break;
    }
    case ::c10::onnx::Tile: {
      if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
        auto input0_shape_size =
            ConstantValueMap::GetShape(n->input(0)->debugName())
                .value()
                .sizes();
        if (input0_shape_size.has_value()) {
          auto input0_shape_value = input0_shape_size.value();
          if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
            auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
                n->input(1)->debugName());
            auto final_shape =
                ComputeShapeFromTile(input0_shape_value, shape_temp);
            if (final_shape.has_value()) {
              UpdateShape(n->output(), final_shape.value());
            }
          }
        }
      }
      break;
    }
    case ::c10::onnx::Unsqueeze: {
      ProcessUnsqueezeNode(n);
      break;
    }
    default: {
      break;
    }
  }
  if (n->outputs().size() > 1 ||
      ConstantValueMap::HasShape(n->output(0)->debugName())) {
    return;
  }
  if (only_rank_available) {
    UpdateRank(n->output(), rank);
  }
}