void AliasDb::analyzeImpl()

in torch/csrc/jit/ir/alias_analysis.cpp [568:878]


void AliasDb::analyzeImpl(Node* node) {
  auto op = node->maybeOperator();
  const bool hasSpecialCase = aliasAnalysisHasSpecialCaseFor(node->kind());
  if (op) {
    const auto analysis = op->aliasAnalysisKind();

    const bool registeredAsSpecialCase =
        analysis == AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
    if (C10_UNLIKELY(registeredAsSpecialCase && !hasSpecialCase)) {
      TORCH_INTERNAL_ASSERT(
          false,
          "Op ",
          node->kind().toDisplayString(),
          " is registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but doesn't have a special case.");
    } else if (C10_UNLIKELY(!registeredAsSpecialCase && hasSpecialCase)) {
      TORCH_INTERNAL_ASSERT(
          false,
          "Op ",
          node->kind().toDisplayString(),
          " has a special case and should be registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but is registered with ",
          c10::toString(analysis));
    }
  } else {
    if (!hasSpecialCase) {
      std::ostringstream oss;
      for (const auto input : node->inputs()) {
        oss << input->type()->str() << ", ";
      }
      oss << "\n\nCandidates:";
      const auto& candidates = getAllOperatorsFor(node->kind());
      for (const auto& candidate : candidates) {
        oss << "\n\t" << candidate->schema();
      }
      TORCH_INTERNAL_ASSERT(
          0,
          "We don't have an op for ",
          node->kind().toDisplayString(),
          " but it isn't a special case.  ",
          "Argument types: ",
          oss.str());
    }
  }

  // These nodes are not schematized, so we need to handle them specially
  switch (node->kind()) {
    case prim::If:
      return analyzeIf(node);
    case prim::Loop:
      return analyzeLoop(node);
    case prim::FusionGroup:
    case prim::CudaFusionGroup:
    case prim::FunctionalGraph:
    case prim::DifferentiableGraph:
    case prim::FallbackGraph:
      return analyzeSubgraph(node);
    case prim::fork:
      return analyzeFork(node);
    case aten::wait:
      return analyzeWait(node);
    case prim::rpc_async:
    case prim::rpc_sync:
    case prim::rpc_remote:
      return analyzeRpcAsync(node);
    case aten::batch_norm:
      return analyzeBatchNorm(node);
    case aten::instance_norm:
      return analyzeInstanceNorm(node);
    case prim::GradOf:
      return analyzeGradOf(node);
    case prim::BroadcastMKLDNNTensors: {
      makePointerTo(node->outputs().at(0), node->inputs().at(0));
      makePointerTo(node->outputs().at(1), node->inputs().at(1));
      return;
    }
    // TODO: think more about TensorExpr alias correctness
    case prim::TensorExprGroup:
    case prim::TensorExprDynamicGroup:
    case prim::MKLDNNGroup:
    case prim::ConstantMKLDNNTensor:
    case prim::StaticSubgraph:
    case prim::Constant:
    case prim::AutogradZero:
    case prim::AutogradAdd:
    case prim::FusedConcat:
    case prim::MMTreeReduce:
    case prim::MMBatchSide:
    case prim::BroadcastSizes:
    case prim::ChunkSizes:
    case prim::Closure:
    case prim::CreateObject:
    case prim::tolist:
    case prim::Uninitialized:
      return analyzeCreator(node);
    case prim::TupleConstruct:
    case prim::DictConstruct:
    case prim::ListConstruct:
      return analyzeContainerConstruct(node);
    case prim::TupleUnpack:
    case prim::TupleIndex:
    case prim::TupleSlice:
    case prim::ListUnpack:
    case prim::PythonOp:
    case prim::GetAttr:
      if (isFrozen_ && node->kind() == prim::GetAttr) {
        auto& ty = node->input()->type();
        if (ty->expectRef<ClassType>().is_module()) {
          return analyzeCreator(node);
        }
      }
      return analyzeExtractor(node);
    case prim::unchecked_cast:
      return makePointerTo(node->output(), node->input());
    case prim::ConstantChunk:
      return analyzeChunk(node);
    case prim::BroadcastingChunk:
      return analyzeBroadcastingChunk(node);
    case prim::SetAttr:
      return analyzeSetAttr(node);
    case prim::profile_ivalue:
    case prim::profile:
      makePointerTo(node->output(), node->inputs().at(0));
      return;
    case prim::TypeCheck:
    case prim::RequiresGradCheck: {
      auto num_inputs = node->inputs().size();
      for (const auto i : c10::irange(num_inputs)) {
        makePointerTo(node->outputs().at(i), node->inputs().at(i));
      }
      return;
    }
    case prim::BailOut:
      TORCH_INTERNAL_ASSERT(
          node->inputs().at(0)->node()->kind() == prim::BailoutTemplate);
      makePointerTo(node->output(), node->inputs().at(1));
      return;
    case prim::Guard:
      makePointerTo(node->output(), node->inputs().at(0));
      return;
    case prim::CallFunction:
    case prim::CallMethod:
    case prim::Enter:
    case prim::Exit:
      // TODO: this can be improved with summarizes of what the function does
      // for now we assume the worst
      // NB: update safeToChangeAliasingRelationship if changed
      return analyzeConservative(node);
    case prim::Print:
    case prim::isinstance:
      // These ops do nothing
      return;
    default:
      if (tryRegisteredAnalysis(node)) {
        return;
      }
  }

  TORCH_INTERNAL_ASSERT(op, "We should have an op schema if we get to here");
  const AliasAnalysisKind analysis = op->aliasAnalysisKind();
  TORCH_INTERNAL_ASSERT(
      analysis != AliasAnalysisKind::INTERNAL_SPECIAL_CASE &&
          !aliasAnalysisHasSpecialCaseFor(node->kind()),
      "Special cases should be handled already if we're here.");

  if (node->kind().is_aten() || node->kind().is_prim() ||
      node->kind().is_cuda()) {
    // TODO There is nothing in the system that relies on aten:: and prim::
    // ops using AliasAnalysisKind::FROM_SCHEMA or
    // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, but this is the intended
    // behavior for all current ops and a good error check. We can consider
    // lifting this constraint later if we have a use case for it.
    TORCH_INTERNAL_ASSERT(
        analysis == AliasAnalysisKind::FROM_SCHEMA ||
            analysis == AliasAnalysisKind::CONSERVATIVE,
        "aten:: and prim:: operators should use AliasAnalysisKind::FROM_SCHEMA or "
        "AliasAnalysisKind::CONSERVATIVE(if really necessary), but ",
        node->kind().toDisplayString(),
        " doesn't. Note: Ideally, prim:: operators actually shouldn't have a schema ",
        "and then use AliasAnalysisKind::INTERNAL_SPECIAL_CASE instead.");
  }

  if (analysis == AliasAnalysisKind::CONSERVATIVE) {
    // TODO A previous implementation of alias analysis always accessed
    // node->schema , which cause the schema caches in the Node class to be
    // filled for the full graph. Unfortunately, our JIT passes started relying
    // on that, so we need to keep doing this. Details: in
    // caffe2/torch/onnx/utils.py, _jit_pass_onnx is called on an invalid JIT
    // graph because we called _jit_pass_erase_number_types right before and
    // ints are now Tensors instead. So if _jit_pass_onnx tries to look up
    // operator schemas, it will crash. However, _jit_pass_constant_propagation,
    // which is called before it, runs alias analysis and prefills the schema
    // cache in the all Node instances so that _jit_pass_onnx doesn't look up
    // operators to get the schemas anymore. We should fix this.
    node->schema(); // fill the schema cache in the Node class

    return analyzeConservative(node);
  }

  TORCH_INTERNAL_ASSERT(
      analysis == AliasAnalysisKind::FROM_SCHEMA,
      "AliasAnalysisKind::CONSERVATIVE/PURE_FUNCTION/INTERNAL_SPECIAL_CASE should already have been handled above");
  const auto& schema = node->schema();

  // Bind the schema's "formal" alias annotation to the actual values those
  // schema arguments represent
  std::unordered_map<Symbol, Value*> formalToActual;
  for (const auto i : c10::irange(schema.arguments().size())) {
    const at::AliasInfo* formal = schema.arguments()[i].alias_info();
    const auto& actualValue = node->inputs().at(i);
    // Skip if there's no alias annotation
    if (!formal) {
      continue;
    }

    // If this type cannot alias, continue. Can occur with a VarType schema
    if (!isMutableTypeInternal(actualValue)) {
      continue;
    }

    // Do sanity checks on the alias annotation
    TORCH_INTERNAL_ASSERT(
        formal->containedTypes().size() == 0,
        "Composite types for alias analysis not yet supported");
    TORCH_INTERNAL_ASSERT(
        !formal->isWildcardBefore(),
        "Doesn't make sense for a input value to begin as a wildcard");

    const auto& formalAlias = formal->beforeSet();

    // skip if we've already bound this alias
    if (formalToActual.count(formalAlias) != 0) {
      continue;
    }

    // Bind the formal to the actual
    formalToActual[formalAlias] = actualValue;

    // Record writes
    if (formal->isWrite()) {
      registerWrite(actualValue, node);
    }

    // Now deal with sets after the '->'
    if (formal->isWildcardAfter()) {
      TORCH_INTERNAL_ASSERT(
          formal->afterSets().size() == 1,
          "If the after set contains a wildcard, "
          "there should be no other alias sets specified.");
      setWildcard(actualValue);
    } else {
      // We don't understand anything else in the after yet, so assert there's
      // been no change.
      TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
    }
  }

  // Use the formal-actual mapping to give aliases to the outputs
  for (const auto i : c10::irange(schema.returns().size())) {
    const auto actual = node->outputs().at(i);
    const at::AliasInfo* formal = schema.returns()[i].alias_info();
    if (!formal) {
      // This is a fresh tensor
      giveFreshAlias(actual);
      continue;
    }

    // If this type cannot alias, continue. Can occur with a VarType schema
    if (!isMutableType(actual)) {
      continue;
    }

    TORCH_INTERNAL_ASSERT(
        formal->containedTypes().size() == 0,
        "Composite types for alias analysis not yet supported");
    TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());

    if (formal->isWildcardBefore()) {
      TORCH_INTERNAL_ASSERT(
          formal->beforeSets().size() == 1,
          "If an output is a wildcard, "
          "there should be no other alias sets specified.");
      setWildcard(actual);
      continue;
    }

    for (const auto& formalAlias : formal->beforeSets()) {
      // If we encounter an alias annotation that wasn't in the inputs:
      if (!formalToActual.count(formalAlias)) {
        // If this alias is not seen elsewhere and is the only annotation on
        // the output, it's equivalent to being fresh:
        //   e.g. foo(Tensor(a) self) -> Tensor(b)
        if (formal->beforeSets().size() == 1) {
          giveFreshAlias(actual);
        }
        // Or it is the form of a|fresh, which we can ignore, taking the
        // conservative assumption that the output must alias `a`, e.g
        //   aten::cuda(Tensor(a) self) -> Tensor(a|fresh)

        // Don't assign an alias set in that case.
        continue;
      }

      auto toAlias = formalToActual.at(formalAlias);
      makePointerTo(actual, toAlias);
    }

    // Record writes
    if (formal->isWrite()) {
      registerWrite(actual, node);
    }
  }
}