in torch/csrc/jit/ir/alias_analysis.cpp [568:878]
void AliasDb::analyzeImpl(Node* node) {
auto op = node->maybeOperator();
const bool hasSpecialCase = aliasAnalysisHasSpecialCaseFor(node->kind());
if (op) {
const auto analysis = op->aliasAnalysisKind();
const bool registeredAsSpecialCase =
analysis == AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
if (C10_UNLIKELY(registeredAsSpecialCase && !hasSpecialCase)) {
TORCH_INTERNAL_ASSERT(
false,
"Op ",
node->kind().toDisplayString(),
" is registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but doesn't have a special case.");
} else if (C10_UNLIKELY(!registeredAsSpecialCase && hasSpecialCase)) {
TORCH_INTERNAL_ASSERT(
false,
"Op ",
node->kind().toDisplayString(),
" has a special case and should be registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but is registered with ",
c10::toString(analysis));
}
} else {
if (!hasSpecialCase) {
std::ostringstream oss;
for (const auto input : node->inputs()) {
oss << input->type()->str() << ", ";
}
oss << "\n\nCandidates:";
const auto& candidates = getAllOperatorsFor(node->kind());
for (const auto& candidate : candidates) {
oss << "\n\t" << candidate->schema();
}
TORCH_INTERNAL_ASSERT(
0,
"We don't have an op for ",
node->kind().toDisplayString(),
" but it isn't a special case. ",
"Argument types: ",
oss.str());
}
}
// These nodes are not schematized, so we need to handle them specially
switch (node->kind()) {
case prim::If:
return analyzeIf(node);
case prim::Loop:
return analyzeLoop(node);
case prim::FusionGroup:
case prim::CudaFusionGroup:
case prim::FunctionalGraph:
case prim::DifferentiableGraph:
case prim::FallbackGraph:
return analyzeSubgraph(node);
case prim::fork:
return analyzeFork(node);
case aten::wait:
return analyzeWait(node);
case prim::rpc_async:
case prim::rpc_sync:
case prim::rpc_remote:
return analyzeRpcAsync(node);
case aten::batch_norm:
return analyzeBatchNorm(node);
case aten::instance_norm:
return analyzeInstanceNorm(node);
case prim::GradOf:
return analyzeGradOf(node);
case prim::BroadcastMKLDNNTensors: {
makePointerTo(node->outputs().at(0), node->inputs().at(0));
makePointerTo(node->outputs().at(1), node->inputs().at(1));
return;
}
// TODO: think more about TensorExpr alias correctness
case prim::TensorExprGroup:
case prim::TensorExprDynamicGroup:
case prim::MKLDNNGroup:
case prim::ConstantMKLDNNTensor:
case prim::StaticSubgraph:
case prim::Constant:
case prim::AutogradZero:
case prim::AutogradAdd:
case prim::FusedConcat:
case prim::MMTreeReduce:
case prim::MMBatchSide:
case prim::BroadcastSizes:
case prim::ChunkSizes:
case prim::Closure:
case prim::CreateObject:
case prim::tolist:
case prim::Uninitialized:
return analyzeCreator(node);
case prim::TupleConstruct:
case prim::DictConstruct:
case prim::ListConstruct:
return analyzeContainerConstruct(node);
case prim::TupleUnpack:
case prim::TupleIndex:
case prim::TupleSlice:
case prim::ListUnpack:
case prim::PythonOp:
case prim::GetAttr:
if (isFrozen_ && node->kind() == prim::GetAttr) {
auto& ty = node->input()->type();
if (ty->expectRef<ClassType>().is_module()) {
return analyzeCreator(node);
}
}
return analyzeExtractor(node);
case prim::unchecked_cast:
return makePointerTo(node->output(), node->input());
case prim::ConstantChunk:
return analyzeChunk(node);
case prim::BroadcastingChunk:
return analyzeBroadcastingChunk(node);
case prim::SetAttr:
return analyzeSetAttr(node);
case prim::profile_ivalue:
case prim::profile:
makePointerTo(node->output(), node->inputs().at(0));
return;
case prim::TypeCheck:
case prim::RequiresGradCheck: {
auto num_inputs = node->inputs().size();
for (const auto i : c10::irange(num_inputs)) {
makePointerTo(node->outputs().at(i), node->inputs().at(i));
}
return;
}
case prim::BailOut:
TORCH_INTERNAL_ASSERT(
node->inputs().at(0)->node()->kind() == prim::BailoutTemplate);
makePointerTo(node->output(), node->inputs().at(1));
return;
case prim::Guard:
makePointerTo(node->output(), node->inputs().at(0));
return;
case prim::CallFunction:
case prim::CallMethod:
case prim::Enter:
case prim::Exit:
// TODO: this can be improved with summarizes of what the function does
// for now we assume the worst
// NB: update safeToChangeAliasingRelationship if changed
return analyzeConservative(node);
case prim::Print:
case prim::isinstance:
// These ops do nothing
return;
default:
if (tryRegisteredAnalysis(node)) {
return;
}
}
TORCH_INTERNAL_ASSERT(op, "We should have an op schema if we get to here");
const AliasAnalysisKind analysis = op->aliasAnalysisKind();
TORCH_INTERNAL_ASSERT(
analysis != AliasAnalysisKind::INTERNAL_SPECIAL_CASE &&
!aliasAnalysisHasSpecialCaseFor(node->kind()),
"Special cases should be handled already if we're here.");
if (node->kind().is_aten() || node->kind().is_prim() ||
node->kind().is_cuda()) {
// TODO There is nothing in the system that relies on aten:: and prim::
// ops using AliasAnalysisKind::FROM_SCHEMA or
// AliasAnalysisKind::INTERNAL_SPECIAL_CASE, but this is the intended
// behavior for all current ops and a good error check. We can consider
// lifting this constraint later if we have a use case for it.
TORCH_INTERNAL_ASSERT(
analysis == AliasAnalysisKind::FROM_SCHEMA ||
analysis == AliasAnalysisKind::CONSERVATIVE,
"aten:: and prim:: operators should use AliasAnalysisKind::FROM_SCHEMA or "
"AliasAnalysisKind::CONSERVATIVE(if really necessary), but ",
node->kind().toDisplayString(),
" doesn't. Note: Ideally, prim:: operators actually shouldn't have a schema ",
"and then use AliasAnalysisKind::INTERNAL_SPECIAL_CASE instead.");
}
if (analysis == AliasAnalysisKind::CONSERVATIVE) {
// TODO A previous implementation of alias analysis always accessed
// node->schema , which cause the schema caches in the Node class to be
// filled for the full graph. Unfortunately, our JIT passes started relying
// on that, so we need to keep doing this. Details: in
// caffe2/torch/onnx/utils.py, _jit_pass_onnx is called on an invalid JIT
// graph because we called _jit_pass_erase_number_types right before and
// ints are now Tensors instead. So if _jit_pass_onnx tries to look up
// operator schemas, it will crash. However, _jit_pass_constant_propagation,
// which is called before it, runs alias analysis and prefills the schema
// cache in the all Node instances so that _jit_pass_onnx doesn't look up
// operators to get the schemas anymore. We should fix this.
node->schema(); // fill the schema cache in the Node class
return analyzeConservative(node);
}
TORCH_INTERNAL_ASSERT(
analysis == AliasAnalysisKind::FROM_SCHEMA,
"AliasAnalysisKind::CONSERVATIVE/PURE_FUNCTION/INTERNAL_SPECIAL_CASE should already have been handled above");
const auto& schema = node->schema();
// Bind the schema's "formal" alias annotation to the actual values those
// schema arguments represent
std::unordered_map<Symbol, Value*> formalToActual;
for (const auto i : c10::irange(schema.arguments().size())) {
const at::AliasInfo* formal = schema.arguments()[i].alias_info();
const auto& actualValue = node->inputs().at(i);
// Skip if there's no alias annotation
if (!formal) {
continue;
}
// If this type cannot alias, continue. Can occur with a VarType schema
if (!isMutableTypeInternal(actualValue)) {
continue;
}
// Do sanity checks on the alias annotation
TORCH_INTERNAL_ASSERT(
formal->containedTypes().size() == 0,
"Composite types for alias analysis not yet supported");
TORCH_INTERNAL_ASSERT(
!formal->isWildcardBefore(),
"Doesn't make sense for a input value to begin as a wildcard");
const auto& formalAlias = formal->beforeSet();
// skip if we've already bound this alias
if (formalToActual.count(formalAlias) != 0) {
continue;
}
// Bind the formal to the actual
formalToActual[formalAlias] = actualValue;
// Record writes
if (formal->isWrite()) {
registerWrite(actualValue, node);
}
// Now deal with sets after the '->'
if (formal->isWildcardAfter()) {
TORCH_INTERNAL_ASSERT(
formal->afterSets().size() == 1,
"If the after set contains a wildcard, "
"there should be no other alias sets specified.");
setWildcard(actualValue);
} else {
// We don't understand anything else in the after yet, so assert there's
// been no change.
TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
}
}
// Use the formal-actual mapping to give aliases to the outputs
for (const auto i : c10::irange(schema.returns().size())) {
const auto actual = node->outputs().at(i);
const at::AliasInfo* formal = schema.returns()[i].alias_info();
if (!formal) {
// This is a fresh tensor
giveFreshAlias(actual);
continue;
}
// If this type cannot alias, continue. Can occur with a VarType schema
if (!isMutableType(actual)) {
continue;
}
TORCH_INTERNAL_ASSERT(
formal->containedTypes().size() == 0,
"Composite types for alias analysis not yet supported");
TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
if (formal->isWildcardBefore()) {
TORCH_INTERNAL_ASSERT(
formal->beforeSets().size() == 1,
"If an output is a wildcard, "
"there should be no other alias sets specified.");
setWildcard(actual);
continue;
}
for (const auto& formalAlias : formal->beforeSets()) {
// If we encounter an alias annotation that wasn't in the inputs:
if (!formalToActual.count(formalAlias)) {
// If this alias is not seen elsewhere and is the only annotation on
// the output, it's equivalent to being fresh:
// e.g. foo(Tensor(a) self) -> Tensor(b)
if (formal->beforeSets().size() == 1) {
giveFreshAlias(actual);
}
// Or it is the form of a|fresh, which we can ignore, taking the
// conservative assumption that the output must alias `a`, e.g
// aten::cuda(Tensor(a) self) -> Tensor(a|fresh)
// Don't assign an alias set in that case.
continue;
}
auto toAlias = formalToActual.at(formalAlias);
makePointerTo(actual, toAlias);
}
// Record writes
if (formal->isWrite()) {
registerWrite(actual, node);
}
}
}