in packages/pyright-internal/src/analyzer/dataClasses.ts [66:514]
export function synthesizeDataClassMethods(
evaluator: TypeEvaluator,
node: ClassNode,
classType: ClassType,
skipSynthesizeInit: boolean,
skipSynthesizeHash: boolean
) {
assert(ClassType.isDataClass(classType));
const classTypeVar = synthesizeTypeVarForSelfCls(classType, /* isClsParam */ true);
const newType = FunctionType.createInstance(
'__new__',
'',
'',
FunctionTypeFlags.ConstructorMethod | FunctionTypeFlags.SynthesizedMethod
);
const initType = FunctionType.createInstance('__init__', '', '', FunctionTypeFlags.SynthesizedMethod);
FunctionType.addParameter(newType, {
category: ParameterCategory.Simple,
name: 'cls',
type: classTypeVar,
hasDeclaredType: true,
});
FunctionType.addDefaultParameters(newType);
newType.details.declaredReturnType = convertToInstance(classTypeVar);
const selfParam: FunctionParameter = {
category: ParameterCategory.Simple,
name: 'self',
type: synthesizeTypeVarForSelfCls(classType, /* isClsParam */ false),
hasDeclaredType: true,
};
FunctionType.addParameter(initType, selfParam);
initType.details.declaredReturnType = NoneType.createInstance();
// Maintain a list of all dataclass entries (including
// those from inherited classes) plus a list of only those
// entries added by this class.
const localDataClassEntries: DataClassEntry[] = [];
const fullDataClassEntries: DataClassEntry[] = [];
const allAncestorsKnown = addInheritedDataClassEntries(classType, fullDataClassEntries);
if (!allAncestorsKnown) {
// If one or more ancestor classes have an unknown type, we cannot
// safely determine the parameter list, so we'll accept any parameters
// to avoid a false positive.
FunctionType.addDefaultParameters(initType);
}
// Maintain a list of "type evaluators".
type TypeEvaluator = () => Type;
const localEntryTypeEvaluator: { entry: DataClassEntry; evaluator: TypeEvaluator }[] = [];
let sawKeywordOnlySeparator = false;
node.suite.statements.forEach((statementList) => {
if (statementList.nodeType === ParseNodeType.StatementList) {
statementList.statements.forEach((statement) => {
let variableNameNode: NameNode | undefined;
let aliasName: string | undefined;
let variableTypeEvaluator: TypeEvaluator | undefined;
let hasDefaultValue = false;
let isKeywordOnly = ClassType.isDataClassKeywordOnlyParams(classType) || sawKeywordOnlySeparator;
let defaultValueExpression: ExpressionNode | undefined;
let includeInInit = true;
if (statement.nodeType === ParseNodeType.Assignment) {
if (
statement.leftExpression.nodeType === ParseNodeType.TypeAnnotation &&
statement.leftExpression.valueExpression.nodeType === ParseNodeType.Name
) {
variableNameNode = statement.leftExpression.valueExpression;
variableTypeEvaluator = () =>
evaluator.getTypeOfAnnotation(
(statement.leftExpression as TypeAnnotationNode).typeAnnotation,
{
isVariableAnnotation: true,
allowFinal: true,
allowClassVar: true,
}
);
}
hasDefaultValue = true;
defaultValueExpression = statement.rightExpression;
// If the RHS of the assignment is assigning a field instance where the
// "init" parameter is set to false, do not include it in the init method.
if (statement.rightExpression.nodeType === ParseNodeType.Call) {
const callType = evaluator.getTypeOfExpression(
statement.rightExpression.leftExpression,
/* expectedType */ undefined,
EvaluatorFlags.DoNotSpecialize
).type;
if (
isDataclassFieldConstructor(
callType,
classType.details.dataClassBehaviors?.fieldDescriptorNames || []
)
) {
const initArg = statement.rightExpression.arguments.find(
(arg) => arg.name?.value === 'init'
);
if (initArg && initArg.valueExpression) {
const value = evaluateStaticBoolExpression(
initArg.valueExpression,
AnalyzerNodeInfo.getFileInfo(node).executionEnvironment
);
if (value === false) {
includeInInit = false;
}
} else {
// See if the field constructor has an `init` parameter with
// a default value.
let callTarget: FunctionType | undefined;
if (isFunction(callType)) {
callTarget = callType;
} else if (isOverloadedFunction(callType)) {
callTarget = evaluator.getBestOverloadForArguments(
statement.rightExpression,
callType,
statement.rightExpression.arguments
);
} else if (isInstantiableClass(callType)) {
const initCall = evaluator.getBoundMethod(callType, '__init__');
if (initCall) {
if (isFunction(initCall)) {
callTarget = initCall;
} else if (isOverloadedFunction(initCall)) {
callTarget = evaluator.getBestOverloadForArguments(
statement.rightExpression,
initCall,
statement.rightExpression.arguments
);
}
}
}
if (callTarget) {
const initParam = callTarget.details.parameters.find((p) => p.name === 'init');
if (initParam && initParam.defaultValueExpression && initParam.hasDeclaredType) {
if (
isClass(initParam.type) &&
ClassType.isBuiltIn(initParam.type, 'bool') &&
isLiteralType(initParam.type)
) {
if (initParam.type.literalValue === false) {
includeInInit = false;
}
}
}
}
}
const kwOnlyArg = statement.rightExpression.arguments.find(
(arg) => arg.name?.value === 'kw_only'
);
if (kwOnlyArg && kwOnlyArg.valueExpression) {
const value = evaluateStaticBoolExpression(
kwOnlyArg.valueExpression,
AnalyzerNodeInfo.getFileInfo(node).executionEnvironment
);
if (value === false) {
isKeywordOnly = false;
} else if (value === true) {
isKeywordOnly = true;
}
}
hasDefaultValue = statement.rightExpression.arguments.some(
(arg) =>
arg.name?.value === 'default' ||
arg.name?.value === 'default_factory' ||
arg.name?.value === 'factory'
);
const aliasArg = statement.rightExpression.arguments.find(
(arg) => arg.name?.value === 'alias'
);
if (aliasArg) {
const valueType = evaluator.getTypeOfExpression(aliasArg.valueExpression).type;
if (
isClassInstance(valueType) &&
ClassType.isBuiltIn(valueType, 'str') &&
isLiteralType(valueType)
) {
aliasName = valueType.literalValue as string;
}
}
}
}
} else if (statement.nodeType === ParseNodeType.TypeAnnotation) {
if (statement.valueExpression.nodeType === ParseNodeType.Name) {
variableNameNode = statement.valueExpression;
variableTypeEvaluator = () =>
evaluator.getTypeOfAnnotation(statement.typeAnnotation, {
isVariableAnnotation: true,
allowFinal: true,
allowClassVar: true,
});
// Is this a KW_ONLY separator introduced in Python 3.10?
if (statement.valueExpression.value === '_') {
const annotatedType = variableTypeEvaluator();
if (isClassInstance(annotatedType) && ClassType.isBuiltIn(annotatedType, 'KW_ONLY')) {
sawKeywordOnlySeparator = true;
variableNameNode = undefined;
variableTypeEvaluator = undefined;
}
}
}
}
if (variableNameNode && variableTypeEvaluator) {
const variableName = variableNameNode.value;
// Don't include class vars. PEP 557 indicates that they shouldn't
// be considered data class entries.
const variableSymbol = classType.details.fields.get(variableName);
const isFinal = variableSymbol
?.getDeclarations()
.some((decl) => decl.type === DeclarationType.Variable && decl.isFinal);
if (variableSymbol?.isClassVar() && !isFinal) {
// If an ancestor class declared an instance variable but this dataclass
// declares a ClassVar, delete the older one from the full data class entries.
// We exclude final variables here because a Final type annotation is implicitly
// considered a ClassVar by the binder, but dataclass rules are different.
const index = fullDataClassEntries.findIndex((p) => p.name === variableName);
if (index >= 0) {
fullDataClassEntries.splice(index, 1);
}
const dataClassEntry: DataClassEntry = {
name: variableName,
classType,
alias: aliasName,
isKeywordOnly: false,
hasDefault: hasDefaultValue,
defaultValueExpression,
includeInInit,
type: UnknownType.create(),
isClassVar: true,
};
localDataClassEntries.push(dataClassEntry);
} else {
// Create a new data class entry, but defer evaluation of the type until
// we've compiled the full list of data class entries for this class. This
// allows us to handle circular references in types.
const dataClassEntry: DataClassEntry = {
name: variableName,
classType,
alias: aliasName,
isKeywordOnly,
hasDefault: hasDefaultValue,
defaultValueExpression,
includeInInit,
type: UnknownType.create(),
isClassVar: false,
};
localEntryTypeEvaluator.push({ entry: dataClassEntry, evaluator: variableTypeEvaluator });
// Add the new entry to the local entry list.
let insertIndex = localDataClassEntries.findIndex((e) => e.name === variableName);
if (insertIndex >= 0) {
localDataClassEntries[insertIndex] = dataClassEntry;
} else {
localDataClassEntries.push(dataClassEntry);
}
// Add the new entry to the full entry list.
insertIndex = fullDataClassEntries.findIndex((p) => p.name === variableName);
if (insertIndex >= 0) {
fullDataClassEntries[insertIndex] = dataClassEntry;
} else {
fullDataClassEntries.push(dataClassEntry);
insertIndex = fullDataClassEntries.length - 1;
}
// If we've already seen a entry with a default value defined,
// all subsequent entries must also have default values.
if (!isKeywordOnly && includeInInit && !skipSynthesizeInit && !hasDefaultValue) {
const firstDefaultValueIndex = fullDataClassEntries.findIndex(
(p) => p.hasDefault && p.includeInInit && !p.isKeywordOnly
);
if (firstDefaultValueIndex >= 0 && firstDefaultValueIndex < insertIndex) {
evaluator.addError(Localizer.Diagnostic.dataClassFieldWithDefault(), variableNameNode);
}
}
}
}
});
}
});
classType.details.dataClassEntries = localDataClassEntries;
// Now that the dataClassEntries field has been set with a complete list
// of local data class entries for this class, perform deferred type
// evaluations. This could involve circular type dependencies, so it's
// required that the list be complete (even if types are not yet accurate)
// before we perform the type evaluations.
localEntryTypeEvaluator.forEach((entryEvaluator) => {
entryEvaluator.entry.type = entryEvaluator.evaluator();
});
const symbolTable = classType.details.fields;
const keywordOnlyParams: FunctionParameter[] = [];
if (!skipSynthesizeInit && allAncestorsKnown) {
fullDataClassEntries.forEach((entry) => {
if (entry.includeInInit) {
// If the type refers to Self of the parent class, we need to
// transform it to refer to the Self of this subclass.
let effectiveType = entry.type;
if (entry.classType !== classType && requiresSpecialization(effectiveType)) {
const typeVarMap = new TypeVarMap(getTypeVarScopeId(entry.classType));
populateTypeVarMapForSelfType(typeVarMap, entry.classType, classType);
effectiveType = applySolvedTypeVars(effectiveType, typeVarMap);
}
const functionParam: FunctionParameter = {
category: ParameterCategory.Simple,
name: entry.alias || entry.name,
hasDefault: entry.hasDefault,
defaultValueExpression: entry.defaultValueExpression,
type: effectiveType,
hasDeclaredType: true,
};
if (entry.isKeywordOnly) {
keywordOnlyParams.push(functionParam);
} else {
FunctionType.addParameter(initType, functionParam);
}
}
});
if (keywordOnlyParams.length > 0) {
FunctionType.addParameter(initType, {
category: ParameterCategory.VarArgList,
type: AnyType.create(),
});
keywordOnlyParams.forEach((param) => {
FunctionType.addParameter(initType, param);
});
}
symbolTable.set('__init__', Symbol.createWithType(SymbolFlags.ClassMember, initType));
symbolTable.set('__new__', Symbol.createWithType(SymbolFlags.ClassMember, newType));
}
// Synthesize the __match_args__ class variable if it doesn't exist.
const strType = evaluator.getBuiltInType(node, 'str');
const tupleClassType = evaluator.getBuiltInType(node, 'tuple');
if (
tupleClassType &&
isInstantiableClass(tupleClassType) &&
strType &&
isInstantiableClass(strType) &&
!symbolTable.has('__match_args__')
) {
const matchArgsNames: string[] = [];
fullDataClassEntries.forEach((entry) => {
if (entry.includeInInit && !entry.isKeywordOnly) {
// Use the field name, not its alias (if it has one).
matchArgsNames.push(entry.name);
}
});
const literalTypes: TupleTypeArgument[] = matchArgsNames.map((name) => {
return { type: ClassType.cloneAsInstance(ClassType.cloneWithLiteral(strType, name)), isUnbounded: false };
});
const matchArgsType = ClassType.cloneAsInstance(specializeTupleClass(tupleClassType, literalTypes));
symbolTable.set('__match_args__', Symbol.createWithType(SymbolFlags.ClassMember, matchArgsType));
}
const synthesizeComparisonMethod = (operator: string, paramType: Type) => {
const operatorMethod = FunctionType.createInstance(operator, '', '', FunctionTypeFlags.SynthesizedMethod);
FunctionType.addParameter(operatorMethod, selfParam);
FunctionType.addParameter(operatorMethod, {
category: ParameterCategory.Simple,
name: 'x',
type: paramType,
hasDeclaredType: true,
});
operatorMethod.details.declaredReturnType = evaluator.getBuiltInObject(node, 'bool');
symbolTable.set(operator, Symbol.createWithType(SymbolFlags.ClassMember, operatorMethod));
};
// Synthesize comparison operators.
if (!ClassType.isSkipSynthesizedDataClassEq(classType)) {
synthesizeComparisonMethod('__eq__', evaluator.getBuiltInObject(node, 'object'));
}
if (ClassType.isSynthesizedDataclassOrder(classType)) {
const objType = ClassType.cloneAsInstance(classType);
['__lt__', '__le__', '__gt__', '__ge__'].forEach((operator) => {
synthesizeComparisonMethod(operator, objType);
});
}
let synthesizeHashFunction =
!ClassType.isSkipSynthesizedDataClassEq(classType) && ClassType.isFrozenDataClass(classType);
const synthesizeHashNone =
!ClassType.isSkipSynthesizedDataClassEq(classType) && !ClassType.isFrozenDataClass(classType);
if (skipSynthesizeHash) {
synthesizeHashFunction = false;
}
// If the user has indicated that a hash function should be generated even if it's unsafe
// to do so or there is already a hash function present, override the default logic.
if (ClassType.isSynthesizeDataClassUnsafeHash(classType)) {
synthesizeHashFunction = true;
}
if (synthesizeHashFunction) {
const hashMethod = FunctionType.createInstance('__hash__', '', '', FunctionTypeFlags.SynthesizedMethod);
FunctionType.addParameter(hashMethod, selfParam);
hashMethod.details.declaredReturnType = evaluator.getBuiltInObject(node, 'int');
symbolTable.set('__hash__', Symbol.createWithType(SymbolFlags.ClassMember, hashMethod));
} else if (synthesizeHashNone && !skipSynthesizeHash) {
symbolTable.set('__hash__', Symbol.createWithType(SymbolFlags.ClassMember, NoneType.createInstance()));
}
let dictType = evaluator.getBuiltInType(node, 'dict');
if (isInstantiableClass(dictType)) {
dictType = ClassType.cloneAsInstance(
ClassType.cloneForSpecialization(
dictType,
[evaluator.getBuiltInObject(node, 'str'), AnyType.create()],
/* isTypeArgumentExplicit */ true
)
);
}
symbolTable.set('__dataclass_fields__', Symbol.createWithType(SymbolFlags.ClassMember, dictType));
if (ClassType.isGeneratedDataClassSlots(classType) && classType.details.localSlotsNames === undefined) {
classType.details.localSlotsNames = localDataClassEntries.map((entry) => entry.name);
}
// If this dataclass derived from a NamedTuple, update the NamedTuple with
// the specialized entry types.
updateNamedTupleBaseClass(
classType,
fullDataClassEntries.map((entry) => entry.type),
/* isTypeArgumentExplicit */ true
);
}