export function synthesizeDataClassMethods()

in packages/pyright-internal/src/analyzer/dataClasses.ts [66:514]


export function synthesizeDataClassMethods(
    evaluator: TypeEvaluator,
    node: ClassNode,
    classType: ClassType,
    skipSynthesizeInit: boolean,
    skipSynthesizeHash: boolean
) {
    assert(ClassType.isDataClass(classType));

    const classTypeVar = synthesizeTypeVarForSelfCls(classType, /* isClsParam */ true);
    const newType = FunctionType.createInstance(
        '__new__',
        '',
        '',
        FunctionTypeFlags.ConstructorMethod | FunctionTypeFlags.SynthesizedMethod
    );
    const initType = FunctionType.createInstance('__init__', '', '', FunctionTypeFlags.SynthesizedMethod);

    FunctionType.addParameter(newType, {
        category: ParameterCategory.Simple,
        name: 'cls',
        type: classTypeVar,
        hasDeclaredType: true,
    });
    FunctionType.addDefaultParameters(newType);
    newType.details.declaredReturnType = convertToInstance(classTypeVar);

    const selfParam: FunctionParameter = {
        category: ParameterCategory.Simple,
        name: 'self',
        type: synthesizeTypeVarForSelfCls(classType, /* isClsParam */ false),
        hasDeclaredType: true,
    };
    FunctionType.addParameter(initType, selfParam);
    initType.details.declaredReturnType = NoneType.createInstance();

    // Maintain a list of all dataclass entries (including
    // those from inherited classes) plus a list of only those
    // entries added by this class.
    const localDataClassEntries: DataClassEntry[] = [];
    const fullDataClassEntries: DataClassEntry[] = [];
    const allAncestorsKnown = addInheritedDataClassEntries(classType, fullDataClassEntries);

    if (!allAncestorsKnown) {
        // If one or more ancestor classes have an unknown type, we cannot
        // safely determine the parameter list, so we'll accept any parameters
        // to avoid a false positive.
        FunctionType.addDefaultParameters(initType);
    }

    // Maintain a list of "type evaluators".
    type TypeEvaluator = () => Type;
    const localEntryTypeEvaluator: { entry: DataClassEntry; evaluator: TypeEvaluator }[] = [];
    let sawKeywordOnlySeparator = false;

    node.suite.statements.forEach((statementList) => {
        if (statementList.nodeType === ParseNodeType.StatementList) {
            statementList.statements.forEach((statement) => {
                let variableNameNode: NameNode | undefined;
                let aliasName: string | undefined;
                let variableTypeEvaluator: TypeEvaluator | undefined;
                let hasDefaultValue = false;
                let isKeywordOnly = ClassType.isDataClassKeywordOnlyParams(classType) || sawKeywordOnlySeparator;
                let defaultValueExpression: ExpressionNode | undefined;
                let includeInInit = true;

                if (statement.nodeType === ParseNodeType.Assignment) {
                    if (
                        statement.leftExpression.nodeType === ParseNodeType.TypeAnnotation &&
                        statement.leftExpression.valueExpression.nodeType === ParseNodeType.Name
                    ) {
                        variableNameNode = statement.leftExpression.valueExpression;
                        variableTypeEvaluator = () =>
                            evaluator.getTypeOfAnnotation(
                                (statement.leftExpression as TypeAnnotationNode).typeAnnotation,
                                {
                                    isVariableAnnotation: true,
                                    allowFinal: true,
                                    allowClassVar: true,
                                }
                            );
                    }

                    hasDefaultValue = true;
                    defaultValueExpression = statement.rightExpression;

                    // If the RHS of the assignment is assigning a field instance where the
                    // "init" parameter is set to false, do not include it in the init method.
                    if (statement.rightExpression.nodeType === ParseNodeType.Call) {
                        const callType = evaluator.getTypeOfExpression(
                            statement.rightExpression.leftExpression,
                            /* expectedType */ undefined,
                            EvaluatorFlags.DoNotSpecialize
                        ).type;
                        if (
                            isDataclassFieldConstructor(
                                callType,
                                classType.details.dataClassBehaviors?.fieldDescriptorNames || []
                            )
                        ) {
                            const initArg = statement.rightExpression.arguments.find(
                                (arg) => arg.name?.value === 'init'
                            );
                            if (initArg && initArg.valueExpression) {
                                const value = evaluateStaticBoolExpression(
                                    initArg.valueExpression,
                                    AnalyzerNodeInfo.getFileInfo(node).executionEnvironment
                                );
                                if (value === false) {
                                    includeInInit = false;
                                }
                            } else {
                                // See if the field constructor has an `init` parameter with
                                // a default value.
                                let callTarget: FunctionType | undefined;
                                if (isFunction(callType)) {
                                    callTarget = callType;
                                } else if (isOverloadedFunction(callType)) {
                                    callTarget = evaluator.getBestOverloadForArguments(
                                        statement.rightExpression,
                                        callType,
                                        statement.rightExpression.arguments
                                    );
                                } else if (isInstantiableClass(callType)) {
                                    const initCall = evaluator.getBoundMethod(callType, '__init__');
                                    if (initCall) {
                                        if (isFunction(initCall)) {
                                            callTarget = initCall;
                                        } else if (isOverloadedFunction(initCall)) {
                                            callTarget = evaluator.getBestOverloadForArguments(
                                                statement.rightExpression,
                                                initCall,
                                                statement.rightExpression.arguments
                                            );
                                        }
                                    }
                                }

                                if (callTarget) {
                                    const initParam = callTarget.details.parameters.find((p) => p.name === 'init');
                                    if (initParam && initParam.defaultValueExpression && initParam.hasDeclaredType) {
                                        if (
                                            isClass(initParam.type) &&
                                            ClassType.isBuiltIn(initParam.type, 'bool') &&
                                            isLiteralType(initParam.type)
                                        ) {
                                            if (initParam.type.literalValue === false) {
                                                includeInInit = false;
                                            }
                                        }
                                    }
                                }
                            }

                            const kwOnlyArg = statement.rightExpression.arguments.find(
                                (arg) => arg.name?.value === 'kw_only'
                            );
                            if (kwOnlyArg && kwOnlyArg.valueExpression) {
                                const value = evaluateStaticBoolExpression(
                                    kwOnlyArg.valueExpression,
                                    AnalyzerNodeInfo.getFileInfo(node).executionEnvironment
                                );
                                if (value === false) {
                                    isKeywordOnly = false;
                                } else if (value === true) {
                                    isKeywordOnly = true;
                                }
                            }

                            hasDefaultValue = statement.rightExpression.arguments.some(
                                (arg) =>
                                    arg.name?.value === 'default' ||
                                    arg.name?.value === 'default_factory' ||
                                    arg.name?.value === 'factory'
                            );

                            const aliasArg = statement.rightExpression.arguments.find(
                                (arg) => arg.name?.value === 'alias'
                            );
                            if (aliasArg) {
                                const valueType = evaluator.getTypeOfExpression(aliasArg.valueExpression).type;
                                if (
                                    isClassInstance(valueType) &&
                                    ClassType.isBuiltIn(valueType, 'str') &&
                                    isLiteralType(valueType)
                                ) {
                                    aliasName = valueType.literalValue as string;
                                }
                            }
                        }
                    }
                } else if (statement.nodeType === ParseNodeType.TypeAnnotation) {
                    if (statement.valueExpression.nodeType === ParseNodeType.Name) {
                        variableNameNode = statement.valueExpression;
                        variableTypeEvaluator = () =>
                            evaluator.getTypeOfAnnotation(statement.typeAnnotation, {
                                isVariableAnnotation: true,
                                allowFinal: true,
                                allowClassVar: true,
                            });

                        // Is this a KW_ONLY separator introduced in Python 3.10?
                        if (statement.valueExpression.value === '_') {
                            const annotatedType = variableTypeEvaluator();

                            if (isClassInstance(annotatedType) && ClassType.isBuiltIn(annotatedType, 'KW_ONLY')) {
                                sawKeywordOnlySeparator = true;
                                variableNameNode = undefined;
                                variableTypeEvaluator = undefined;
                            }
                        }
                    }
                }

                if (variableNameNode && variableTypeEvaluator) {
                    const variableName = variableNameNode.value;

                    // Don't include class vars. PEP 557 indicates that they shouldn't
                    // be considered data class entries.
                    const variableSymbol = classType.details.fields.get(variableName);
                    const isFinal = variableSymbol
                        ?.getDeclarations()
                        .some((decl) => decl.type === DeclarationType.Variable && decl.isFinal);

                    if (variableSymbol?.isClassVar() && !isFinal) {
                        // If an ancestor class declared an instance variable but this dataclass
                        // declares a ClassVar, delete the older one from the full data class entries.
                        // We exclude final variables here because a Final type annotation is implicitly
                        // considered a ClassVar by the binder, but dataclass rules are different.
                        const index = fullDataClassEntries.findIndex((p) => p.name === variableName);
                        if (index >= 0) {
                            fullDataClassEntries.splice(index, 1);
                        }
                        const dataClassEntry: DataClassEntry = {
                            name: variableName,
                            classType,
                            alias: aliasName,
                            isKeywordOnly: false,
                            hasDefault: hasDefaultValue,
                            defaultValueExpression,
                            includeInInit,
                            type: UnknownType.create(),
                            isClassVar: true,
                        };
                        localDataClassEntries.push(dataClassEntry);
                    } else {
                        // Create a new data class entry, but defer evaluation of the type until
                        // we've compiled the full list of data class entries for this class. This
                        // allows us to handle circular references in types.
                        const dataClassEntry: DataClassEntry = {
                            name: variableName,
                            classType,
                            alias: aliasName,
                            isKeywordOnly,
                            hasDefault: hasDefaultValue,
                            defaultValueExpression,
                            includeInInit,
                            type: UnknownType.create(),
                            isClassVar: false,
                        };
                        localEntryTypeEvaluator.push({ entry: dataClassEntry, evaluator: variableTypeEvaluator });

                        // Add the new entry to the local entry list.
                        let insertIndex = localDataClassEntries.findIndex((e) => e.name === variableName);
                        if (insertIndex >= 0) {
                            localDataClassEntries[insertIndex] = dataClassEntry;
                        } else {
                            localDataClassEntries.push(dataClassEntry);
                        }

                        // Add the new entry to the full entry list.
                        insertIndex = fullDataClassEntries.findIndex((p) => p.name === variableName);
                        if (insertIndex >= 0) {
                            fullDataClassEntries[insertIndex] = dataClassEntry;
                        } else {
                            fullDataClassEntries.push(dataClassEntry);
                            insertIndex = fullDataClassEntries.length - 1;
                        }

                        // If we've already seen a entry with a default value defined,
                        // all subsequent entries must also have default values.
                        if (!isKeywordOnly && includeInInit && !skipSynthesizeInit && !hasDefaultValue) {
                            const firstDefaultValueIndex = fullDataClassEntries.findIndex(
                                (p) => p.hasDefault && p.includeInInit && !p.isKeywordOnly
                            );
                            if (firstDefaultValueIndex >= 0 && firstDefaultValueIndex < insertIndex) {
                                evaluator.addError(Localizer.Diagnostic.dataClassFieldWithDefault(), variableNameNode);
                            }
                        }
                    }
                }
            });
        }
    });

    classType.details.dataClassEntries = localDataClassEntries;

    // Now that the dataClassEntries field has been set with a complete list
    // of local data class entries for this class, perform deferred type
    // evaluations. This could involve circular type dependencies, so it's
    // required that the list be complete (even if types are not yet accurate)
    // before we perform the type evaluations.
    localEntryTypeEvaluator.forEach((entryEvaluator) => {
        entryEvaluator.entry.type = entryEvaluator.evaluator();
    });

    const symbolTable = classType.details.fields;
    const keywordOnlyParams: FunctionParameter[] = [];

    if (!skipSynthesizeInit && allAncestorsKnown) {
        fullDataClassEntries.forEach((entry) => {
            if (entry.includeInInit) {
                // If the type refers to Self of the parent class, we need to
                // transform it to refer to the Self of this subclass.
                let effectiveType = entry.type;
                if (entry.classType !== classType && requiresSpecialization(effectiveType)) {
                    const typeVarMap = new TypeVarMap(getTypeVarScopeId(entry.classType));
                    populateTypeVarMapForSelfType(typeVarMap, entry.classType, classType);
                    effectiveType = applySolvedTypeVars(effectiveType, typeVarMap);
                }

                const functionParam: FunctionParameter = {
                    category: ParameterCategory.Simple,
                    name: entry.alias || entry.name,
                    hasDefault: entry.hasDefault,
                    defaultValueExpression: entry.defaultValueExpression,
                    type: effectiveType,
                    hasDeclaredType: true,
                };

                if (entry.isKeywordOnly) {
                    keywordOnlyParams.push(functionParam);
                } else {
                    FunctionType.addParameter(initType, functionParam);
                }
            }
        });

        if (keywordOnlyParams.length > 0) {
            FunctionType.addParameter(initType, {
                category: ParameterCategory.VarArgList,
                type: AnyType.create(),
            });
            keywordOnlyParams.forEach((param) => {
                FunctionType.addParameter(initType, param);
            });
        }

        symbolTable.set('__init__', Symbol.createWithType(SymbolFlags.ClassMember, initType));
        symbolTable.set('__new__', Symbol.createWithType(SymbolFlags.ClassMember, newType));
    }

    // Synthesize the __match_args__ class variable if it doesn't exist.
    const strType = evaluator.getBuiltInType(node, 'str');
    const tupleClassType = evaluator.getBuiltInType(node, 'tuple');
    if (
        tupleClassType &&
        isInstantiableClass(tupleClassType) &&
        strType &&
        isInstantiableClass(strType) &&
        !symbolTable.has('__match_args__')
    ) {
        const matchArgsNames: string[] = [];
        fullDataClassEntries.forEach((entry) => {
            if (entry.includeInInit && !entry.isKeywordOnly) {
                // Use the field name, not its alias (if it has one).
                matchArgsNames.push(entry.name);
            }
        });
        const literalTypes: TupleTypeArgument[] = matchArgsNames.map((name) => {
            return { type: ClassType.cloneAsInstance(ClassType.cloneWithLiteral(strType, name)), isUnbounded: false };
        });
        const matchArgsType = ClassType.cloneAsInstance(specializeTupleClass(tupleClassType, literalTypes));
        symbolTable.set('__match_args__', Symbol.createWithType(SymbolFlags.ClassMember, matchArgsType));
    }

    const synthesizeComparisonMethod = (operator: string, paramType: Type) => {
        const operatorMethod = FunctionType.createInstance(operator, '', '', FunctionTypeFlags.SynthesizedMethod);
        FunctionType.addParameter(operatorMethod, selfParam);
        FunctionType.addParameter(operatorMethod, {
            category: ParameterCategory.Simple,
            name: 'x',
            type: paramType,
            hasDeclaredType: true,
        });
        operatorMethod.details.declaredReturnType = evaluator.getBuiltInObject(node, 'bool');
        symbolTable.set(operator, Symbol.createWithType(SymbolFlags.ClassMember, operatorMethod));
    };

    // Synthesize comparison operators.
    if (!ClassType.isSkipSynthesizedDataClassEq(classType)) {
        synthesizeComparisonMethod('__eq__', evaluator.getBuiltInObject(node, 'object'));
    }

    if (ClassType.isSynthesizedDataclassOrder(classType)) {
        const objType = ClassType.cloneAsInstance(classType);
        ['__lt__', '__le__', '__gt__', '__ge__'].forEach((operator) => {
            synthesizeComparisonMethod(operator, objType);
        });
    }

    let synthesizeHashFunction =
        !ClassType.isSkipSynthesizedDataClassEq(classType) && ClassType.isFrozenDataClass(classType);
    const synthesizeHashNone =
        !ClassType.isSkipSynthesizedDataClassEq(classType) && !ClassType.isFrozenDataClass(classType);

    if (skipSynthesizeHash) {
        synthesizeHashFunction = false;
    }

    // If the user has indicated that a hash function should be generated even if it's unsafe
    // to do so or there is already a hash function present, override the default logic.
    if (ClassType.isSynthesizeDataClassUnsafeHash(classType)) {
        synthesizeHashFunction = true;
    }

    if (synthesizeHashFunction) {
        const hashMethod = FunctionType.createInstance('__hash__', '', '', FunctionTypeFlags.SynthesizedMethod);
        FunctionType.addParameter(hashMethod, selfParam);
        hashMethod.details.declaredReturnType = evaluator.getBuiltInObject(node, 'int');
        symbolTable.set('__hash__', Symbol.createWithType(SymbolFlags.ClassMember, hashMethod));
    } else if (synthesizeHashNone && !skipSynthesizeHash) {
        symbolTable.set('__hash__', Symbol.createWithType(SymbolFlags.ClassMember, NoneType.createInstance()));
    }

    let dictType = evaluator.getBuiltInType(node, 'dict');
    if (isInstantiableClass(dictType)) {
        dictType = ClassType.cloneAsInstance(
            ClassType.cloneForSpecialization(
                dictType,
                [evaluator.getBuiltInObject(node, 'str'), AnyType.create()],
                /* isTypeArgumentExplicit */ true
            )
        );
    }
    symbolTable.set('__dataclass_fields__', Symbol.createWithType(SymbolFlags.ClassMember, dictType));

    if (ClassType.isGeneratedDataClassSlots(classType) && classType.details.localSlotsNames === undefined) {
        classType.details.localSlotsNames = localDataClassEntries.map((entry) => entry.name);
    }

    // If this dataclass derived from a NamedTuple, update the NamedTuple with
    // the specialized entry types.
    updateNamedTupleBaseClass(
        classType,
        fullDataClassEntries.map((entry) => entry.type),
        /* isTypeArgumentExplicit */ true
    );
}