in torch/csrc/jit/passes/onnx/shape_type_inference.cpp [1300:1610]
void ComputeConstant(Node* n, int opset_version) {
if (n->kind() == ::c10::onnx::Constant) {
if (n->kindOf(attr::value) == AttributeKind::t) {
at::Tensor const_val = n->t(attr::value);
at::Tensor const_val_copy =
at::empty(const_val.sizes(), const_val.options());
const_val_copy.copy_(const_val);
ConstantValueMap::SetValue(n->output()->debugName(), const_val_copy);
}
return;
}
auto only_rank_available = false;
size_t rank = 0;
// Constant folding.
auto const_fold_val = ComputeConstantFolding(n, opset_version);
if (const_fold_val.has_value()) {
at::Tensor const_fold_val_copy = at::empty(
const_fold_val.value().sizes(), const_fold_val.value().options());
const_fold_val_copy.copy_(const_fold_val.value());
ConstantValueMap::SetValue(n->output()->debugName(), const_fold_val_copy);
UpdateShapeFromVector(n->output(), const_fold_val_copy.sizes().vec());
return;
}
switch (n->kind()) {
case ::c10::onnx::Add:
case ::c10::onnx::Div:
case ::c10::onnx::Equal:
case ::c10::onnx::Greater:
case ::c10::onnx::GreaterOrEqual:
case ::c10::onnx::Less:
case ::c10::onnx::LessOrEqual:
case ::c10::onnx::Mod:
case ::c10::onnx::Mul:
case ::c10::onnx::Pow:
case ::c10::onnx::Sub: {
ProcessBroadcastNode(n);
break;
}
case ::c10::onnx::Shape: {
auto input_shape =
ConstantValueMap::GetShapeInto1DInt64Vector(n->input()->debugName());
if (input_shape.has_value()) {
auto shape_value = input_shape.value();
// TODO: getDevice() ?
auto options = c10::TensorOptions().dtype(at::kLong).device(at::kCPU);
auto shape_value_size = static_cast<int64_t>(shape_value.size());
auto f =
at::from_blob(shape_value.data(), {shape_value_size}, at::kLong)
.to(at::kCPU);
// Need copy here
at::Tensor f_copy = at::empty({shape_value_size}, options);
f_copy.copy_(f);
ConstantValueMap::SetValue(n->output()->debugName(), f_copy);
std::vector<::c10::ShapeSymbol> final_shape_vector(
1, c10::ShapeSymbol::fromStaticSize(shape_value_size));
::c10::SymbolicShape final_shape(final_shape_vector);
UpdateShape(n->output(), final_shape);
} else if (ConstantValueMap::HasShape(n->input()->debugName())) {
ConstantValueMap::SetShapeValue(
n->output()->debugName(),
ConstantValueMap::GetShape(n->input()->debugName()).value());
}
break;
}
case ::c10::onnx::Reshape: {
ProcessReshapeNode(n, opset_version);
break;
}
case ::c10::onnx::Gather: {
if (ConstantValueMap::HasRank(n->input(0)->debugName()) &&
ConstantValueMap::HasRank(n->input(1)->debugName())) {
auto rank_0 =
ConstantValueMap::GetRank(n->input(0)->debugName()).value();
auto rank_1 =
ConstantValueMap::GetRank(n->input(1)->debugName()).value();
only_rank_available = true;
rank = rank_0 + rank_1 - 1;
}
if (ConstantValueMap::HasShapeValue(n->input(0)->debugName()) &&
ConstantValueMap::HasValue(n->input(1)->debugName())) {
auto shape_value =
ConstantValueMap::GetShapeValue(n->input(0)->debugName()).value();
auto idx_value =
ConstantValueMap::GetValue(n->input(1)->debugName()).value();
// Consider the case when Gather index is a scalar.
if (idx_value.dim() == 0) {
auto idx_value_0 = idx_value.item<int64_t>();
if (idx_value_0 >= 0) {
std::vector<c10::ShapeSymbol> dims = {shape_value.at(idx_value_0)};
c10::SymbolicShape symShape(dims);
ConstantValueMap::SetShapeValue(n->output()->debugName(), symShape);
}
}
}
break;
}
case ::c10::onnx::Transpose: {
if (n->hasAttributeS("perm")) {
auto perm_v = n->is(attr::perm);
rank = perm_v.size();
auto is_default_perm = false;
if (rank == 2 && perm_v[0] == 1 && perm_v[1] == 0) {
is_default_perm = true;
}
auto shape_updated = false;
if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
auto shape_size_0 =
ConstantValueMap::GetShape(n->input(0)->debugName())
.value()
.sizes();
if (shape_size_0.has_value()) {
auto shape_vector_0 = shape_size_0.value();
std::vector<::c10::ShapeSymbol> final_shape_vector(
shape_vector_0.size(), ::c10::ShapeSymbol());
if (is_default_perm) {
std::reverse_copy(
std::begin(shape_vector_0),
std::end(shape_vector_0),
std::begin(final_shape_vector));
} else {
for (const auto i : c10::irange(shape_vector_0.size())) {
final_shape_vector[i] = shape_vector_0[perm_v[i]];
}
}
::c10::SymbolicShape final_shape(final_shape_vector);
UpdateShape(n->output(), final_shape);
shape_updated = true;
}
}
if (!shape_updated) {
if (!is_default_perm) {
only_rank_available = true;
} else if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
only_rank_available = true;
}
}
}
break;
}
case ::c10::onnx::Concat: {
ProcessConcatNode(n);
break;
}
case ::c10::onnx::ConstantOfShape: {
if (ConstantValueMap::HasValue(n->input()->debugName())) {
auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
n->input()->debugName());
UpdateShapeFromVector(n->output(), shape_temp);
if (!shape_temp.empty()) {
if (n->hasAttributeS("value")) {
auto value = n->t(attr::value).repeat(shape_temp);
ConstantValueMap::SetValue(n->output()->debugName(), value);
} else {
auto options =
c10::TensorOptions().dtype(at::kFloat).device(at::kCPU);
auto value = at::full({1}, 0.0, options).repeat(shape_temp);
ConstantValueMap::SetValue(n->output()->debugName(), value);
}
}
}
break;
}
case ::c10::onnx::Expand: {
if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
auto input0_shape_size =
ConstantValueMap::GetShape(n->input(0)->debugName())
.value()
.sizes();
if (input0_shape_size.has_value()) {
auto input0_shape_value = input0_shape_size.value();
if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
n->input(1)->debugName());
auto final_shape =
ComputeShapeFromExpand(input0_shape_value, shape_temp);
if (final_shape.has_value()) {
UpdateShape(n->output(), final_shape.value());
}
}
}
}
break;
}
case ::c10::onnx::NonZero: {
if (ConstantValueMap::HasRank(n->input()->debugName())) {
auto rank = ConstantValueMap::GetRank(n->input()->debugName()).value();
std::vector<c10::ShapeSymbol> dims;
dims.emplace_back(
c10::ShapeSymbol::fromStaticSize(static_cast<int64_t>(rank)));
auto input_node = n->input()->node();
if (input_node->kind() == ::c10::onnx::ConstantOfShape) {
if (input_node->hasAttributeS("value")) {
auto value =
input_node->t(attr::value).toType(at::ScalarType::Float);
auto value_a = value.accessor<float, 1>();
if (value_a.size(0) == 1 && std::abs(value_a[0]) > 1e-6) {
if (ConstantValueMap::HasShape(n->input()->debugName())) {
auto shape_size_0 =
ConstantValueMap::GetShape(n->input()->debugName()).value();
if (shape_size_0.isComplete()) {
auto shape_vector_0 = shape_size_0.sizes().value();
int64_t num_elements = 1;
for (auto cur_dim : shape_vector_0) {
num_elements *= cur_dim.static_size();
}
dims.emplace_back(c10::ShapeSymbol::fromStaticSize(
static_cast<int64_t>(num_elements)));
}
}
}
}
}
if (dims.size() == 1) {
dims.emplace_back(c10::ShapeSymbol::newSymbol());
}
c10::SymbolicShape shape_v(dims);
UpdateShape(n->output(), shape_v);
}
break;
}
case ::c10::onnx::MatMul: {
ProcessMatMulNode(n);
break;
}
case ::c10::onnx::ReduceMean:
case ::c10::onnx::ReduceProd: {
ProcessReduceNode(n);
break;
}
case ::c10::onnx::RNN:
case ::c10::onnx::LSTM:
case ::c10::onnx::GRU: {
ProcessTimeSeriesNode(n);
break;
}
case ::c10::onnx::Size: {
if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
auto input0_shape_size =
ConstantValueMap::GetShape(n->input(0)->debugName())
.value()
.sizes();
if (input0_shape_size.has_value()) {
auto input0_shape_value = input0_shape_size.value();
int64_t total_size = 1;
auto is_full_static = true;
for (const auto i : c10::irange(input0_shape_value.size())) {
if (input0_shape_value[i].is_static()) {
total_size *= input0_shape_value[i].static_size();
} else {
is_full_static = false;
break;
}
}
if (is_full_static) {
auto f_final = onnx_constant_fold::IntToTensor(total_size);
ConstantValueMap::SetValue(n->output(0)->debugName(), f_final);
}
}
}
break;
}
case ::c10::onnx::Slice: {
ProcessSliceNode(n, opset_version);
break;
}
case ::c10::onnx::Cast:
case ::c10::onnx::Relu:
case ::c10::onnx::Softmax: {
ProcessUnchangeNode(n);
break;
}
case ::c10::onnx::Tile: {
if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
auto input0_shape_size =
ConstantValueMap::GetShape(n->input(0)->debugName())
.value()
.sizes();
if (input0_shape_size.has_value()) {
auto input0_shape_value = input0_shape_size.value();
if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
n->input(1)->debugName());
auto final_shape =
ComputeShapeFromTile(input0_shape_value, shape_temp);
if (final_shape.has_value()) {
UpdateShape(n->output(), final_shape.value());
}
}
}
}
break;
}
case ::c10::onnx::Unsqueeze: {
ProcessUnsqueezeNode(n);
break;
}
default: {
break;
}
}
if (n->outputs().size() > 1 ||
ConstantValueMap::HasShape(n->output(0)->debugName())) {
return;
}
if (only_rank_available) {
UpdateRank(n->output(), rank);
}
}