throw propagation_error()

in torch/csrc/jit/passes/shape_analysis.cpp [118:2131]


  throw propagation_error()

namespace {

bool isValidArgumentForRunning(Value* v) {
  // allow constants
  if (toIValue(v))
    return true;
  if (TensorTypePtr tt = v->type()->cast<TensorType>()) {
    if (!tt->scalarType()) {
      return false;
    }
    return !at::isIntegralType(*tt->scalarType(), /*includeBool=*/false);
  }
  return v->type()->isSubtypeOf(*FloatType::get());
}

bool isValidReturnForRunning(Value* v) {
  return v->type()->isSubtypeOf(*TensorType::get()) ||
      v->type()->isSubtypeOf(*NumberType::get());
}

bool containsTensorType(const TypePtr& t) {
  auto n_contained = t->containedTypes().size();
  if (n_contained == 1) {
    return t->containedTypes().at(0)->isSubtypeOf(*TensorType::get());
  } else if (n_contained > 1) {
    return std::any_of(
        t->containedTypes().begin(),
        t->containedTypes().end(),
        containsTensorType);
  }
  return false;
}

// for each node in the schema with type Tensor, extract the T type
// returns c10::nullopt if any Tensor in the schema does not have a known
// shape ignores non-tensor in the list of inputs
c10::optional<std::vector<TensorTypePtr>> gatherTensorTypes(
    Node* node,
    bool complete = false) {
  std::vector<TensorTypePtr> tensor_types;

  auto schema_opt = node->maybeSchema();
  if (!schema_opt) {
    return c10::nullopt;
  }
  auto& schema = *schema_opt;
  auto& args = schema.arguments();
  // can't handle varargs primitives because we don't know what should be a
  // Tensor
  if (schema.is_vararg()) {
    return c10::nullopt;
  }
  for (const auto i : c10::irange(args.size())) {
    if (args[i].type()->isSubtypeOf(*ListType::ofTensors())) {
      return c10::nullopt;
    } else if (args[i].type()->isSubtypeOf(*TensorType::get())) {
      if (auto type = node->input(i)->type()->cast<TensorType>()) {
        if (complete && !type->isComplete()) {
          return c10::nullopt;
        }
        tensor_types.push_back(type);
      } else {
        return c10::nullopt;
      }
    } else /* non-tensor type */ {
      continue;
    }
  }
  return tensor_types;
}

int64_t wrapDim(int64_t dim, at::IntArrayRef sizes) {
  if (dim < 0) {
    dim += (int64_t)sizes.size();
  }
  return dim;
}

c10::ScalarType unionScalarTypes(
    c10::ScalarType original,
    c10::ScalarType next) {
  if (original == c10::ScalarType::Undefined) {
    return next;
  } else {
    return c10::promoteTypes(original, next);
  }
}

// Promotes result types for arithmetic operations on Tensor operands using
// new type promotion logic. See tensor_attributes.rst for details.
// This doesn't handle the case of arithmetic ops with Scalar arguments (when
// `Tensor.getUnsafeTensorImpl()->is_wrapped_nubmer()` would return true)
c10::optional<c10::ScalarType> getPromotedTypeForArithmeticOp(Node* node) {
  c10::ScalarType dimmed = c10::ScalarType::Undefined;
  c10::ScalarType zerodim = c10::ScalarType::Undefined;
  // binary arithmetic ops, more than 2 args is alpha.
  for (const auto i : c10::irange(2)) {
    auto dtt = node->inputs()[i]->type()->expect<TensorType>();
    auto inputDtype = dtt->scalarType();
    if (!dtt || !inputDtype) {
      return c10::nullopt;
    }
    if (dtt->dim() && *dtt->dim() > 0) {
      dimmed = unionScalarTypes(dimmed, *inputDtype);
    } else if (!isFloatingType(dimmed)) {
      // if no dimensions
      zerodim = unionScalarTypes(zerodim, *inputDtype);
    }
  }
  // if a tensor with dimensions is already of the highest category, don't
  // need to check zero-dim tensors.
  if (isFloatingType(dimmed)) {
    return dimmed;
  }
  // int_tensor * zero_dim_floating -> floating_tensor
  if (isIntegralType(dimmed, false) && isFloatingType(zerodim)) {
    return zerodim;
  }
  // bool_tensor * non_bool_scalar -> non_bool_tensor
  if (c10::ScalarType::Bool == dimmed &&
      c10::ScalarType::Undefined != zerodim) {
    return zerodim;
  }
  // types of dimensioned tensors generally take precedence over zero-dim
  // tensors if not promoting due to category. e.g.:
  // int_tensor * long -> int_tensor
  if (c10::ScalarType::Undefined != dimmed) {
    return dimmed;
  }

  // no dimmed tensors. e.g. zero_dim_tensor + zero_dim_tensor.
  return zerodim;
}

class ShapePropagator : public PropertyPropBase {
 public:
  explicit ShapePropagator(const std::shared_ptr<Graph>& graph)
      : PropertyPropBase(graph), aliasDb_(graph) {
    collectResizeSet(graph->block());
  }

 private:
  ValueSet resized_alias_set;
  const AliasDb aliasDb_;

  bool resizesInput(Node* n) {
    static std::unordered_set<Symbol> resize_ops{
        aten::resize_,
        aten::resize_as_,
        aten::copy_,
        aten::set_,
        aten::unsqueeze_,
        aten::t_,
        aten::transpose_,
    };

    if (resize_ops.count(n->kind()))
      return true;

    if (!n->maybeSchema())
      return false;

    // ops which take the result and write to input "out"
    if (auto out_arg_index = n->schema().argumentIndexWithName("out")) {
      auto arg = n->schema().arguments().at(*out_arg_index);
      return arg.kwarg_only() && arg.type()->isSubtypeOf(*TensorType::get());
    }
    return false;
  }

  void collectResizeSet(Block* block) {
    for (Node* n : block->nodes()) {
      for (Block* b : n->blocks()) {
        collectResizeSet(b);
      }
      if (resizesInput(n)) {
        for (const auto input : n->inputs()) {
          if (aliasDb_.writesToAlias(n, {input})) {
            resized_alias_set.insert(input);
          }
        }
      }
    }
  }

  IValue representativeValue(Value* v) {
    TypePtr type_ = v->type();
    // if the value is actually constant, just use it!
    if (auto iv = toIValue(v)) {
      return *iv;
    }
    if (TensorTypePtr type = type_->cast<TensorType>()) {
      if (type->isComplete()) {
        at::DeviceGuard device_guard(*type->device());
        return at::empty_strided(
                   *type->sizes().concrete_sizes(),
                   *type->strides().concrete_sizes(),
                   at::TensorOptions(*type->device())
                       .dtype(*type->scalarType()))
            .zero_();
      }
      // fallthrough
    } else if (type_->isSubtypeOf(*FloatType::get())) {
      return 0.f;
    }
    // we should not get here because isValidArgumentForRunning should have
    // prevented it
    std::stringstream ss;
    ss << "unable to create representative value for: " << type_->str()
       << ". File a bug report";
    throw std::runtime_error(ss.str());
  }

  void broadcastBinary(
      Node* node,
      std::vector<TensorTypePtr>& types,
      size_t idx1,
      size_t idx2) {
    auto expected_size = at::infer_size(
        *types[idx1]->sizes().concrete_sizes(),
        *types[idx2]->sizes().concrete_sizes());
    auto broadcast = [&](size_t input_idx) {
      TensorTypePtr input_type = types.at(input_idx);
      if (input_type->sizes() == expected_size)
        return;
      auto graph = node->owningGraph();
      WithInsertPoint point_guard{node};
      Node* expand = graph
                         ->create(
                             aten::expand,
                             {node->inputs().at(input_idx),
                              graph->insertConstant(expected_size),
                              graph->insertConstant(false)})
                         ->insertBefore(node);
      propagateNode(expand);
      node->replaceInput(input_idx, expand->output());
    };
    broadcast(idx1);
    broadcast(idx2);
    types[0] = node->inputs().at(idx1)->type()->expect<TensorType>();
    types[1] = node->inputs().at(idx2)->type()->expect<TensorType>();
  }

  OperatorSet cannot_propagate_shape_by_running_it = {
      "aten::solve(Tensor self, Tensor A) -> (Tensor, Tensor)",
      "aten::inverse(Tensor self) -> Tensor",
  };

