in scala/scala-impl/src/org/jetbrains/plugins/scala/lang/psi/impl/expr/ExpectedTypesImpl.scala [243:632]
override def expectedExprTypes(expr: ScExpression, withResolvedFunction: Boolean = false,
fromUnderscore: Boolean = true): Array[ParameterType] = {
import expr.projectContext
implicit val context: Context = Context(expr)
val sameInContext = expr.getDeepSameElementInContext
def fromFunction(tp: ParameterType, isContextFunction: Boolean): Array[ParameterType] = {
val functionLikeType = FunctionLikeType(expr)
val tpUnwrapped =
if (!isContextFunction) unwrapContextFunctionType(tp._1)
else tp._1
tpUnwrapped match {
case functionLikeType(_, retTpe, _) => Array((retTpe, None))
case _ => Array.empty
}
}
def fromPolyFunction(tyParams: Seq[ScTypeParam])(tp: ParameterType): Array[ParameterType] = {
tp._1 match {
case PolyFunctionType(sig, retType) =>
// match
// [pParams..] => bodyType
// [tyParams..] => ...
// so
// substitute bodyType[pParams.. -> tyParams..]
val pParams = sig.typeParams
val bodyType = {
val paramTypes = sig.substitutedTypes.head.map(_.apply())
FunctionType((retType, paramTypes))(expr.elementScope, Context.Empty)
}
//val f2: [T, S] => T => S => Unit = [_, _] => x => y => ???
//Above is valid scala, which would naively be translated into
//scala.PolyFunction { def apply[_, _](x: _): Function1[_, Unit]
//which is not a valid scala snippet, so just introduce fresh names for
//type parameters named _ here.
val renameUnderscoreTypeParams =
tyParams.mapWithIndex {
case (tParam, idx) =>
val renamed =
if (tParam.name == "_") {
TypeParameterType(
TypeParameter.light(
s"_$$$idx",
tParam.typeParameters.map(TypeParameter.apply),
tParam.lowerBound.getOrNothing,
tParam.upperBound.getOrAny
)
)
} else TypeParameterType(tParam)
tParam -> renamed
}.to(Map)
val bindTypeParams = ScSubstitutor.bind(pParams, tyParams)(p => TypeParameterType(TypeParameter(p)))
val fullSubst = bindTypeParams.followed(ScSubstitutor.bind(tyParams)(renameUnderscoreTypeParams))
val result = fullSubst(bodyType)
Array((result, None))
case _ =>
Array.empty
}
}
def expectedTypesUnwrapContextFunction(e: ScExpression, fromUnderscore: Boolean): Array[ParameterType] =
e.expectedTypesEx(fromUnderscore).map(pt => unwrapContextFunctionType(pt._1) -> None)
def mapResolves(resolves: Array[ScalaResolveResult], types: Array[TypeResult]): Array[(TypeResult, Boolean, Boolean)] =
resolves.zip(types).map {
case (r, tp) =>
val syntheticKindProjectorApplyNames = Set(Lambda, LambdaSymbolic)
val actualType = r match {
case srr @ ScalaResolveResult(fun: ScFunction, s: ScSubstitutor) if fun.name == CommonNames.Apply =>
if (srr.innerResolveResult.exists(inner => syntheticKindProjectorApplyNames.contains(inner.name)))
tp
else Right(fun.polymorphicType(s))
case _ => tp
}
val isPolymorphic = r.element match {
case tpo: ScTypeParametersOwner => tpo.typeParameters.nonEmpty
case tpo: PsiTypeParameterListOwner => tpo.getTypeParameters.nonEmpty
case _ => false
}
(actualType, isPolymorphic, isApplyDynamicNamed(r))
}
def argIndex(argExprs: Seq[ScExpression]) =
if (sameInContext == null) 0
else argExprs.indexWhere(_ == sameInContext).max(0)
def expectedTypesForArg(invocation: MethodInvocation): Array[ParameterType] = {
implicit val context: Context = Context(invocation)
val argExprs = invocation.argumentExpressions
val invoked = invocation.getEffectiveInvokedExpr
val resolvedTypes = invoked match {
case ref: ScReferenceExpression =>
if (!withResolvedFunction) mapResolves(ref.shapeResolve, ref.shapeMultiType)
else mapResolves(ref.multiResolveScala(false), ref.multiType)
case gen: ScGenericCall =>
if (!withResolvedFunction) {
val multiType = gen.shapeMultiType
gen.shapeMultiResolve
.map(mapResolves(_, multiType))
.getOrElse(multiType.map((_, false, false)))
} else {
val multiType = gen.multiType
gen.multiResolve
.map(mapResolves(_, multiType))
.getOrElse(multiType.map((_, false, false)))
}
case _ => Array((invoked.getNonValueType(), false, false))
}
val updatedWithExpected =
resolvedTypes.map {
case (tpe, isPolymorphic, isDynamicNamed) =>
(tpe.map(invocation.updateAccordingToExpectedType), isPolymorphic, isDynamicNamed)
}
updatedWithExpected
.filterNot(_._1.exists(_.equiv(Nothing)))
.flatMap {
case (tpe, isPolymorphic, isDynamicNamed) =>
computeExpectedParamType(
expr,
tpe,
argExprs,
argIndex(argExprs),
Option(invocation),
isDynamicNamed = isDynamicNamed,
stripTypeArgs = isPolymorphic
)
}
}
val result: Array[ParameterType] = sameInContext.getContext match {
case p: ScParenthesisedExpr => expectedTypesUnwrapContextFunction(p, fromUnderscore = false)
//see SLS[6.11]
case b: ScBlockExpr =>
b.resultExpression match {
case Some(e) if b.needCheckExpectedType && e == sameInContext => expectedTypesUnwrapContextFunction(b, fromUnderscore = true)
case _ => Array.empty
}
//see SLS[6.16]
case ifExpr: ScIf if ifExpr.condition.contains(sameInContext) => Array((api.Boolean, None))
case ifExpr: ScIf if ifExpr.elseExpression.isDefined => expectedTypesUnwrapContextFunction(ifExpr, fromUnderscore = true)
//see SLA[6.22]
case tr @ ScTry(Some(`sameInContext`), _, _) => expectedTypesUnwrapContextFunction(tr, fromUnderscore = true)
case wh: ScWhile if wh.condition.contains(sameInContext) => Array((api.Boolean, None))
case _: ScWhile => Array((Unit, None))
case d: ScDo if d.condition.contains(sameInContext) => Array((api.Boolean, None))
case _: ScDo => Array((api.Unit, None))
case _: ScFinallyBlock => Array((api.Unit, None))
case _: ScCatchBlock => Array.empty
case te: ScThrow =>
// Not in the SLS, but in the implementation.
val throwableClass = ScalaPsiManager.instance(te.getProject).getCachedClass(te.resolveScope, "java.lang.Throwable")
val throwableType = throwableClass.map(new ScDesignatorType(_)).getOrElse(Any)
Array((throwableType, None))
//see SLS[8.4]
case c: ScCaseClause => c.getContext.getContext match {
case m: ScMatch =>
val expectedForMatch = m.expectedTypesEx()
if (expectedForMatch.isEmpty) Array.empty
else {
val matchSubst = PatternTypeInference.doForMatchClause(m, c)
expectedForMatch.map { case (tpe, elem) => (matchSubst(tpe), elem) }
}
case b: ScBlockExpr if b.isInCatchBlock =>
b.getContext.getContext.asInstanceOf[ScTry].expectedTypesEx(fromUnderscore = true)
case b: ScBlockExpr if b.isPartialFunction =>
val expectedForPf = b.expectedTypesEx(fromUnderscore = true)
val functionLikeType = FunctionLikeType(expr)
expectedForPf.collect {
case (functionLikeType(_, resTpe, paramTypes), te) =>
val subst = c.pattern.fold(ScSubstitutor.empty) { pattern =>
val scrutineeType =
if (paramTypes.size == 1) paramTypes.head
else TupleType(paramTypes, context = pattern)
PatternTypeInference.doTypeInference(pattern, scrutineeType)
}
(subst(resTpe), te)
}
case _ => Array.empty
}
//see SLS[6.23]
case f: ScFunctionExpr => f.expectedTypesEx(fromUnderscore = true).flatMap(fromFunction(_, f.isContext))
case f: ScPolyFunctionExpr => f.expectedTypesEx(fromUnderscore = true).flatMap(fromPolyFunction(f.typeParameters))
case t: ScTypedExpression if t.getLastChild.is[ScSequenceArg] =>
t.expectedTypesEx(fromUnderscore = true)
//SLS[6.13]
case t: ScTypedExpression =>
t.typeElement match {
case Some(te) => Array((te.`type`().getOrAny, Some(te)))
case _ => Array.empty
}
//SLS[6.15]
case a: ScAssignment if a.rightExpression.getOrElse(null: ScExpression) == sameInContext =>
a.leftExpression match {
case ref: ScReferenceExpression if (!a.getContext.is[ScArgumentExprList] && !(
a.getContext.is[ScInfixArgumentExpression] && a.getContext.asInstanceOf[ScInfixArgumentExpression].isCall)) ||
ref.qualifier.isDefined ||
ScUnderScoreSectionUtil.isUnderscore(expr) /* See SCL-3512, SCL-3525, SCL-4809, SCL-6785 */ =>
ref.bind() match {
case Some(ScalaResolveResult(named: PsiNamedElement, subst: ScSubstitutor)) =>
named.nameContext match {
case v: ScValue =>
Array((subst(named.asInstanceOf[ScTypedDefinition].
`type`().getOrAny), v.typeElement))
case v: ScVariable =>
Array((subst(named.asInstanceOf[ScTypedDefinition].
`type`().getOrAny), v.typeElement))
case f: ScFunction if f.paramClauses.clauses.isEmpty =>
a.mirrorMethodCall match {
case Some(call) =>
call.args.exprs.head.expectedTypesEx(fromUnderscore = fromUnderscore)
case None => Array.empty
}
case p: ScParameter =>
//for named parameters
Array((subst(p.`type`().getOrAny), p.typeElement))
case f: PsiField =>
Array((subst(f.getType.toScType()), None))
case _ => Array.empty
}
case _ => Array.empty
}
case _: ScReferenceExpression => expectedExprTypes(a)
case _: ScMethodCall =>
a.mirrorMethodCall match {
case Some(mirrorCall) => mirrorCall.args.exprs.last.expectedTypesEx(fromUnderscore = fromUnderscore)
case _ => Array.empty
}
case _ => Array.empty
}
//method application
case tuple: ScTuple if tuple.isCall => expectedTypesForArg(tuple.getContext.asInstanceOf[ScInfixExpr])
case tuple: ScTuple =>
val result = Array.newBuilder[ParameterType]
val exprs = tuple.exprs
val index = exprs.indexOf(sameInContext)
@tailrec
def addType(aType: ScType): Unit = {
aType match {
case _: ScAbstractType => addType(aType.removeAbstracts)
case TupleType(comps) if comps.length == exprs.length =>
result += ((comps(index), None))
case _ =>
}
}
if (index >= 0) {
for (tp: ScType <- tuple.expectedTypes()) addType(tp)
}
result.result()
case comp: ScNamedTupleExprComponent =>
val tuple = comp.namedTuple
val result = Array.newBuilder[ParameterType]
val components = tuple.components
val index = components.indexOf(comp)
@tailrec
def addType(aType: ScType): Unit = {
aType match {
case _: ScAbstractType => addType(aType.removeAbstracts)
case NamedTupleType(expectedComps) if expectedComps.length == components.length =>
expectedComps(index) match {
case (NamedTupleType.NameType(expectedName), expectedType) if expectedName == comp.name =>
result += expectedType -> None
case _ =>
}
case _ =>
}
}
if (index >= 0) {
for (tp: ScType <- tuple.expectedTypes()) addType(tp)
}
result.result()
case infix@ScInfixExpr.withAssoc(_, _, `sameInContext`) if !expr.is[ScTuple] =>
expr match {
case p: ScParenthesisedExpr if p.innerElement.isEmpty => return Array.empty
case _ =>
}
expectedTypesForArg(infix)
//SLS[4.1]
case v @ ScPatternDefinition.expr(`sameInContext`) if v.isSimple || (v.pList.patterns match { case Seq(_: ScWildcardPattern) => true; case _ => false }) =>
declaredOrInheritedType(v)
case v @ ScVariableDefinition.expr(`sameInContext`) if v.isSimple || (v.pList.patterns match { case Seq(_: ScWildcardPattern) => true; case _ => false }) =>
declaredOrInheritedType(v)
//SLS[4.6]
case v: ScFunctionDefinition if v.body.contains(sameInContext) => declaredOrInheritedType(v)
//default parameters
case param: ScParameter =>
param.typeElement match {
case Some(_) => Array((param.`type`().getOrAny, param.typeElement))
case _ => Array.empty
}
case ret: ScReturn =>
val fun: ScFunction = PsiTreeUtil.getContextOfType(ret, true, classOf[ScFunction])
if (fun == null) return Array.empty
fun.returnTypeElement match {
case Some(rte: ScTypeElement) =>
fun.returnType match {
case Right(rt) => Array((rt, Some(rte)))
case _ => Array.empty
}
case None => Array.empty
}
case args: ScArgumentExprList =>
args.getContext match {
case mc: ScMethodCall => expectedTypesForArg(mc)
case ctx @ (_: ScConstructorInvocation | _: ScSelfInvocation) =>
val argExprs = args.exprs
val argIdx = argIndex(argExprs)
val tps = ctx match {
case c: ScConstructorInvocation =>
val clauseIdx = c.arguments.indexOf(args)
if (!withResolvedFunction) c.shapeMultiType(clauseIdx)
else c.multiType(clauseIdx)
case s: ScSelfInvocation =>
val clauseIdx = s.arguments.indexOf(args)
if (!withResolvedFunction) s.shapeMultiType(clauseIdx)
else s.multiType(clauseIdx)
}
tps.flatMap(computeExpectedParamType(expr, _, argExprs, argIdx))
case _ =>
Array.empty
}
case guard: ScGuard =>
guard.desugared flatMap { _.content } match {
case Some(content) => content.expectedTypesEx(fromUnderscore = fromUnderscore)
case _ => Array.empty
}
case b: ScBlock if {
val context = b.getContext
context.is[ScTry, ScCaseClause, ScFunctionExpr] ||
//extra null checks are needed for some broken code, in order not to throw NPEs
context != null && {
val context2 = context.getContext
context2 != null && context2.getContext.is[ScCatchBlock]
}
} =>
b.resultExpression match {
case Some(e) if sameInContext == e => b.expectedTypesEx(fromUnderscore = true)
case _ => Array.empty
}
case _ => Array.empty
}
@tailrec
def checkIsUnderscore(expr: ScExpression): Boolean = {
expr match {
case p: ScParenthesisedExpr =>
p.innerElement match {
case Some(e) => checkIsUnderscore(e)
case _ => false
}
case _ => ScUnderScoreSectionUtil.underscores(expr).nonEmpty
}
}
if (fromUnderscore && checkIsUnderscore(expr)) {
val res = new ArrayBuffer[ParameterType]
for (tp <- result) {
tp._1 match {
case FunctionType(rt: ScType, _) => res += ((rt, None))
case _ =>
}
}
res.toArray
} else result
}