export function getTypeNarrowingCallback()

in packages/pyright-internal/src/analyzer/typeGuards.ts [88:606]


export function getTypeNarrowingCallback(
    evaluator: TypeEvaluator,
    reference: ExpressionNode,
    testExpression: ExpressionNode,
    isPositiveTest: boolean
): TypeNarrowingCallback | undefined {
    if (testExpression.nodeType === ParseNodeType.AssignmentExpression) {
        return (
            getTypeNarrowingCallback(evaluator, reference, testExpression.rightExpression, isPositiveTest) ??
            getTypeNarrowingCallback(evaluator, reference, testExpression.name, isPositiveTest)
        );
    }

    if (testExpression.nodeType === ParseNodeType.BinaryOperation) {
        const isOrIsNotOperator =
            testExpression.operator === OperatorType.Is || testExpression.operator === OperatorType.IsNot;
        const equalsOrNotEqualsOperator =
            testExpression.operator === OperatorType.Equals || testExpression.operator === OperatorType.NotEquals;

        if (isOrIsNotOperator || equalsOrNotEqualsOperator) {
            // Invert the "isPositiveTest" value if this is an "is not" operation.
            const adjIsPositiveTest =
                testExpression.operator === OperatorType.Is || testExpression.operator === OperatorType.Equals
                    ? isPositiveTest
                    : !isPositiveTest;

            // Look for "X is None", "X is not None", "X == None", and "X != None".
            // These are commonly-used patterns used in control flow.
            if (
                testExpression.rightExpression.nodeType === ParseNodeType.Constant &&
                testExpression.rightExpression.constType === KeywordType.None
            ) {
                // Allow the LHS to be either a simple expression or an assignment
                // expression that assigns to a simple name.
                let leftExpression = testExpression.leftExpression;
                if (leftExpression.nodeType === ParseNodeType.AssignmentExpression) {
                    leftExpression = leftExpression.name;
                }

                if (ParseTreeUtils.isMatchingExpression(reference, leftExpression)) {
                    return (type: Type) => {
                        return narrowTypeForIsNone(evaluator, type, adjIsPositiveTest);
                    };
                }

                if (
                    leftExpression.nodeType === ParseNodeType.Index &&
                    ParseTreeUtils.isMatchingExpression(reference, leftExpression.baseExpression) &&
                    leftExpression.items.length === 1 &&
                    !leftExpression.trailingComma &&
                    leftExpression.items[0].argumentCategory === ArgumentCategory.Simple &&
                    !leftExpression.items[0].name &&
                    leftExpression.items[0].valueExpression.nodeType === ParseNodeType.Number &&
                    leftExpression.items[0].valueExpression.isInteger &&
                    !leftExpression.items[0].valueExpression.isImaginary
                ) {
                    const indexValue = leftExpression.items[0].valueExpression.value;
                    if (typeof indexValue === 'number') {
                        return (type: Type) => {
                            return narrowTupleTypeForIsNone(evaluator, type, adjIsPositiveTest, indexValue);
                        };
                    }
                }
            }

            // Look for "type(X) is Y" or "type(X) is not Y".
            if (isOrIsNotOperator && testExpression.leftExpression.nodeType === ParseNodeType.Call) {
                const callType = evaluator.getTypeOfExpression(
                    testExpression.leftExpression.leftExpression,
                    /* expectedType */ undefined,
                    EvaluatorFlags.DoNotSpecialize
                ).type;

                if (
                    isInstantiableClass(callType) &&
                    ClassType.isBuiltIn(callType, 'type') &&
                    testExpression.leftExpression.arguments.length === 1 &&
                    testExpression.leftExpression.arguments[0].argumentCategory === ArgumentCategory.Simple
                ) {
                    const arg0Expr = testExpression.leftExpression.arguments[0].valueExpression;
                    if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
                        const classType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;

                        if (isInstantiableClass(classType)) {
                            return (type: Type) => {
                                return narrowTypeForTypeIs(type, classType, adjIsPositiveTest);
                            };
                        }
                    }
                }
            }

            // Look for "X is Y" or "X is not Y" where Y is a an enum or bool literal.
            if (isOrIsNotOperator) {
                if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
                    const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
                    if (
                        isClassInstance(rightType) &&
                        (ClassType.isEnumClass(rightType) || ClassType.isBuiltIn(rightType, 'bool')) &&
                        rightType.literalValue !== undefined
                    ) {
                        return (type: Type) => {
                            return narrowTypeForLiteralComparison(
                                evaluator,
                                type,
                                rightType,
                                adjIsPositiveTest,
                                /* isIsOperator */ true
                            );
                        };
                    }
                }
            }

            if (equalsOrNotEqualsOperator) {
                // Look for X == <literal> or X != <literal>
                const adjIsPositiveTest =
                    testExpression.operator === OperatorType.Equals ? isPositiveTest : !isPositiveTest;

                if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
                    const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
                    if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
                        return (type: Type) => {
                            return narrowTypeForLiteralComparison(
                                evaluator,
                                type,
                                rightType,
                                adjIsPositiveTest,
                                /* isIsOperator */ false
                            );
                        };
                    }
                }

                // Look for <literal> == X or <literal> != X
                if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) {
                    const leftType = evaluator.getTypeOfExpression(testExpression.leftExpression).type;
                    if (isClassInstance(leftType) && leftType.literalValue !== undefined) {
                        return (type: Type) => {
                            return narrowTypeForLiteralComparison(
                                evaluator,
                                type,
                                leftType,
                                adjIsPositiveTest,
                                /* isIsOperator */ false
                            );
                        };
                    }
                }

                // Look for X[<literal>] == <literal> or X[<literal>] != <literal>
                if (
                    testExpression.leftExpression.nodeType === ParseNodeType.Index &&
                    testExpression.leftExpression.items.length === 1 &&
                    !testExpression.leftExpression.trailingComma &&
                    testExpression.leftExpression.items[0].argumentCategory === ArgumentCategory.Simple &&
                    ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.baseExpression)
                ) {
                    const indexType = evaluator.getTypeOfExpression(
                        testExpression.leftExpression.items[0].valueExpression
                    ).type;

                    if (isClassInstance(indexType) && isLiteralType(indexType)) {
                        if (ClassType.isBuiltIn(indexType, 'str')) {
                            const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
                            if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
                                return (type: Type) => {
                                    return narrowTypeForDiscriminatedDictEntryComparison(
                                        evaluator,
                                        type,
                                        indexType,
                                        rightType,
                                        adjIsPositiveTest
                                    );
                                };
                            }
                        } else if (ClassType.isBuiltIn(indexType, 'int')) {
                            const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
                            if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
                                return (type: Type) => {
                                    return narrowTypeForDiscriminatedTupleComparison(
                                        evaluator,
                                        type,
                                        indexType,
                                        rightType,
                                        adjIsPositiveTest
                                    );
                                };
                            }
                        }
                    }
                }
            }

            // Look for len(x) == <literal> or len(x) != <literal>
            if (
                equalsOrNotEqualsOperator &&
                testExpression.leftExpression.nodeType === ParseNodeType.Call &&
                testExpression.leftExpression.arguments.length === 1 &&
                testExpression.rightExpression.nodeType === ParseNodeType.Number &&
                testExpression.rightExpression.isInteger
            ) {
                const arg0Expr = testExpression.leftExpression.arguments[0].valueExpression;

                if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
                    const callType = evaluator.getTypeOfExpression(
                        testExpression.leftExpression.leftExpression,
                        /* expectedType */ undefined,
                        EvaluatorFlags.DoNotSpecialize
                    ).type;

                    if (isFunction(callType) && callType.details.fullName === 'builtins.len') {
                        const tupleLength = testExpression.rightExpression.value;

                        if (typeof tupleLength === 'number') {
                            return (type: Type) => {
                                return narrowTypeForTupleLength(evaluator, type, tupleLength, adjIsPositiveTest);
                            };
                        }
                    }
                }
            }

            // Look for X.Y == <literal> or X.Y != <literal>
            if (
                equalsOrNotEqualsOperator &&
                testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess &&
                ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression)
            ) {
                const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
                const memberName = testExpression.leftExpression.memberName;
                if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
                    return (type: Type) => {
                        return narrowTypeForDiscriminatedFieldComparison(
                            evaluator,
                            type,
                            memberName.value,
                            rightType,
                            adjIsPositiveTest
                        );
                    };
                }
            }

            // Look for X.Y is <literal> or X.Y is not <literal> where <literal> is
            // an enum or bool literal
            if (
                testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess &&
                ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression)
            ) {
                const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
                const memberName = testExpression.leftExpression.memberName;
                if (
                    isClassInstance(rightType) &&
                    (ClassType.isEnumClass(rightType) || ClassType.isBuiltIn(rightType, 'bool')) &&
                    rightType.literalValue !== undefined
                ) {
                    return (type: Type) => {
                        return narrowTypeForDiscriminatedFieldComparison(
                            evaluator,
                            type,
                            memberName.value,
                            rightType,
                            adjIsPositiveTest
                        );
                    };
                }
            }
        }

        if (testExpression.operator === OperatorType.In) {
            // Look for "x in y" where y is one of several built-in types.
            if (isPositiveTest && ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
                const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
                return (type: Type) => {
                    return narrowTypeForContains(evaluator, type, rightType);
                };
            }
        }

        if (testExpression.operator === OperatorType.In || testExpression.operator === OperatorType.NotIn) {
            if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) {
                // Look for <string literal> in y where y is a union that contains
                // one or more TypedDicts.
                const leftType = evaluator.getTypeOfExpression(testExpression.leftExpression).type;
                if (isClassInstance(leftType) && ClassType.isBuiltIn(leftType, 'str') && isLiteralType(leftType)) {
                    const adjIsPositiveTest =
                        testExpression.operator === OperatorType.In ? isPositiveTest : !isPositiveTest;
                    return (type: Type) => {
                        return narrowTypeForTypedDictKey(
                            evaluator,
                            type,
                            ClassType.cloneAsInstantiable(leftType),
                            adjIsPositiveTest
                        );
                    };
                }
            }
        }
    }

    if (testExpression.nodeType === ParseNodeType.Call) {
        const callType = evaluator.getTypeOfExpression(
            testExpression.leftExpression,
            /* expectedType */ undefined,
            EvaluatorFlags.DoNotSpecialize
        ).type;

        // Look for "isinstance(X, Y)" or "issubclass(X, Y)".
        if (
            isFunction(callType) &&
            (callType.details.builtInName === 'isinstance' || callType.details.builtInName === 'issubclass') &&
            testExpression.arguments.length === 2
        ) {
            // Make sure the first parameter is a supported expression type
            // and the second parameter is a valid class type or a tuple
            // of valid class types.
            const isInstanceCheck = callType.details.builtInName === 'isinstance';
            const arg0Expr = testExpression.arguments[0].valueExpression;
            const arg1Expr = testExpression.arguments[1].valueExpression;
            if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
                const arg1Type = evaluator.getTypeOfExpression(
                    arg1Expr,
                    undefined,
                    EvaluatorFlags.EvaluateStringLiteralAsType |
                        EvaluatorFlags.ParamSpecDisallowed |
                        EvaluatorFlags.TypeVarTupleDisallowed
                ).type;

                const classTypeList = getIsInstanceClassTypes(arg1Type);

                if (classTypeList) {
                    return (type: Type) => {
                        const narrowedType = narrowTypeForIsInstance(
                            evaluator,
                            type,
                            classTypeList,
                            isInstanceCheck,
                            isPositiveTest,
                            /* allowIntersections */ false,
                            testExpression
                        );
                        if (!isNever(narrowedType)) {
                            return narrowedType;
                        }

                        // Try again with intersection types allowed.
                        return narrowTypeForIsInstance(
                            evaluator,
                            type,
                            classTypeList,
                            isInstanceCheck,
                            isPositiveTest,
                            /* allowIntersections */ true,
                            testExpression
                        );
                    };
                }
            }
        }

        // Look for "callable(X)"
        if (
            isFunction(callType) &&
            callType.details.builtInName === 'callable' &&
            testExpression.arguments.length === 1
        ) {
            const arg0Expr = testExpression.arguments[0].valueExpression;
            if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
                return (type: Type) => {
                    let narrowedType = narrowTypeForCallable(
                        evaluator,
                        type,
                        isPositiveTest,
                        testExpression,
                        /* allowIntersections */ false
                    );
                    if (isPositiveTest && isNever(narrowedType)) {
                        // Try again with intersections allowed.
                        narrowedType = narrowTypeForCallable(
                            evaluator,
                            type,
                            isPositiveTest,
                            testExpression,
                            /* allowIntersections */ true
                        );
                    }

                    return narrowedType;
                };
            }
        }

        // Look for "bool(X)"
        if (
            isInstantiableClass(callType) &&
            ClassType.isBuiltIn(callType, 'bool') &&
            testExpression.arguments.length === 1 &&
            !testExpression.arguments[0].name
        ) {
            if (ParseTreeUtils.isMatchingExpression(reference, testExpression.arguments[0].valueExpression)) {
                return (type: Type) => {
                    return narrowTypeForTruthiness(evaluator, type, isPositiveTest);
                };
            }
        }

        // Look for a TypeGuard function.
        if (testExpression.arguments.length >= 1) {
            const arg0Expr = testExpression.arguments[0].valueExpression;
            if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
                // Does this look like it's a custom type guard function?
                if (
                    isFunction(callType) &&
                    callType.details.declaredReturnType &&
                    isClassInstance(callType.details.declaredReturnType) &&
                    ClassType.isBuiltIn(callType.details.declaredReturnType, ['TypeGuard', 'StrictTypeGuard'])
                ) {
                    // Evaluate the type guard call expression.
                    const functionReturnType = evaluator.getTypeOfExpression(testExpression).type;
                    if (
                        isClassInstance(functionReturnType) &&
                        ClassType.isBuiltIn(functionReturnType, 'bool') &&
                        functionReturnType.typeGuardType
                    ) {
                        const isStrictTypeGuard = ClassType.isBuiltIn(
                            callType.details.declaredReturnType,
                            'StrictTypeGuard'
                        );
                        const typeGuardType = functionReturnType.typeGuardType;

                        return (type: Type) => {
                            return narrowTypeForUserDefinedTypeGuard(
                                evaluator,
                                type,
                                typeGuardType,
                                isPositiveTest,
                                isStrictTypeGuard
                            );
                        };
                    }
                }
            }
        }
    }

    if (ParseTreeUtils.isMatchingExpression(reference, testExpression)) {
        return (type: Type) => {
            return narrowTypeForTruthiness(evaluator, type, isPositiveTest);
        };
    }

    // Is this a reference to an aliased conditional expression (a local variable
    // that was assigned a value that can inform type narrowing of the reference expression)?
    if (
        testExpression.nodeType === ParseNodeType.Name &&
        reference.nodeType === ParseNodeType.Name &&
        testExpression !== reference
    ) {
        // Make sure the reference expression is a constant parameter or variable.
        // If the reference expression is modified within the scope multiple times,
        // we need to validate that it is not modified between the test expression
        // evaluation and the conditional check.
        const testExprDecl = getDeclsForLocalVar(evaluator, testExpression, testExpression);
        if (testExprDecl && testExprDecl.length === 1 && testExprDecl[0].type === DeclarationType.Variable) {
            const referenceDecls = getDeclsForLocalVar(evaluator, reference, testExpression);

            if (referenceDecls) {
                let modifyingDecls: Declaration[] = [];

                if (referenceDecls.length > 1) {
                    // If there is more than one assignment to the reference variable within
                    // the local scope, make sure that none of these assignments are done
                    // after the test expression but before the condition check.
                    //
                    // This is OK:
                    //  val = None
                    //  is_none = val is None
                    //  if is_none: ...
                    //
                    // This is not OK:
                    //  val = None
                    //  is_none = val is None
                    //  val = 1
                    //  if is_none: ...
                    modifyingDecls = referenceDecls.filter((decl) => {
                        return (
                            evaluator.isNodeReachable(testExpression, decl.node) &&
                            evaluator.isNodeReachable(decl.node, testExprDecl[0].node)
                        );
                    });
                }

                if (modifyingDecls.length === 0) {
                    const initNode = testExprDecl[0].inferredTypeSource;

                    if (
                        initNode &&
                        !ParseTreeUtils.isNodeContainedWithin(testExpression, initNode) &&
                        isExpressionNode(initNode)
                    ) {
                        return getTypeNarrowingCallback(evaluator, reference, initNode, isPositiveTest);
                    }
                }
            }
        }
    }

    // We normally won't find a "not" operator here because they are stripped out
    // by the binder when it creates condition flow nodes, but we can find this
    // in the case of local variables type narrowing.
    if (testExpression.nodeType === ParseNodeType.UnaryOperation) {
        if (testExpression.operator === OperatorType.Not) {
            return getTypeNarrowingCallback(evaluator, reference, testExpression.expression, !isPositiveTest);
        }
    }

    return undefined;
}