in packages/pyright-internal/src/analyzer/typeGuards.ts [88:606]
export function getTypeNarrowingCallback(
evaluator: TypeEvaluator,
reference: ExpressionNode,
testExpression: ExpressionNode,
isPositiveTest: boolean
): TypeNarrowingCallback | undefined {
if (testExpression.nodeType === ParseNodeType.AssignmentExpression) {
return (
getTypeNarrowingCallback(evaluator, reference, testExpression.rightExpression, isPositiveTest) ??
getTypeNarrowingCallback(evaluator, reference, testExpression.name, isPositiveTest)
);
}
if (testExpression.nodeType === ParseNodeType.BinaryOperation) {
const isOrIsNotOperator =
testExpression.operator === OperatorType.Is || testExpression.operator === OperatorType.IsNot;
const equalsOrNotEqualsOperator =
testExpression.operator === OperatorType.Equals || testExpression.operator === OperatorType.NotEquals;
if (isOrIsNotOperator || equalsOrNotEqualsOperator) {
// Invert the "isPositiveTest" value if this is an "is not" operation.
const adjIsPositiveTest =
testExpression.operator === OperatorType.Is || testExpression.operator === OperatorType.Equals
? isPositiveTest
: !isPositiveTest;
// Look for "X is None", "X is not None", "X == None", and "X != None".
// These are commonly-used patterns used in control flow.
if (
testExpression.rightExpression.nodeType === ParseNodeType.Constant &&
testExpression.rightExpression.constType === KeywordType.None
) {
// Allow the LHS to be either a simple expression or an assignment
// expression that assigns to a simple name.
let leftExpression = testExpression.leftExpression;
if (leftExpression.nodeType === ParseNodeType.AssignmentExpression) {
leftExpression = leftExpression.name;
}
if (ParseTreeUtils.isMatchingExpression(reference, leftExpression)) {
return (type: Type) => {
return narrowTypeForIsNone(evaluator, type, adjIsPositiveTest);
};
}
if (
leftExpression.nodeType === ParseNodeType.Index &&
ParseTreeUtils.isMatchingExpression(reference, leftExpression.baseExpression) &&
leftExpression.items.length === 1 &&
!leftExpression.trailingComma &&
leftExpression.items[0].argumentCategory === ArgumentCategory.Simple &&
!leftExpression.items[0].name &&
leftExpression.items[0].valueExpression.nodeType === ParseNodeType.Number &&
leftExpression.items[0].valueExpression.isInteger &&
!leftExpression.items[0].valueExpression.isImaginary
) {
const indexValue = leftExpression.items[0].valueExpression.value;
if (typeof indexValue === 'number') {
return (type: Type) => {
return narrowTupleTypeForIsNone(evaluator, type, adjIsPositiveTest, indexValue);
};
}
}
}
// Look for "type(X) is Y" or "type(X) is not Y".
if (isOrIsNotOperator && testExpression.leftExpression.nodeType === ParseNodeType.Call) {
const callType = evaluator.getTypeOfExpression(
testExpression.leftExpression.leftExpression,
/* expectedType */ undefined,
EvaluatorFlags.DoNotSpecialize
).type;
if (
isInstantiableClass(callType) &&
ClassType.isBuiltIn(callType, 'type') &&
testExpression.leftExpression.arguments.length === 1 &&
testExpression.leftExpression.arguments[0].argumentCategory === ArgumentCategory.Simple
) {
const arg0Expr = testExpression.leftExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
const classType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
if (isInstantiableClass(classType)) {
return (type: Type) => {
return narrowTypeForTypeIs(type, classType, adjIsPositiveTest);
};
}
}
}
}
// Look for "X is Y" or "X is not Y" where Y is a an enum or bool literal.
if (isOrIsNotOperator) {
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
if (
isClassInstance(rightType) &&
(ClassType.isEnumClass(rightType) || ClassType.isBuiltIn(rightType, 'bool')) &&
rightType.literalValue !== undefined
) {
return (type: Type) => {
return narrowTypeForLiteralComparison(
evaluator,
type,
rightType,
adjIsPositiveTest,
/* isIsOperator */ true
);
};
}
}
}
if (equalsOrNotEqualsOperator) {
// Look for X == <literal> or X != <literal>
const adjIsPositiveTest =
testExpression.operator === OperatorType.Equals ? isPositiveTest : !isPositiveTest;
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForLiteralComparison(
evaluator,
type,
rightType,
adjIsPositiveTest,
/* isIsOperator */ false
);
};
}
}
// Look for <literal> == X or <literal> != X
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) {
const leftType = evaluator.getTypeOfExpression(testExpression.leftExpression).type;
if (isClassInstance(leftType) && leftType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForLiteralComparison(
evaluator,
type,
leftType,
adjIsPositiveTest,
/* isIsOperator */ false
);
};
}
}
// Look for X[<literal>] == <literal> or X[<literal>] != <literal>
if (
testExpression.leftExpression.nodeType === ParseNodeType.Index &&
testExpression.leftExpression.items.length === 1 &&
!testExpression.leftExpression.trailingComma &&
testExpression.leftExpression.items[0].argumentCategory === ArgumentCategory.Simple &&
ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.baseExpression)
) {
const indexType = evaluator.getTypeOfExpression(
testExpression.leftExpression.items[0].valueExpression
).type;
if (isClassInstance(indexType) && isLiteralType(indexType)) {
if (ClassType.isBuiltIn(indexType, 'str')) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForDiscriminatedDictEntryComparison(
evaluator,
type,
indexType,
rightType,
adjIsPositiveTest
);
};
}
} else if (ClassType.isBuiltIn(indexType, 'int')) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForDiscriminatedTupleComparison(
evaluator,
type,
indexType,
rightType,
adjIsPositiveTest
);
};
}
}
}
}
}
// Look for len(x) == <literal> or len(x) != <literal>
if (
equalsOrNotEqualsOperator &&
testExpression.leftExpression.nodeType === ParseNodeType.Call &&
testExpression.leftExpression.arguments.length === 1 &&
testExpression.rightExpression.nodeType === ParseNodeType.Number &&
testExpression.rightExpression.isInteger
) {
const arg0Expr = testExpression.leftExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
const callType = evaluator.getTypeOfExpression(
testExpression.leftExpression.leftExpression,
/* expectedType */ undefined,
EvaluatorFlags.DoNotSpecialize
).type;
if (isFunction(callType) && callType.details.fullName === 'builtins.len') {
const tupleLength = testExpression.rightExpression.value;
if (typeof tupleLength === 'number') {
return (type: Type) => {
return narrowTypeForTupleLength(evaluator, type, tupleLength, adjIsPositiveTest);
};
}
}
}
}
// Look for X.Y == <literal> or X.Y != <literal>
if (
equalsOrNotEqualsOperator &&
testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess &&
ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression)
) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
const memberName = testExpression.leftExpression.memberName;
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForDiscriminatedFieldComparison(
evaluator,
type,
memberName.value,
rightType,
adjIsPositiveTest
);
};
}
}
// Look for X.Y is <literal> or X.Y is not <literal> where <literal> is
// an enum or bool literal
if (
testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess &&
ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression)
) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
const memberName = testExpression.leftExpression.memberName;
if (
isClassInstance(rightType) &&
(ClassType.isEnumClass(rightType) || ClassType.isBuiltIn(rightType, 'bool')) &&
rightType.literalValue !== undefined
) {
return (type: Type) => {
return narrowTypeForDiscriminatedFieldComparison(
evaluator,
type,
memberName.value,
rightType,
adjIsPositiveTest
);
};
}
}
}
if (testExpression.operator === OperatorType.In) {
// Look for "x in y" where y is one of several built-in types.
if (isPositiveTest && ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
return (type: Type) => {
return narrowTypeForContains(evaluator, type, rightType);
};
}
}
if (testExpression.operator === OperatorType.In || testExpression.operator === OperatorType.NotIn) {
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) {
// Look for <string literal> in y where y is a union that contains
// one or more TypedDicts.
const leftType = evaluator.getTypeOfExpression(testExpression.leftExpression).type;
if (isClassInstance(leftType) && ClassType.isBuiltIn(leftType, 'str') && isLiteralType(leftType)) {
const adjIsPositiveTest =
testExpression.operator === OperatorType.In ? isPositiveTest : !isPositiveTest;
return (type: Type) => {
return narrowTypeForTypedDictKey(
evaluator,
type,
ClassType.cloneAsInstantiable(leftType),
adjIsPositiveTest
);
};
}
}
}
}
if (testExpression.nodeType === ParseNodeType.Call) {
const callType = evaluator.getTypeOfExpression(
testExpression.leftExpression,
/* expectedType */ undefined,
EvaluatorFlags.DoNotSpecialize
).type;
// Look for "isinstance(X, Y)" or "issubclass(X, Y)".
if (
isFunction(callType) &&
(callType.details.builtInName === 'isinstance' || callType.details.builtInName === 'issubclass') &&
testExpression.arguments.length === 2
) {
// Make sure the first parameter is a supported expression type
// and the second parameter is a valid class type or a tuple
// of valid class types.
const isInstanceCheck = callType.details.builtInName === 'isinstance';
const arg0Expr = testExpression.arguments[0].valueExpression;
const arg1Expr = testExpression.arguments[1].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
const arg1Type = evaluator.getTypeOfExpression(
arg1Expr,
undefined,
EvaluatorFlags.EvaluateStringLiteralAsType |
EvaluatorFlags.ParamSpecDisallowed |
EvaluatorFlags.TypeVarTupleDisallowed
).type;
const classTypeList = getIsInstanceClassTypes(arg1Type);
if (classTypeList) {
return (type: Type) => {
const narrowedType = narrowTypeForIsInstance(
evaluator,
type,
classTypeList,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ false,
testExpression
);
if (!isNever(narrowedType)) {
return narrowedType;
}
// Try again with intersection types allowed.
return narrowTypeForIsInstance(
evaluator,
type,
classTypeList,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ true,
testExpression
);
};
}
}
}
// Look for "callable(X)"
if (
isFunction(callType) &&
callType.details.builtInName === 'callable' &&
testExpression.arguments.length === 1
) {
const arg0Expr = testExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
return (type: Type) => {
let narrowedType = narrowTypeForCallable(
evaluator,
type,
isPositiveTest,
testExpression,
/* allowIntersections */ false
);
if (isPositiveTest && isNever(narrowedType)) {
// Try again with intersections allowed.
narrowedType = narrowTypeForCallable(
evaluator,
type,
isPositiveTest,
testExpression,
/* allowIntersections */ true
);
}
return narrowedType;
};
}
}
// Look for "bool(X)"
if (
isInstantiableClass(callType) &&
ClassType.isBuiltIn(callType, 'bool') &&
testExpression.arguments.length === 1 &&
!testExpression.arguments[0].name
) {
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.arguments[0].valueExpression)) {
return (type: Type) => {
return narrowTypeForTruthiness(evaluator, type, isPositiveTest);
};
}
}
// Look for a TypeGuard function.
if (testExpression.arguments.length >= 1) {
const arg0Expr = testExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
// Does this look like it's a custom type guard function?
if (
isFunction(callType) &&
callType.details.declaredReturnType &&
isClassInstance(callType.details.declaredReturnType) &&
ClassType.isBuiltIn(callType.details.declaredReturnType, ['TypeGuard', 'StrictTypeGuard'])
) {
// Evaluate the type guard call expression.
const functionReturnType = evaluator.getTypeOfExpression(testExpression).type;
if (
isClassInstance(functionReturnType) &&
ClassType.isBuiltIn(functionReturnType, 'bool') &&
functionReturnType.typeGuardType
) {
const isStrictTypeGuard = ClassType.isBuiltIn(
callType.details.declaredReturnType,
'StrictTypeGuard'
);
const typeGuardType = functionReturnType.typeGuardType;
return (type: Type) => {
return narrowTypeForUserDefinedTypeGuard(
evaluator,
type,
typeGuardType,
isPositiveTest,
isStrictTypeGuard
);
};
}
}
}
}
}
if (ParseTreeUtils.isMatchingExpression(reference, testExpression)) {
return (type: Type) => {
return narrowTypeForTruthiness(evaluator, type, isPositiveTest);
};
}
// Is this a reference to an aliased conditional expression (a local variable
// that was assigned a value that can inform type narrowing of the reference expression)?
if (
testExpression.nodeType === ParseNodeType.Name &&
reference.nodeType === ParseNodeType.Name &&
testExpression !== reference
) {
// Make sure the reference expression is a constant parameter or variable.
// If the reference expression is modified within the scope multiple times,
// we need to validate that it is not modified between the test expression
// evaluation and the conditional check.
const testExprDecl = getDeclsForLocalVar(evaluator, testExpression, testExpression);
if (testExprDecl && testExprDecl.length === 1 && testExprDecl[0].type === DeclarationType.Variable) {
const referenceDecls = getDeclsForLocalVar(evaluator, reference, testExpression);
if (referenceDecls) {
let modifyingDecls: Declaration[] = [];
if (referenceDecls.length > 1) {
// If there is more than one assignment to the reference variable within
// the local scope, make sure that none of these assignments are done
// after the test expression but before the condition check.
//
// This is OK:
// val = None
// is_none = val is None
// if is_none: ...
//
// This is not OK:
// val = None
// is_none = val is None
// val = 1
// if is_none: ...
modifyingDecls = referenceDecls.filter((decl) => {
return (
evaluator.isNodeReachable(testExpression, decl.node) &&
evaluator.isNodeReachable(decl.node, testExprDecl[0].node)
);
});
}
if (modifyingDecls.length === 0) {
const initNode = testExprDecl[0].inferredTypeSource;
if (
initNode &&
!ParseTreeUtils.isNodeContainedWithin(testExpression, initNode) &&
isExpressionNode(initNode)
) {
return getTypeNarrowingCallback(evaluator, reference, initNode, isPositiveTest);
}
}
}
}
}
// We normally won't find a "not" operator here because they are stripped out
// by the binder when it creates condition flow nodes, but we can find this
// in the case of local variables type narrowing.
if (testExpression.nodeType === ParseNodeType.UnaryOperation) {
if (testExpression.operator === OperatorType.Not) {
return getTypeNarrowingCallback(evaluator, reference, testExpression.expression, !isPositiveTest);
}
}
return undefined;
}