  // Check if this node depends on a value that has been mutated previously. If
  // it has, then it's not safe to run this node in isolation, since we don't
  // know whether the dependency has been executed.
  std::unordered_map<Node*, bool> dependsOnMutationMemo_;
  bool dependsOnMutation(Node* node) {
    if (dependsOnMutationMemo_.count(node) != 0) {
      return dependsOnMutationMemo_[node];
    }

    if (aliasDb_.hasWriters(node)) {
      // If something could have written to a value used by this node, we can't
      // guarantee the result is the same when running it in isolation.
      dependsOnMutationMemo_[node] = true;
      return true;
    }

    // recursively check the producers of its inputs. We need to do this if the
    // mutable value has been laundered through a pure function:
    //   a += 1
    //   c = a + b
    //   d = c + 1
    // In this case, `d` cares whether `a` has been mutated even though it's not
    // a direct input.
    auto depends = false;
    for (auto input : node->inputs()) {
      depends |= dependsOnMutation(input->node());
    }

    dependsOnMutationMemo_[node] = depends;
    return depends;
  }

  bool canPropagateShapeByRunningIt(Node* node) {
    if (node->isMemberOf(cannot_propagate_shape_by_running_it)) {
      return false;
    }

    if (dependsOnMutation(node)) {
      return false;
    }

    bool valid_args = std::all_of(
        node->inputs().begin(),
        node->inputs().end(),
        isValidArgumentForRunning);
    if (!valid_args)
      return false;

    bool valid_returns = std::all_of(
        node->outputs().begin(),
        node->outputs().end(),
        isValidReturnForRunning);
    if (!valid_returns)
      return false;

    return true;
  }

  // If there's no Tensor in outputs, e.g float / float,
  // we don't need to propagate shape.
  bool DoesntRefineOutputs(Node* node) {
    auto outputs = node->outputs();
    for (auto& out : outputs) {
      if (containsTensorType(out->type())) {
        return false;
      }
    }
    return true;
  }

  bool PropagateShapeOnNodeByRunningIt(Node* node, Operation op = nullptr) {
    if (!canPropagateShapeByRunningIt(node))
      return false;

    if (!op)
      op = node->getOperation();

    Stack stack;

    for (auto input : node->inputs()) {
      stack.push_back(representativeValue(input));
    }

    // XXX: we're not catching any exceptions from the op for now. This
    // is to uncover any mistakes we could make when editing this code,
    // and eventually it shouldn't matter, because this phase should be
    // preceded by schema checking.
    op(stack);

    AT_ASSERT(stack.size() == node->outputs().size());
    for (const auto i : c10::irange(stack.size())) {
      // some ops may have mixed tensor/primitive outputs
      // for primitives, we don't need to change the type because it is already
      // its most constrained form.
      auto tensor_type = node->outputs()[i]->type()->cast<TensorType>();
      if (stack[i].isTensor() && tensor_type) {
        // gradient information isn't always available or part of represenative
        // inputs, maintain original grad property
        auto tensor_grad = tensor_type->requiresGrad();
        node->outputs()[i]->setType(TensorType::create(stack[i].toTensor())
                                        ->withRequiresGrad(tensor_grad));
      }
    }
    return true;
  }

  void PropagateCatShape(Node* cat_node) {
    static const auto propagate_complete =
        [](Node* node, at::ArrayRef<Value*> tensors) -> bool {
      auto input_types =
          fmap(tensors, [](Value* v) { return v->type()->cast<TensorType>(); });
      if (!std::all_of(
              input_types.begin(),
              input_types.end(),
              [](const TensorTypePtr& tp) {
                return tp != nullptr && tp->isComplete();
              })) {
        return false;
      }
      if (!node->is_constant(attr::dim))
        return false;
      std::vector<int64_t> sizes = *input_types[0]->sizes().concrete_sizes();
      const int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
      const int64_t ndim = (int64_t)sizes.size();

      if (dim < 0 || dim >= ndim)
        return false;

      sizes[dim] = 0;
      for (auto& tp : input_types) {
        auto tp_sizes = tp->sizes().concrete_sizes().value();
        if (sizes.size() != tp_sizes.size())
          return false;
        for (const auto i : c10::irange(ndim)) {
          if (sizes[i] != tp_sizes[i] && i != dim) {
            return false;
          }
        }
        sizes[dim] += tp_sizes[dim];
      }
      node->output()->setType(input_types[0]->withSizes(sizes));
      return true;
    };
    static const auto propagate = [](Node* node,
                                     at::ArrayRef<Value*> tensors) -> bool {
      for (Value* v : tensors) {
        if (auto type = v->type()->cast<TensorType>()) {
          node->output()->setType(type->dimensionedOnly());
          return true;
        }
      }
      return false;
    };
    auto list_node =
        ((cat_node->kind() == prim::FusedConcat)
             ? cat_node
             : cat_node->namedInput(attr::tensors)->node());
    if (list_node->kind() == prim::ListConstruct ||
        cat_node->kind() == prim::FusedConcat) {
      auto tensors = list_node->inputs();
      if (!tensors.empty()) {
        // NOLINTNEXTLINE(bugprone-branch-clone)
        if (propagate_complete(cat_node, tensors)) {
          return;
        } else if (propagate(cat_node, tensors)) {
          return;
        }
      }
    }
    setUnshapedType(cat_node);
  }

  void propagateTorchTensorShape(Node* node) {
    auto input_type = node->inputs().at(0)->type();

    size_t dims = 0;
    auto input_base_type = input_type;
    auto list_type = input_type->cast<ListType>();
    while (list_type) {
      dims++;
      input_base_type = list_type->getElementType();
      list_type = input_base_type->cast<ListType>();
    }

    at::optional<at::ScalarType> default_type =
        tryScalarTypeFromJitType(*input_base_type);
    if (auto grad_index = node->schema().argumentIndexWithName("dtype")) {
      auto inp = toIValue(node->inputs().at(*grad_index));
      if (inp == c10::nullopt) {
        return;
      } else if (!inp->isNone()) {
        default_type = inp->toScalarType();
      }
    }

    at::Device default_device = at::kCPU;
    if (auto device_index = node->schema().argumentIndexWithName("device")) {
      auto inp = toIValue(node->inputs().at(*device_index));
      if (inp == c10::nullopt) {
        return;
      } else if (!inp->isNone()) {
        default_device = inp->toDevice();
      }
    }
    node->output()->setType(TensorType::create(
        default_type, default_device, dims, /*requires_grad=*/c10::nullopt));
  }

  // returns whether any such values were found
  bool setUnshapedTypeIfAliasResizedSet(at::ArrayRef<Value*> vs) {
    bool in_resize = false;
    for (auto v : vs) {
      if (aliasDb_.mayAlias(ValueSet{v}, resized_alias_set)) {
        setUnshapedType(v);
        in_resize = true;
      }
    }
    return in_resize;
  }

  void propagateNode(Node* node, bool insert_expands = true) override {
    // Certain ops like resize_ change the input tensors size. Because our
    // analysis is flow invariant, we set any Tensor that can alias a resized
    // Tensor to the base Tensor Type without size information.
    if (setUnshapedTypeIfAliasResizedSet(node->inputs())) {
      return setUnshapedType(node);
    }

    // These don't require the types, and have complicated schema. Return early
    // after we process them.
    switch (node->kind()) {
      case prim::If:
        return processIf(node);
      case prim::Loop: {
        return processLoop(node);
      }
      case aten::Bool:
      case aten::Int:
      case aten::Float:
      case aten::ScalarImplicit:
      case aten::FloatImplicit:
      case aten::IntImplicit:
        return; // correct num type is already set
      case prim::NumToTensor: {
        TypePtr typ = node->input()->type();
        if (typ->isSubtypeOf(*IntType::get()) ||
            typ->isSubtypeOf(*BoolType::get())) {
          node->output()->setType(TensorType::create(
              at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt));
        } else if (node->input()->type()->isSubtypeOf(*FloatType::get())) {
          node->output()->setType(TensorType::create(
              at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt));
        }
        return;
      }
      case aten::tensor:
      case aten::as_tensor: {
        // as_tensor has an overloaded schema and can either have a tensor or
        // a list as the first input, if the input is a tensor, we delegate
        // the shape propagation in PropagateTensorShapeOnNode
        if (node->inputs().at(0)->type()->isSubtypeOf(*TensorType::get())) {
          break;
        }
        return propagateTorchTensorShape(node);
      }
      case prim::TupleConstruct: {
        // We refresh the tuple type, because the input types could have been
        // refined.
        auto orig_type = node->output()->type()->expect<TupleType>();
        auto new_types =
            fmap(node->inputs(), [](Value* v) { return v->type(); });
        node->output()->setType(
            orig_type->createWithContained(std::move(new_types)));
        return;
      }
      case prim::TupleUnpack: {
        auto tuple_type = node->input()->type()->cast<TupleType>();
        AT_ASSERT(
            tuple_type &&
            tuple_type->elements().size() == node->outputs().size());
        auto elems = tuple_type->elements();
        for (size_t i = 0; i < node->outputs().size(); ++i) {
          node->output(i)->setType(elems[i]);
        }
        return;
      }
      case prim::Constant: {
        if (node->output()->type()->isSubtypeOf(*TensorType::get())) {
          node->output()->inferTypeFrom(node->t(attr::value));
        }
        return;
      }
      case prim::unchecked_unwrap_optional: {
        // If we have specialized the optional type to the element type,
        // we want to pass it down. We write this as input.isSubtypeOf(output)
        // to be sure that we don't screw up nested optionals.
        if (node->input()->type()->isSubtypeOf(*node->output()->type())) {
          node->output()->setType(node->input()->type());
        }
        return;
      }
      case prim::ConstantChunk: {
        Value* tensor = node->input();
        if (auto type = tensor->type()->cast<TensorType>()) {
          type = type->dimensionedOnly();
          for (Value* output : node->outputs()) {
            output->setType(type);
          }
        } else {
          setUnshapedType(node);
        }
        return;
      }
      case prim::grad: {
        auto tt = node->input()->type()->expect<TensorType>();
        // grad may be undefined
        // requires_grad may be required
        auto grad_type = TensorType::get()->withPossiblyUndefined();
        node->output()->setType(grad_type);
        return;
      }
      case prim::CallFunction:
      case prim::CallMethod:
      case prim::AutogradZero: {
        setUnshapedType(node);
        return;
      }
      case prim::GetAttr: {
        auto cls = node->input()->type()->expect<ClassType>();
        // propagate any type specializations encoded in the type of the class
        node->output()->setType(cls->getAttribute(node->s(attr::name)));
        return;
      }
      case aten::_unwrap_optional: {
        // If we have specialized the optional type to the element type,
        // we want to pass it down. We write this as input.isSubtypeOf(output)
        // to be sure that we don't screw up nested optionals.
        if (node->input()->type()->isSubtypeOf(*node->output()->type())) {
          node->output()->setType(node->input()->type());
        }
        return;
      }
      default:
        break; // fall-through
    }

    if (node->hasSideEffects()) {
      return;
    }

    if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor") ||
        node->kind() == prim::FusedConcat) {
      return PropagateCatShape(node);
    }

    if (auto maybe_complete_types =
            gatherTensorTypes(node, /*complete=*/true)) {
      if (PropagateCompleteShapeOnNode(
              node, insert_expands, std::move(*maybe_complete_types))) {
        return;
      }
    }

    if (PropagateTensorShapeOnNode(node, insert_expands)) {
      return;
    }

    if (DoesntRefineOutputs(node)) {
      return;
    }

    if (PropagateShapeOnNodeByRunningIt(node)) {
      return;
    }
    return setUnshapedType(node);
  }

  static c10::optional<size_t> determineListSize(Value* list) {
    AT_ASSERT(list->type()->cast<ListType>());
    if (auto shape = constant_as<c10::List<int64_t>>(list)) {
      return shape->size();
    }
    auto input_node = list->node();
    if (input_node->kind() == prim::ListConstruct) {
      return input_node->inputs().size();
    }
    return c10::nullopt;
  }

  // is it ok to try to run the op
  // If an input is a constant, then we assume that the input is valid
  // and we can try to run it.
  // Otherwise:
  // Integral typed _inputs_ are often an indicator that we're indexing into
  // a tensor, so we should special-case these ops in the shape propagation.
  // Additionally, passing in a zero representative tensor into an integer
  // division op causes divide-by-zero errors
  // _Outputs_ must be tensors or primitives
  // We will call inferTypeFrom on the tensors, and ignore the primitives.
  // However, we allow primitive returns because we want to support mixed
  // primitive/tensor outputs.

  bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
    static const auto broadcast =
        [](std::vector<TensorTypePtr>& tensor_types,
           c10::optional<at::ScalarType> t) -> TensorTypePtr {
      if (tensor_types.size() == 1) {
        return tensor_types[0]->dimensionedOnly()->withScalarType(t);
      }
      AT_ASSERT(!tensor_types.empty());
      auto any_type = tensor_types[0];
      auto max_dims = any_type->dim();
      for (auto& type : tensor_types) {
        if (!max_dims || !type->dim()) {
          max_dims = c10::nullopt;
        } else {
          max_dims = std::max(*max_dims, *type->dim());
        }
      }
      return TensorType::create(
          t,
          any_type->device(),
          max_dims,
          /*requires_grad=*/c10::nullopt);
    };

    using type_vec_t = std::vector<TensorTypePtr>;
    // Formula is expected to return a vector of length equal to the number of
    // tensor outputs of the node, or an empty vector which implies that it
    // failed to propagate.
    using formula_t = std::function<type_vec_t(Node*)>;
    static std::mutex shape_formulas_mutex;
    static std::vector<std::pair<OperatorSet, formula_t>> shape_formulas;
    struct register_formula_for {
      register_formula_for(OperatorSet operators, formula_t formula) {
        std::unique_lock<std::mutex> lock{shape_formulas_mutex};
        shape_formulas.emplace_back(std::move(operators), std::move(formula));
      }
    };

    // Requirements:
    //   dims           : preserved
    //   scalar type    : preserved
    //   device         : preserved
    //   tensor inputs  : 1
    //   tensor outputs : 1
    // Additionally:
    //   - First input should be the only tensor input
    static const register_formula_for simple_unary_ops{
        {
            "aten::acos(Tensor self) -> Tensor",
            "aten::neg(Tensor self) -> Tensor",
            "aten::t(Tensor self) -> Tensor",
            "aten::sigmoid(Tensor self) -> Tensor",
            "aten::logit(Tensor self, float? eps=None) -> Tensor",
            "aten::tanh(Tensor self) -> Tensor",
            "aten::relu(Tensor self) -> Tensor",
            "aten::asin(Tensor self) -> Tensor",
            "aten::atan(Tensor self) -> Tensor",
            "aten::ceil(Tensor self) -> Tensor",
            "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
            "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
            "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
            "aten::celu(Tensor self, Scalar alpha) -> Tensor",
            "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
            "aten::clamp_max(Tensor self, Scalar max) -> Tensor",
            "aten::clamp_min(Tensor self, Scalar min) -> Tensor",
            "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
            "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
            "aten::cos(Tensor self) -> Tensor",
            "aten::cosh(Tensor self) -> Tensor",
            "aten::digamma(Tensor self) -> Tensor",
            "aten::dropout(Tensor input, float p, bool train) -> Tensor",
            "aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor",
            "aten::erf(Tensor self) -> Tensor",
            "aten::erfc(Tensor self) -> Tensor",
            "aten::erfinv(Tensor self) -> Tensor",
            "aten::exp(Tensor self) -> Tensor",
            "aten::expm1(Tensor self) -> Tensor",
            "aten::log(Tensor self) -> Tensor",
            "aten::log10(Tensor self) -> Tensor",
            "aten::log1p(Tensor self) -> Tensor",
            "aten::log2(Tensor self) -> Tensor",
            "aten::log_sigmoid(Tensor self) -> Tensor",
            "aten::floor(Tensor self) -> Tensor",
            "aten::frac(Tensor self) -> Tensor",
            "aten::flip(Tensor self, int[] dims) -> Tensor",
            "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
            "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
            "aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
            "aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
            "aten::glu(Tensor self, int dim) -> Tensor",
            "aten::inverse(Tensor self) -> Tensor",
            "aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
            "aten::lgamma(Tensor self) -> Tensor",
            "aten::mvlgamma(Tensor self, int p) -> Tensor",
            "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
            "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
            "aten::permute(Tensor self, int[] dims) -> Tensor",
            "aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)",
            "aten::pinverse(Tensor self, float rcond) -> Tensor",
            "aten::reciprocal(Tensor self) -> Tensor",
            "aten::relu(Tensor self) -> Tensor",
            "aten::round(Tensor self) -> Tensor",
            "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
            "aten::rsqrt(Tensor self) -> Tensor",
            "aten::selu(Tensor self) -> Tensor",
            "aten::gelu(Tensor self) -> Tensor",
            "aten::sigmoid(Tensor self) -> Tensor",
            "aten::sign(Tensor self) -> Tensor",
            "aten::sin(Tensor self) -> Tensor",
            "aten::sinh(Tensor self) -> Tensor",
            "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor",
            "aten::softshrink(Tensor self, Scalar lambd) -> Tensor",
            "aten::sqrt(Tensor self) -> Tensor",
            "aten::tan(Tensor self) -> Tensor",
            "aten::tanh(Tensor self) -> Tensor",
            "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
            "aten::transpose(Tensor self, int dim0, int dim1) -> Tensor",
            "aten::tril(Tensor self, int diagonal) -> Tensor",
            "aten::triu(Tensor self, int diagonal) -> Tensor",
            "aten::trunc(Tensor self) -> Tensor",
            "aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
            "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
            "aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor",
            "aten::alias(Tensor self) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          auto input_type = node->input(0)->type()->cast<TensorType>();
          return input_type ? type_vec_t{input_type->dimensionedOnly()}
                            : type_vec_t{};
        }};

    // Requirements:
    //   dims           : preserved
    //   scalar type    : preserved, except complex maps to float
    //   device         : preserved
    //   tensor inputs  : 1
    //   tensor outputs : 1
    // Additionally:
    //   - First input should be the only tensor input
    static const register_formula_for simple_unary_ops_complex_to_float{
        {
            "aten::abs(Tensor self) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          auto input_type = node->input(0)->type()->cast<TensorType>();

          // Maps complex -> float
          if (input_type->scalarType()) {
            const auto scalar_type = *(input_type->scalarType());
            if (isComplexType(scalar_type)) {
              const auto out_type = c10::toValueType(scalar_type);
              return type_vec_t{
                  input_type->dimensionedOnly()->withScalarType(out_type)};
            }
          }

          return input_type ? type_vec_t{input_type->dimensionedOnly()}
                            : type_vec_t{};
        }};

    // Requirements:
    //   dims           : broadcast all tensor args
    //   scalar type    : promoted from input dtypes
    //   device         : always matching and preserved
    //   tensor inputs  : *
    //   tensor outputs : 1
    static const register_formula_for broadcasting_ops_arithmetic{
        {
            // Tensor-Tensor operators
            "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
            "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
            "aten::mul(Tensor self, Tensor other) -> Tensor",
            "aten::div(Tensor self, Tensor other) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto maybe_tensor_types = gatherTensorTypes(node)) {
            AT_ASSERT(maybe_tensor_types->size() >= 2);
            auto dtype = getPromotedTypeForArithmeticOp(node);
            return {broadcast(*maybe_tensor_types, dtype)};
          }
          return {};
        }};

    // Requirements:
    //   dims           : broadcast all tensor args
    //   scalar type    : always matching and preserved
    //   device         : always matching and preserved
    //   tensor inputs  : *
    //   tensor outputs : 1
    static const register_formula_for broadcasting_ops{
        {
            "aten::pow(Tensor self, Tensor exponent) -> Tensor",
            "aten::fmod(Tensor self, Tensor other) -> Tensor",
            "aten::remainder(Tensor self, Tensor other) -> Tensor",
            "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
            "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
            "aten::max(Tensor self, Tensor other) -> Tensor",
            "aten::min(Tensor self, Tensor other) -> Tensor",
            "aten::__and__(Tensor self, Tensor other) -> Tensor",
            "aten::__or__(Tensor self, Tensor other) -> Tensor",
            "aten::__xor__(Tensor self, Tensor other) -> Tensor",
            "aten::__lshift__(Tensor self, Tensor other) -> Tensor",
            "aten::__rshift__(Tensor self, Tensor other) -> Tensor",
            "aten::__iand__(Tensor self, Tensor other) -> Tensor",
            "aten::__ior__(Tensor self, Tensor other) -> Tensor",
            "aten::__ixor__(Tensor self, Tensor other) -> Tensor",
            "aten::__ilshift__(Tensor self, Tensor other) -> Tensor",
            "aten::__irshift__(Tensor self, Tensor other) -> Tensor",

            // Ops with Tensor-Tensor overloads only
            "aten::atan2(Tensor self, Tensor other) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto maybe_tensor_types = gatherTensorTypes(node)) {
            AT_ASSERT(maybe_tensor_types->size() >= 2);
            auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
            auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
            if (!first_scalar_type || !second_scalar_type) {
              return {};
            }
            size_t arg_for_type = 0;
            if (c10::promoteTypes(*first_scalar_type, *second_scalar_type) !=
                first_scalar_type) {
              arg_for_type = 1;
            }
            auto t = (*maybe_tensor_types)[arg_for_type]->scalarType();
            return {broadcast(*maybe_tensor_types, *t)};
          }
          return {};
        }};

    static const register_formula_for fused_accum_binary_ops{
        {
            // Non-binary ops
            "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
            "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto maybe_tensor_types = gatherTensorTypes(node)) {
            auto dtype = (*maybe_tensor_types)[0]->scalarType();
            if (!dtype) {
              return {};
            }
            return {broadcast(*maybe_tensor_types, *dtype)};
          }
          return {};
        }};

    static const register_formula_for broadcasting_tensor_scalar_ops_arithmetic{
        {
            // Tensor-Scalar operators
            "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
            "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
            "aten::mul(Tensor self, Scalar other) -> Tensor",
            "aten::div(Tensor self, Scalar other) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto maybe_tensor_types = gatherTensorTypes(node)) {
            auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
            auto second_scalar_type =
                tryScalarTypeFromJitType(*node->inputs()[1]->type());
            if (!first_scalar_type || !second_scalar_type) {
              return {};
            }
            if (isIntegralType(*first_scalar_type, false) &&
                isFloatingType(*second_scalar_type)) {
              auto default_dtype =
                  at::typeMetaToScalarType(caffe2::get_default_dtype());
              return {broadcast(*maybe_tensor_types, default_dtype)};
            }
            if (c10::ScalarType::Bool == *first_scalar_type &&
                c10::ScalarType::Bool != *second_scalar_type) {
              auto result_type =
                  c10::promoteTypes(*first_scalar_type, *second_scalar_type);
              return {broadcast(*maybe_tensor_types, result_type)};
            }
            return {broadcast(*maybe_tensor_types, first_scalar_type)};
          }
          return {};
        }};

    // NB: we always take the scalar type of the Tensor
    static const register_formula_for broadcasting_tensor_scalar_ops{
        {

            "aten::pow(Tensor self, Scalar exponent) -> Tensor",
            "aten::fmod(Tensor self, Scalar other) -> Tensor",
            "aten::remainder(Tensor self, Scalar other) -> Tensor",
            "aten::pow(Scalar self, Tensor exponent) -> Tensor",
            "aten::__and__(Tensor self, Scalar other) -> Tensor",
            "aten::__or__(Tensor self, Scalar other) -> Tensor",
            "aten::__xor__(Tensor self, Scalar other) -> Tensor",
            "aten::__lshift__(Tensor self, Scalar other) -> Tensor",
            "aten::__rshift__(Tensor self, Scalar other) -> Tensor",
            "aten::__iand__(Tensor self, Scalar other) -> Tensor",
            "aten::__ior__(Tensor self, Scalar other) -> Tensor",
            "aten::__ixor__(Tensor self, Scalar other) -> Tensor",
            "aten::__ilshift__(Tensor self, Scalar other) -> Tensor",
            "aten::__irshift__(Tensor self, Scalar other) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto maybe_tensor_types = gatherTensorTypes(node)) {
            return {broadcast(
                *maybe_tensor_types, (*maybe_tensor_types)[0]->scalarType())};
          }
          return {};
        }};

    // aten::where is special in that its return type is the second argument's
    // (self) type rather than the that of condition
    static const register_formula_for where_op{
        {
            "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto maybe_tensor_types = gatherTensorTypes(node)) {
            return {broadcast(
                *maybe_tensor_types, (*maybe_tensor_types)[1]->scalarType())};
          }
          return {};
        }};

    static const auto any_tensor_type = [](Node* node) -> TensorTypePtr {
      for (Value* input : node->inputs()) {
        if (auto type = input->type()->cast<TensorType>()) {
          if (type->dim().has_value()) {
            return type;
          }
        }
      }
      return nullptr;
    };

    // Requirements:
    //   dims           : always matching and preserved
    //   scalar type    : always matching and preserved
    //   device         : always matching and preserved
    //   tensor inputs  : 2
    //   tensor outputs : 1
    static const register_formula_for binary_ops_strict_match{
        {
            "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
            "aten::mm(Tensor self, Tensor mat2) -> Tensor",
            "aten::bmm(Tensor self, Tensor mat2) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto type = any_tensor_type(node)) {
            return {type};
          }
          return {};
        }};

    // Requirements:
    //   dims           : all tensor args are broadcast
    //   scalar type    : byte/uint8
    //   device         : always matching and preserved
    //   tensor inputs  : *
    //   tensor outputs : 1
    static const register_formula_for comparison_ops{
        {
            "aten::lt(Tensor self, Tensor other) -> Tensor",
            "aten::le(Tensor self, Tensor other) -> Tensor",
            "aten::gt(Tensor self, Tensor other) -> Tensor",
            "aten::ge(Tensor self, Tensor other) -> Tensor",
            "aten::eq(Tensor self, Tensor other) -> Tensor",
            "aten::ne(Tensor self, Tensor other) -> Tensor",
            "aten::lt(Tensor self, Scalar other) -> Tensor",
            "aten::le(Tensor self, Scalar other) -> Tensor",
            "aten::gt(Tensor self, Scalar other) -> Tensor",
            "aten::ge(Tensor self, Scalar other) -> Tensor",
            "aten::eq(Tensor self, Scalar other) -> Tensor",
            "aten::ne(Tensor self, Scalar other) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto maybe_tensor_types = gatherTensorTypes(node)) {
            return {broadcast(*maybe_tensor_types, at::kBool)};
          }
          return {};
        }};

    static const register_formula_for nn_ops_first_input_formula{
        *nn_ops_first_input_preserving(), [](Node* node) -> type_vec_t {
          if (auto type = node->input(0)->type()->cast<TensorType>()) {
            return {type->dimensionedOnly()};
          }
          return {};
        }};

    // Requirements:
    //   dims           : 0
    //   scalar type    : preserved
    //   device         : preserved
    //   tensor inputs  : 1
    //   tensor outputs : 1
    // Additionally:
    //   - First input should be the only tensor input
    static const register_formula_for all_reduce_ops{
        {
            "aten::det(Tensor self) -> Tensor",
            "aten::logdet(Tensor self) -> Tensor",
            "aten::max(Tensor self) -> Tensor",
            "aten::min(Tensor self) -> Tensor",
            "aten::median(Tensor self) -> Tensor",
            "aten::nanmedian(Tensor self) -> Tensor",
            "aten::norm(Tensor self, Scalar p) -> Tensor",
            "aten::std(Tensor self, bool unbiased) -> Tensor",
            "aten::trace(Tensor self) -> Tensor",
            "aten::var(Tensor self, bool unbiased) -> Tensor",
            "aten::all(Tensor self) -> Tensor",
            "aten::any(Tensor self) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto type = node->input(0)->type()->cast<TensorType>()) {
            return {type->withDim(0)};
          }
          return {};
        }};

    // Requirements:
    //   dims           : 0
    //   scalar type    : dtype if specified, else preserved
    //   device         : preserved
    //   tensor inputs  : 1
    //   tensor outputs : 1
    // Additionally:
    //   - First input should be the only tensor input
    static const register_formula_for reduce_ops_with_opt_dtype{
        {"aten::mean(Tensor self, *, int? dtype) -> Tensor"},
        [](Node* node) -> type_vec_t {
          at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
          if (auto type = node->input(0)->type()->cast<TensorType>()) {
            auto ret = type->withDim(0);
            if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
              return {ret->withScalarType(maybe_dtype_option->toScalarType())};
            } else {
              return {ret};
            }
          }
          return {};
        }};

    // Requirements:
    //   dims           : 0
    //   scalar type    : dtype if specified, else preserved if floating point,
    //   otherwise long/int64 device         : preserved tensor inputs  : 1
    //   tensor outputs : 1
    // Additionally:
    //   - First input should be the only tensor input
    static const register_formula_for
        all_reduce_ops_with_integer_upcast_and_dtype{
            {
                "aten::sum(Tensor self, *, int? dtype) -> Tensor",
                "aten::prod(Tensor self, *, int? dtype) -> Tensor",
            },
            [](Node* node) -> type_vec_t {
              if (auto type = node->input(0)->type()->cast<TensorType>()) {
                type = type->withDim(0);
                at::optional<IValue> maybe_dtype_option =
                    node->get(attr::dtype);
                if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
                  return {
                      type->withScalarType(maybe_dtype_option->toScalarType())};
                }
                if (type->scalarType()) {
                  return {
                      at::isFloatingType(*type->scalarType())
                          ? type
                          : type->withScalarType(at::kLong)};
                } else {
                  return {type};
                }
              }
              return {};
            }};

    static const auto reduce_op_handler = [](Node* node,
                                             int64_t num_reduced_dim = 0,
                                             bool upcast_integer = false,
                                             c10::optional<IValue> opt_dtype =
                                                 c10::nullopt) -> type_vec_t {
      if (auto type = node->input(0)->type()->cast<TensorType>()) {
        if (!type->scalarType() || !type->dim()) {
          return {};
        }
        if (opt_dtype && !opt_dtype->isNone()) {
          type = type->withScalarType(opt_dtype->toScalarType());
        } else if (upcast_integer && !at::isFloatingType(*type->scalarType())) {
          type = type->withScalarType(at::kLong);
        }
        // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
        if (*type->dim() >= num_reduced_dim && num_reduced_dim > 0) {
          return {type->withDim(*type->dim() - num_reduced_dim)};
        } else {
          return {type};
        }
      }
      return {};
    };

    static const auto multidim_reduce_with_keepdim =
        [](Node* node,
           int64_t num_reduced_dim,
           bool upcast_integer) -> type_vec_t {
      auto maybe_keepdim = node->get<bool>(attr::keepdim);
      if (!maybe_keepdim)
        return {};
      return reduce_op_handler(
          node, *maybe_keepdim ? 0 : num_reduced_dim, upcast_integer);
    };

    // Requirements:
    //   dims           : 0 if dim is None, otherwise preserved if keepdim ==
    //   false or 1 smaller otherwise scalar type    : preserved device :
    //   preserved tensor inputs  : 1 tensor outputs : 1
    // Additionally:
    //   - First input should be the only tensor input
    //   - Has a bool keepdim argument
    static const register_formula_for argminmax{
        {
            "aten::argmax(Tensor self, int? dim, bool keepdim) -> Tensor",
            "aten::argmin(Tensor self, int? dim, bool keepdim) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto type = node->input(0)->type()->cast<TensorType>()) {
            if (node->input(1)->type()->kind() == c10::TypeKind::NoneType) {
              return {type->withDim(0)};
            } else {
              return multidim_reduce_with_keepdim(
                  node, /*num_reduced_dim=*/1, /*upcast_integer=*/false);
            }
          }
          return {};
        }};

    // Requirements:
    //   dims           : preserved if keepdim == false, 1 smaller otherwise
    //   scalar type    : preserved for first output, byte/uint8 for second
    //   output if exists device         : preserved tensor inputs  : 1 tensor
    //   outputs : 1 or 2
    // Additionally:
    //   - First input should be the only tensor input
    //   - Has a bool keepdim argument
    static const register_formula_for dim_reduce_ops{
        {
            "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
            "aten::any(Tensor self, int dim, bool keepdim) -> Tensor",

            // Ops returning indices as second output
            "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
            "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
            "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
            "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
            "aten::nanmedian(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
            "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
        },
        [](Node* node) -> type_vec_t {
          // NB: Note that while this function is generally meant to be used
          // with ops that have a single output, we will fix up its return right
          // below.
          auto output_types = multidim_reduce_with_keepdim(
              node, /*num_reduced_dim=*/1, /*upcast_integer=*/false);
          if (!output_types.empty() && node->outputs().size() == 2) {
            output_types.push_back(
                output_types.back()->withScalarType(at::kLong));
          }
          return output_types;
        }};

    // Requirements:
    //   dims           : preserved if keepdim == false, 1 smaller otherwise
    //   scalar type    : dtype if specified. preserved if floating point,
    //   otherwise long/int64 device         : preserved tensor inputs  : 1
    //   tensor outputs : 1
    // Additionally:
    //   - First input should be the only tensor input
    //   - has a bool keepdim argument
    static const register_formula_for dim_reduce_ops_with_integer_upcast{
        {
            "aten::prod(Tensor self, int dim, bool keepdim, *, int? dtype) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          auto maybe_keepdim = node->get<bool>(attr::keepdim);
          at::optional<IValue> opt_dtype = node->get(attr::dtype);
          return reduce_op_handler(
              node,
              /*num_reduce_dim=*/*maybe_keepdim ? 0 : 1,
              /*integer_upcast=*/true,
              opt_dtype);
        }};

    // Requirements:
    //   dims           : preserved
    //   scalar type    : dtype if specified, preserved if floating point,
    //    otherwise long/int64
    //   device         : preserved
    //   tensor inputs  : 1
    //   tensor outputs : 1
    // Additionally:
    //   - First input should be the only tensor input
    static const register_formula_for dim_reduce_ops_dtype{
        {"aten::cumprod(Tensor self, int dim, *, int? dtype) -> Tensor",
         "aten::cumsum(Tensor self, int dim, *, int? dtype) -> Tensor",
         "aten::log_softmax(Tensor self, int dim, int? dtype) -> Tensor"},
        [](Node* node) -> type_vec_t {
          at::optional<IValue> opt_dtype = node->get(attr::dtype);
          return reduce_op_handler(
              node, /*num_reduce_dim=*/0, /*integer_upcast=*/true, opt_dtype);
        }};

    // Requirements:
    //   dims           : preserved
    //   scalar type    : dtype if specified, otherwise preserved
    //   device         : preserved
    //   tensor inputs  : 1
    //   tensor outputs : 1
    // Additionally:
    //   - has bool keepdim and int[] dim arguments
    static const register_formula_for register_softmax{
        {"aten::softmax(Tensor self, int dim, int? dtype) -> Tensor"},
        [](Node* node) -> type_vec_t {
          at::optional<IValue> opt_dtype = node->get(attr::dtype);
          return reduce_op_handler(
              node, /*num_reduced_dim=*/0, /*upcast_integer=*/false, opt_dtype);
        }};

    static const auto factory_with_ndim = [](Node* node,
                                             int dim) -> type_vec_t {
      at::optional<IValue> maybe_layout_option = node->get(attr::layout);
      if (!maybe_layout_option)
        return {};

      at::optional<IValue> maybe_device_option = node->get(attr::device);
      if (!maybe_device_option)
        return {};
      auto device =
          (maybe_device_option->isNone() ? at::kCPU
                                         : maybe_device_option->toDevice());

      at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
      if (!maybe_dtype_option)
        return {};
      auto dtype =
          (maybe_dtype_option->isNone() ? at::kDouble
                                        : maybe_dtype_option->toScalarType());

      return {TensorType::create(
          dtype, device, dim, /*requires_grad=*/c10::nullopt)};
    };

    static const auto factory_like_with_ndim = [](Node* node,
                                                  int dim) -> type_vec_t {
      auto tt = node->input(0)->type()->expect<TensorType>();
      auto in_type = tt->scalarType();
      auto in_dev = tt->device();

      at::optional<IValue> maybe_layout_option = node->get(attr::layout);
      if (!maybe_layout_option)
        return {};

      at::optional<IValue> maybe_device_option = node->get(attr::device);
      if (!maybe_device_option)
        return {};

      if (!maybe_device_option->isNone()) {
        in_dev = maybe_device_option->toDevice();
      }

      at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
      if (!maybe_dtype_option)
        return {};

      if (!maybe_dtype_option->isNone()) {
        in_type = maybe_dtype_option->toScalarType();
      }

      return {TensorType::create(
          in_type, in_dev, dim, /*requires_grad=*/c10::nullopt)};
    };

    // Requirements:
    //   dims           : preserved
    //   scalar type    : equal to value of dtype
    //   device         : equal to value of device
    //   tensor inputs  : 1
    //   tensor outputs : 1
    // Additionally:
    //   - has ScalarType dtype, Layeout layout and Device device arguments
    static const register_formula_for like_factories_with_options{
        {
            "aten::empty_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
            "aten::full_like(Tensor self, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
            "aten::ones_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
            "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
            "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
            "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
            "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
            "aten::zeros_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto type =
                  node->namedInput(attr::self)->type()->cast<TensorType>()) {
            if (type->dim()) {
              return factory_like_with_ndim(node, (int)*type->dim());
            }
          }
          return {};
        }};

    // Requirements:
    //   dims           : equal to number of elements in size
    //   scalar type    : equal to value of dtype
    //   device         : equal to value of device
    //   tensor inputs  : 1
    //   tensor outputs : 1
    // Additionally:
    //   - has int[] size, ScalarType dtype, Layeout layout and Device device
    //   arguments
    static const register_formula_for size_factories_with_options{
        {
            "aten::empty(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory, MemoryFormat? memory_format=contiguous_format) -> Tensor",
            "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
            "aten::ones(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
            "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
            "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
            "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
            "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
            "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto maybe_size = node->get<c10::List<int64_t>>(attr::size)) {
            return factory_with_ndim(node, (int)maybe_size->size());
          }
          return {};
        }};

    static const auto get_cast_scalar_type = [](Node* node) -> at::ScalarType {
      switch (node->kind()) {
        case aten::_cast_Byte:
          return at::kByte;
        case aten::_cast_Char:
          return at::kChar;
        case aten::_cast_Double:
          return at::kDouble;
        case aten::_cast_Float:
          return at::kFloat;
        case aten::_cast_Half:
          return at::kHalf;
        case aten::_cast_Int:
          return at::kInt;
        case aten::_cast_Long:
          return at::kLong;
        case aten::_cast_Short:
          return at::kShort;
        default:
          AT_ASSERTM(
              false,
              "unknown node kind in get_cast_scalar_type: ",
              node->kind().toQualString());
      }
    };
    static const register_formula_for cast_ops{
        {
            "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
            "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
            "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
            "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
            "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
            "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
            "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
            "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
        },
        [](Node* node) -> type_vec_t {
          if (auto type =
                  node->namedInput(attr::self)->type()->cast<TensorType>()) {
            return {type->withScalarType(get_cast_scalar_type(node))};
          }
          return {};
        }};

    // First, try to match one of the registered formulas to their operator
    // sets.
    for (auto& entry : shape_formulas) {
      if (node->isMemberOf(entry.first)) {
        auto types = entry.second(node);
        if (types.empty()) {
          return false;
        } else {
          auto outputs = node->outputs();
          AT_ASSERT(types.size() == outputs.size());
          for (const auto i : c10::irange(types.size())) {
            AT_ASSERT(outputs[i]->type()->isSubtypeOf(*TensorType::get()));
            outputs[i]->setType(types[i]);
          }
          return true;
        }
      }
    }

    // This section implements shape prop for an assorted set of nodes that only
    // need partial information about their input types.
    const auto input_type = [node](size_t index) {
      auto result = node->input(index)->type()->cast<TensorType>();
      if (result) {
        result = result->dimensionedOnly();
      }
      return result;
    };
    if (node->matches(
            "aten::masked_select(Tensor self, Tensor mask) -> Tensor")) {
      if (auto type = input_type(0)) {
        node->output()->setType(type->withDim(1));
        return true;
      }
    } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
      if (auto type = input_type(0)) {
        node->output()->setType(type->withRequiresGrad(false));
        return true;
      }
    } else if (
        node->matches(
            "aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)")) {
      if (auto type = input_type(0)) {
        if (type->scalarType() == at::kHalf) {
          type = type->withScalarType(at::kFloat);
        }
        type = type->withDim(1);
        node->outputs()[0]->setType(type);
        node->outputs()[1]->setType(type);
        return true;
      }
    } else if (node->matches(
                   "aten::dot(Tensor self, Tensor tensor) -> Tensor")) {
      if (auto type = any_tensor_type(node)) {
        node->output()->setType(type->withDim(0));
        return true;
      }
    } else if (
        node->matches("aten::mv(Tensor self, Tensor vec) -> Tensor") ||
        node->matches(
            "aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha) -> Tensor")) {
      if (auto type = any_tensor_type(node)) {
        node->output()->setType(type->withDim(1));
        return true;
      }
    } else if (
        node->matches(
            "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor") ||
        node->matches(
            "aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor") ||
        node->matches(
            "aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha) -> Tensor")) {
      if (auto type = any_tensor_type(node)) {
        node->output()->setType(type->withDim(2));
        return true;
      }
    } else if (
        node->matches(
            "aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor")) {
      if (auto type = any_tensor_type(node)) {
        node->output()->setType(type->withDim(3));
        return true;
      }
    } else if (
        node->matches(
            "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor")) {
      auto type = input_type(0);
      auto index_type = input_type(1);
      // index_select behaves very weirdly when self.dim() == 0. It allows both
      // 0D and 1D indices, and returns a value that has as many dimensions as
      // index.
      if (type && index_type && type->dim()) {
        if (*type->dim() == 0) {
          node->output()->setType(type->withDim(index_type->dim()));
        } else {
          node->output()->setType(type);
        }
        return true;
      }
    } else if (
        node->matches(
            "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor")) {
      auto type = input_type(0);
      auto index_type = input_type(1);
      // Gather has this annoying edge case where index always needs to match
      // the number of dims of self, **except** when self is 1D and index is 0D
      // in which case we return a 0D output.
      if (type && index_type && index_type->dim()) {
        if (*index_type->dim() == 0) {
          node->output()->setType(type->withDim(0));
        } else {
          node->output()->setType(type);
        }
        return true;
      }
    } else if (
        node->matches(
            "aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) {
      auto weight_type = input_type(0);
      auto indices_type = input_type(1);
      if (weight_type && indices_type && indices_type->dim()) {
        node->output()->setType(weight_type->withDim(*indices_type->dim() + 1));
        return true;
      }
    } else if (
        node->matches(
            "aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor")) {
      if (auto type = input_type(0)) {
        node->output()->setType(type);
        return true;
      }
      if (auto type = input_type(1)) {
        node->output()->setType(type);
        return true;
      }
    } else if (
        node->matches(
            "aten::dist(Tensor self, Tensor other, Scalar p) -> Tensor")) {
      if (auto type = any_tensor_type(node)) {
        node->output()->setType(type->withDim(0));
        return true;
      }
    }

    // The code below implements formulas that need type information for all
    // their tensor inputs, and have exactly one output.
    std::vector<TensorTypePtr> tensor_types;
    static const auto reshape_prop =
        [](Node* node,
           Symbol shape_input,
           const std::vector<TensorTypePtr>& tensor_types) -> TensorTypePtr {
      if (auto list_size = determineListSize(node->namedInput(shape_input))) {
        return tensor_types.at(0)->withDim(*list_size);
      }
      return nullptr;
    };
    const auto getSingleOutputType = [&]() -> TypePtr {
      if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
        return tensor_types.at(0)->withScalarType(
            tensor_types.at(1)->scalarType());
      } else if (
          node->matches(
              "aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)") ||
          node->matches(
              "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)") ||
          node->matches(
              "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)")) {
        return tensor_types.at(0)->withDim(tensor_types.at(1)->dim());
      } else if (
          node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
          node->matches(
              "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
          node->matches(
              "aten::as_strided(Tensor self, int[] size, int[] stride, int? storage_offset) -> Tensor")) {
        return reshape_prop(node, attr::size, tensor_types);
      } else if (
          node->matches(
              "aten::as_tensor(Tensor data, *, ScalarType? dtype, Device? device) -> Tensor")) {
        TypePtr input_type = node->inputs().at(0)->type();
        if (auto type = input_type->cast<TensorType>()) {
          if (type->scalarType() && type->device()) {
            at::ScalarType default_type = *type->scalarType();
            c10::Device default_device = *type->device();
            if (auto dtype_index =
                    node->schema().argumentIndexWithName("dtype")) {
              auto inp = toIValue(node->inputs().at(*dtype_index));
              if (inp == c10::nullopt) {
                return nullptr;
              }
              if (!inp->isNone()) {
                default_type = inp->toScalarType();
              }
            }
            if (auto device_index =
                    node->schema().argumentIndexWithName("device")) {
              auto inp = toIValue(node->inputs().at(*device_index));
              if (inp == c10::nullopt) {
                return nullptr;
              }
              if (!inp->isNone()) {
                default_device = inp->toDevice();
              }
            }
            node->output()->setType(TensorType::create(
                default_type,
                default_device,
                type->dim(),
                /*requires_grad=*/c10::nullopt));
          }
        }
        return nullptr;
      } else if (
          node->matches(
              "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)")) {
        return reshape_prop(node, attr::shape, tensor_types);
      } else if (node->matches(
                     "aten::repeat(Tensor self, int[] repeats) -> Tensor")) {
        return reshape_prop(node, attr::repeats, tensor_types);
      } else if (node->matches(
                     "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
        auto& t = tensor_types.at(0);
        if (!t->dim()) {
          return t;
        }
        return t->withDim(*t->dim() + 1);
      } else if (
          node->matches(
              "aten::select(Tensor self, int dim, int index) -> Tensor") ||
          node->matches(
              "aten::diagonal(Tensor self, int offset, int dim1, int dim2) -> Tensor")) {
        auto& t = tensor_types.at(0);
        return t->dim() && *t->dim() > 0 ? t->withDim(*t->dim() - 1) : nullptr;
      } else if (node->matches(
                     "aten::matmul(Tensor self, Tensor other) -> Tensor")) {
        if (!tensor_types.at(0)->dim() || !tensor_types.at(1)->dim()) {
          return nullptr;
        }
        int dim1 = *tensor_types.at(0)->dim();
        int dim2 = *tensor_types.at(1)->dim();
        if (dim1 == 1 && dim2 == 1) {
          // Dot product
          return tensor_types.at(0)->withDim(0);
          // NOLINTNEXTLINE(bugprone-branch-clone)
        } else if (dim1 == 2 && dim2 == 2) {
          // Matrix multiply
          return tensor_types.at(0);
        } else if (dim1 == 1 && dim2 == 2) {
          // Unsqueeze + matrix multiply + squeeze
          return tensor_types.at(0);
        } else if (dim1 == 2 && dim2 == 1) {
          // Matrix vector multiply
          return tensor_types.at(1);
        } else {
          // Batched matrix multiply (possibly with squeeze + unsqueeze if one
          // argument is 1D)
          auto type = broadcast(tensor_types, tensor_types[0]->scalarType());
          if (dim1 == 1 || dim2 == 1) {
            type = type->withDim(type->dim().value() - 1);
          }
          return type;
        }
      } else if (node->matches("aten::nonzero(Tensor self) -> Tensor")) {
        return tensor_types.at(0)->dimensionedOnly()->withScalarType(at::kLong);
      } else if (node->matches(
                     "aten::take(Tensor self, Tensor index) -> Tensor")) {
        return tensor_types.at(1)->dimensionedOnly()->withScalarType(
            tensor_types.at(0)->scalarType());
      } else if (node->matches(
                     "aten::diagflat(Tensor self, int offset) -> Tensor")) {
        return tensor_types.at(0)->withDim(2);
      } else if (node->matches(
                     "aten::diag(Tensor self, int diagonal) -> Tensor")) {
        auto& t = tensor_types.at(0);
        if (t->dim() && *t->dim() == 1) {
          return t->withDim(2);
        } else if (t->dim() && *t->dim() == 2) {
          return t->withDim(1);
        } else {
          return nullptr;
        }
      } else if (
          node->matches(
              "aten::unfold(Tensor self, int dimension, int size, int step) -> Tensor")) {
        auto& t = tensor_types.at(0);
        if (!t->dim()) {
          return nullptr;
        }
        return t->withDim(*t->dim() + 1);
      } else if (node->matches(
                     "aten::polygamma(int n, Tensor self) -> Tensor")) {
        return tensor_types.at(0);
      }
      return nullptr;
    };
    if (auto maybe_tensor_types = gatherTensorTypes(node)) {
      tensor_types = std::move(*maybe_tensor_types);
    } else {
      return false;
    }
    if (node->outputs().size() == 1) {
      if (auto type = getSingleOutputType()) {
        node->output()->setType(type);
        return true;
      }
    }
    return false;
  }

  bool PropagateCompleteShapeOnNode(
      Node* node,
      bool insert_expands,
      std::vector<TensorTypePtr> tensor_types) {
    // For expensive ops we can directly encode their shape propagation
    // here, otherwise we fallback to running a fake version of the op
    // to get a quick and dirty propagation.
    if (node->matches(
            "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
        node->matches(
            "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
        node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
      // These nodes handle tensors of different shapes internally, so there's
      // no need to insert explicit expand nodes.
      return PropagateShapeOnNodeByRunningIt(node);
    } else if (node->matches(
                   "aten::div(Tensor self, Tensor other) -> Tensor")) {
      // "div" handle tensors of different shapes internally, so there's no need
      // to insert explicit expand nodes.
      // Note that this function could be merged to the one above , but "div" is
      // not always safe to run by itself due to integer divide-by-zero.
      // We fake the execution by running "mul" operation instead.
      auto op = getOperatorForLiteral(
                    "aten::mul(Tensor self, Tensor other) -> Tensor")
                    ->getOperation();
      return PropagateShapeOnNodeByRunningIt(node, op);
    } else if (node->matches(
                   "aten::pow(Tensor self, Scalar exponent) -> Tensor")) {
      node->output()->setType(tensor_types.at(0));
      return true;
    } else if (
        node->matches(
            "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
        node->matches(
            "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
        node->matches("aten::div(Tensor self, Scalar other) -> Tensor") ||
        node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) {
      auto first_scalar_type = (tensor_types)[0]->scalarType();
      auto second_scalar_type =
          tryScalarTypeFromJitType(*node->inputs()[1]->type());
      if (!first_scalar_type || !second_scalar_type) {
        return false;
      }
      if (isIntegralType(*first_scalar_type, false) &&
          isFloatingType(*second_scalar_type)) {
        auto default_dtype =
            at::typeMetaToScalarType(caffe2::get_default_dtype());
        auto type = tensor_types[0]->withScalarType(default_dtype);
        node->output()->setType(type);
        return true;
      }
      if (c10::ScalarType::Bool == *first_scalar_type &&
          c10::ScalarType::Bool != *second_scalar_type) {
        auto result_type =
            c10::promoteTypes(*first_scalar_type, *second_scalar_type);
        auto type = tensor_types[0]->withScalarType(result_type);
        node->output()->setType(type);
        return true;
      }
      auto type = tensor_types[0]->withScalarType(first_scalar_type);
      node->output()->setType(type);
      return true;
    } else if (
        insert_expands &&
        (node->matches("aten::pow(Tensor self, Tensor exponent) -> Tensor") ||
         node->matches("aten::min(Tensor self, Tensor other) -> Tensor") ||
         node->matches("aten::max(Tensor self, Tensor other) -> Tensor") ||
         node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") ||
         node->matches("aten::le(Tensor self, Tensor other) -> Tensor") ||
         node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") ||
         node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") ||
         node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") ||
         node->matches("aten::ne(Tensor self, Tensor other) -> Tensor"))) {
      // Binary broadcasting ops
      // NB: we don't handle the nodes in any other way (note the lack of
      // return!), because the type casting logic in scalar cases is
      // non-trivial. It's better to just run them.
      broadcastBinary(node, tensor_types, 0, 1);
      return PropagateShapeOnNodeByRunningIt(node);
    } else if (
        node->matches(
            "aten::logit(Tensor self, float? eps = None) -> Tensor") ||
        node->matches("aten::neg(Tensor self) -> Tensor") ||
        node->matches("aten::sigmoid(Tensor self) -> Tensor") ||
        node->matches("aten::tanh(Tensor self) -> Tensor")) {
      node->output()->setType(tensor_types.at(0)->contiguous());
      return true;
    } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
      auto lhs_type = tensor_types.at(0);
      auto rhs_type = tensor_types.at(1);
      auto lhs_sizes = lhs_type->sizes().concrete_sizes().value();
      auto rhs_sizes = rhs_type->sizes().concrete_sizes().value();
      SHAPE_ASSERT(
          *lhs_type->sizes().size() == 2 && *rhs_type->sizes().size() == 2);
      node->output()->setType(TensorType::createContiguous(
          *lhs_type->scalarType(),
          *lhs_type->device(),
          at::IntArrayRef{lhs_sizes[0], rhs_sizes[1]}));
      return true;
    } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
      auto tp = tensor_types.at(0);
      auto sizes = tp->sizes().concrete_sizes().value();
      auto strides = tp->strides().concrete_sizes().value();
      SHAPE_ASSERT(sizes.size() == 2);
      std::swap(sizes.at(0), sizes.at(1));
      std::swap(strides.at(0), strides.at(1));
      node->output()->setType(tp->withSizesStrides(sizes, strides));
      return true;
    } else if (
        node->matches(
            "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
            /*const_inputs=*/{attr::dim, attr::length})) {
      auto tp = tensor_types.at(0);
      auto sizes = tp->sizes().concrete_sizes().value();
      int64_t dim = node->get<int64_t>(attr::dim).value();
      int64_t length = node->get<int64_t>(attr::length).value();
      SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
      sizes.at(dim) = length;
      node->output()->setType(
          tp->withSizesStrides(sizes, tp->strides().concrete_sizes().value()));
      return true;
    } else if (node->matches(
                   "aten::sum(Tensor self, *, int? dtype) -> Tensor")) {
      node->output()->setType(tensor_types.at(0)->withSizes({}));
      return true;
    } else if (
        node->matches(
            "aten::sum(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor",
            /*const_inputs=*/{attr::dim, attr::keepdim})) {
      auto& tp = tensor_types.at(0);
      auto sizes = tp->sizes().concrete_sizes().value();
      auto dims = node->get<c10::List<int64_t>>(attr::dim).value();
      bool keepdim = node->get<bool>(attr::keepdim).value();
      std::reverse(dims.begin(), dims.end());
      for (int64_t dim : dims) {
        SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
        if (keepdim) {
          sizes.at(dim) = 1;
        } else {
          sizes.erase(sizes.begin() + dim);
        }
      }
      node->output()->setType(tp->withSizes(sizes));
      return true;
    } else if (node->matches(
                   "aten::squeeze(Tensor self, int dim) -> Tensor",
                   /*const_inputs=*/attr::dim)) {
      auto& tp = tensor_types.at(0);
      auto sizes = tp->sizes().concrete_sizes().value();
      auto strides = tp->strides().concrete_sizes().value();
      int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
      SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
      if (sizes.at(dim) == 1) {
        sizes.erase(sizes.begin() + dim);
        strides.erase(strides.begin() + dim);
      }
      node->output()->setType(tp->withSizesStrides(sizes, strides));
      return true;
    } else if (node->matches(
                   "aten::unsqueeze(Tensor self, int dim) -> Tensor",
                   /*const_inputs=*/attr::dim)) {
      auto& tp = tensor_types.at(0);
      auto sizes = tp->sizes().concrete_sizes().value();
      auto strides = tp->strides().concrete_sizes().value();
      int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
      SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.size());
      int64_t new_stride = dim >= static_cast<int64_t>(sizes.size())
          ? 1
          : sizes.at(dim) * strides.at(dim);
      sizes.insert(sizes.begin() + dim, 1);
      strides.insert(strides.begin() + dim, new_stride);
      node->output()->setType(tp->withSizesStrides(sizes, strides));
      return true;
    } else if (node->matches(
                   "aten::view(Tensor self, int[] size) -> Tensor",
                   /*const_inputs=*/attr::size)) {
      auto sizes = node->get<c10::List<int64_t>>(attr::size).value();
      bool inferred = false;
      // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
      size_t inferred_idx;
      int64_t size_product = 1;
      for (const auto i : c10::irange(sizes.size())) {
        if (sizes.get(i) == -1) {
          if (inferred)
            throw propagation_error();
          inferred = true;
          inferred_idx = i;
        } else {
          size_product *= sizes.get(i);
        }
      }

      if (inferred) {
        SHAPE_ASSERT(size_product != 0);
        size_t numel = 1;
        auto concrete_sizes =
            tensor_types.at(0)->sizes().concrete_sizes().value();
        for (int64_t s : concrete_sizes)
          numel *= s;
        int64_t inferred_size = numel / size_product;
        sizes[inferred_idx] = inferred_size;
      }
      node->output()->setType(tensor_types.at(0)->withSizes(sizes.vec()));
      return true;
    } else if (node->matches(
                   "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
      if (tensor_types.at(0)->scalarType() ==
          tensor_types.at(1)->scalarType()) {
        node->output()->setType(node->namedInput(attr::self)->type());
      } else {
        // This will be a copy, so the result will be contiguous
        node->output()->setType(tensor_types.at(1)->withSizes(
            tensor_types.at(0)->sizes().concrete_sizes().value()));
      }
      return true;
    } else if (
        node->matches(
            "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
            /*const_inputs=*/attr::size)) {
      auto tp = tensor_types.at(0);
      auto sizesAndStrides = at::inferExpandGeometry_dimvector(
          tp->sizes().concrete_sizes().value(),
          tp->strides().concrete_sizes().value(),
          node->get<c10::List<int64_t>>(attr::size).value().vec());
      node->output()->setType(
          tp->withSizesStrides(sizesAndStrides.sizes, sizesAndStrides.strides));
      return true;
    } else if (
        node->matches(
            "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor",
            /*const_inputs=*/attr::dim)) {
      auto ten = tensor_types.at(0);
      auto index = tensor_types.at(1);
      int64_t dim = node->get<int64_t>(attr::dim).value();
      SHAPE_ASSERT(*index->sizes().size() == 1);
      SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < ten->sizes().size());
      std::vector<int64_t> sizes = ten->sizes().concrete_sizes().value();
      sizes[dim] = index->sizes()[0].value();
      node->output()->setType(ten->withSizes(sizes));
      return true;
    } else if (node->matches(
                   "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
                   /*const_inputs=*/{attr::chunks, attr::dim})) {
      auto input_type = tensor_types.at(0);
      auto sizes = input_type->sizes().concrete_sizes().value();
      auto strides = input_type->strides().concrete_sizes().value();
      int64_t dim = node->get<int64_t>(attr::dim).value();
      int64_t chunks = node->get<int64_t>(attr::chunks).value();
      sizes[dim] /= chunks;
      for (Value* output : node->outputs()) {
        output->setType(input_type->withSizesStrides(sizes, strides));
      }
      if (*input_type->sizes()[dim] % chunks != 0) {
        sizes[dim] = *input_type->sizes()[dim] % chunks;
        node->outputs().back()->setType(
            input_type->withSizesStrides(sizes, strides));
      }
      return true;
    } else if (node->kind() == ::c10::onnx::Shape) {
      SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1);
      std::vector<int64_t> dim_vec = {
          (int64_t)*tensor_types.at(0)->sizes().size()};
      at::IntArrayRef dims(dim_vec);
      node->output()->setType(
          TensorType::createContiguous(at::kLong, at::kCPU, dims));
      return true;
    } else if (node->kind() == ::c10::onnx::Reshape) {
      setUnshapedType(node);
      return true;
    }
    setUnshapedType(node);
    return false;
  }
};
} // anonymous namespace