override def expectedExprTypes()

in scala/scala-impl/src/org/jetbrains/plugins/scala/lang/psi/impl/expr/ExpectedTypesImpl.scala [243:632]


  override def expectedExprTypes(expr: ScExpression, withResolvedFunction: Boolean = false,
                                 fromUnderscore: Boolean = true): Array[ParameterType] = {
    import expr.projectContext
    implicit val context: Context = Context(expr)

    val sameInContext = expr.getDeepSameElementInContext

    def fromFunction(tp: ParameterType, isContextFunction: Boolean): Array[ParameterType] = {
      val functionLikeType = FunctionLikeType(expr)

      val tpUnwrapped =
        if (!isContextFunction) unwrapContextFunctionType(tp._1)
        else                    tp._1

      tpUnwrapped match {
        case functionLikeType(_, retTpe, _) => Array((retTpe, None))
        case _                              => Array.empty
      }
    }

    def fromPolyFunction(tyParams: Seq[ScTypeParam])(tp: ParameterType): Array[ParameterType] = {
      tp._1 match {
        case PolyFunctionType(sig, retType) =>
          // match
          //   [pParams..] => bodyType
          //   [tyParams..] => ...
          // so
          // substitute bodyType[pParams.. -> tyParams..]
          val pParams = sig.typeParams

          val bodyType = {
            val paramTypes = sig.substitutedTypes.head.map(_.apply())
            FunctionType((retType, paramTypes))(expr.elementScope, Context.Empty)
          }

          //val f2: [T, S] => T => S => Unit = [_, _] => x => y => ???
          //Above is valid scala, which would naively be translated into
          //scala.PolyFunction { def apply[_, _](x: _): Function1[_, Unit]
          //which is not a valid scala snippet, so just introduce fresh names for
          //type parameters named _ here.
          val renameUnderscoreTypeParams =
            tyParams.mapWithIndex {
              case (tParam, idx) =>
                val renamed =
                  if (tParam.name == "_") {
                    TypeParameterType(
                      TypeParameter.light(
                        s"_$$$idx",
                        tParam.typeParameters.map(TypeParameter.apply),
                        tParam.lowerBound.getOrNothing,
                        tParam.upperBound.getOrAny
                      )
                    )
                  } else TypeParameterType(tParam)
                tParam -> renamed
            }.to(Map)

          val bindTypeParams = ScSubstitutor.bind(pParams, tyParams)(p => TypeParameterType(TypeParameter(p)))
          val fullSubst = bindTypeParams.followed(ScSubstitutor.bind(tyParams)(renameUnderscoreTypeParams))
          val result = fullSubst(bodyType)
          Array((result, None))
        case _ =>
          Array.empty
      }
    }

    def expectedTypesUnwrapContextFunction(e: ScExpression, fromUnderscore: Boolean): Array[ParameterType] =
      e.expectedTypesEx(fromUnderscore).map(pt => unwrapContextFunctionType(pt._1) -> None)

    def mapResolves(resolves: Array[ScalaResolveResult], types: Array[TypeResult]): Array[(TypeResult, Boolean, Boolean)] =
      resolves.zip(types).map {
        case (r, tp) =>
          val syntheticKindProjectorApplyNames = Set(Lambda, LambdaSymbolic)

          val actualType = r match {
            case srr @ ScalaResolveResult(fun: ScFunction, s: ScSubstitutor) if fun.name == CommonNames.Apply =>
              if (srr.innerResolveResult.exists(inner => syntheticKindProjectorApplyNames.contains(inner.name)))
                tp
              else Right(fun.polymorphicType(s))
            case _ => tp
          }

          val isPolymorphic = r.element match {
            case tpo: ScTypeParametersOwner     => tpo.typeParameters.nonEmpty
            case tpo: PsiTypeParameterListOwner => tpo.getTypeParameters.nonEmpty
            case _                              => false
          }

          (actualType, isPolymorphic, isApplyDynamicNamed(r))
      }

    def argIndex(argExprs: Seq[ScExpression]) =
      if (sameInContext == null) 0
      else argExprs.indexWhere(_ == sameInContext).max(0)

    def expectedTypesForArg(invocation: MethodInvocation): Array[ParameterType] = {
      implicit val context: Context = Context(invocation)

      val argExprs = invocation.argumentExpressions
      val invoked  = invocation.getEffectiveInvokedExpr

      val resolvedTypes = invoked match {
        case ref: ScReferenceExpression =>
          if (!withResolvedFunction) mapResolves(ref.shapeResolve, ref.shapeMultiType)
          else                       mapResolves(ref.multiResolveScala(false), ref.multiType)
        case gen: ScGenericCall =>
          if (!withResolvedFunction) {
            val multiType = gen.shapeMultiType

            gen.shapeMultiResolve
              .map(mapResolves(_, multiType))
              .getOrElse(multiType.map((_, false, false)))
          } else {
            val multiType = gen.multiType

            gen.multiResolve
              .map(mapResolves(_, multiType))
              .getOrElse(multiType.map((_, false, false)))
          }
        case _ => Array((invoked.getNonValueType(), false, false))
      }

      val updatedWithExpected =
        resolvedTypes.map {
          case (tpe, isPolymorphic, isDynamicNamed) =>
            (tpe.map(invocation.updateAccordingToExpectedType), isPolymorphic, isDynamicNamed)
        }

      updatedWithExpected
        .filterNot(_._1.exists(_.equiv(Nothing)))
        .flatMap {
          case (tpe, isPolymorphic, isDynamicNamed) =>
            computeExpectedParamType(
              expr,
              tpe,
              argExprs,
              argIndex(argExprs),
              Option(invocation),
              isDynamicNamed = isDynamicNamed,
              stripTypeArgs  = isPolymorphic
            )
        }
    }

    val result: Array[ParameterType] = sameInContext.getContext match {
      case p: ScParenthesisedExpr => expectedTypesUnwrapContextFunction(p, fromUnderscore = false)
      //see SLS[6.11]
      case b: ScBlockExpr =>
        b.resultExpression match {
          case Some(e) if b.needCheckExpectedType && e == sameInContext => expectedTypesUnwrapContextFunction(b, fromUnderscore = true)
          case _                                                        => Array.empty
        }
      //see SLS[6.16]
      case ifExpr: ScIf if ifExpr.condition.contains(sameInContext) => Array((api.Boolean, None))
      case ifExpr: ScIf if ifExpr.elseExpression.isDefined => expectedTypesUnwrapContextFunction(ifExpr, fromUnderscore = true)
      //see SLA[6.22]
      case tr @ ScTry(Some(`sameInContext`), _, _) => expectedTypesUnwrapContextFunction(tr, fromUnderscore = true)
      case wh: ScWhile if wh.condition.contains(sameInContext) => Array((api.Boolean, None))
      case _: ScWhile                                          => Array((Unit, None))
      case d: ScDo if d.condition.contains(sameInContext)      => Array((api.Boolean, None))
      case _: ScDo                                             => Array((api.Unit, None))
      case _: ScFinallyBlock                                   => Array((api.Unit, None))
      case _: ScCatchBlock                                     => Array.empty
      case te: ScThrow =>
        // Not in the SLS, but in the implementation.
        val throwableClass = ScalaPsiManager.instance(te.getProject).getCachedClass(te.resolveScope, "java.lang.Throwable")
        val throwableType = throwableClass.map(new ScDesignatorType(_)).getOrElse(Any)
        Array((throwableType, None))
      //see SLS[8.4]
      case c: ScCaseClause => c.getContext.getContext match {
        case m: ScMatch =>
          val expectedForMatch = m.expectedTypesEx()

          if (expectedForMatch.isEmpty) Array.empty
          else {
            val matchSubst = PatternTypeInference.doForMatchClause(m, c)
            expectedForMatch.map { case (tpe, elem) => (matchSubst(tpe), elem) }
          }
        case b: ScBlockExpr if b.isInCatchBlock =>
          b.getContext.getContext.asInstanceOf[ScTry].expectedTypesEx(fromUnderscore = true)
        case b: ScBlockExpr if b.isPartialFunction =>
          val expectedForPf    = b.expectedTypesEx(fromUnderscore = true)
          val functionLikeType = FunctionLikeType(expr)

          expectedForPf.collect {
            case (functionLikeType(_, resTpe, paramTypes), te) =>

              val subst = c.pattern.fold(ScSubstitutor.empty) { pattern =>
                val scrutineeType =
                  if (paramTypes.size == 1) paramTypes.head
                  else                      TupleType(paramTypes, context = pattern)

                PatternTypeInference.doTypeInference(pattern, scrutineeType)
              }

              (subst(resTpe), te)
          }
        case _ => Array.empty
      }
      //see SLS[6.23]
      case f: ScFunctionExpr => f.expectedTypesEx(fromUnderscore = true).flatMap(fromFunction(_, f.isContext))
      case f: ScPolyFunctionExpr => f.expectedTypesEx(fromUnderscore = true).flatMap(fromPolyFunction(f.typeParameters))
      case t: ScTypedExpression if t.getLastChild.is[ScSequenceArg] =>
        t.expectedTypesEx(fromUnderscore = true)
      //SLS[6.13]
      case t: ScTypedExpression =>
        t.typeElement match {
          case Some(te) => Array((te.`type`().getOrAny, Some(te)))
          case _ => Array.empty
        }
      //SLS[6.15]
      case a: ScAssignment if a.rightExpression.getOrElse(null: ScExpression) == sameInContext =>
        a.leftExpression match {
          case ref: ScReferenceExpression if (!a.getContext.is[ScArgumentExprList] && !(
            a.getContext.is[ScInfixArgumentExpression] && a.getContext.asInstanceOf[ScInfixArgumentExpression].isCall)) ||
            ref.qualifier.isDefined ||
            ScUnderScoreSectionUtil.isUnderscore(expr) /* See SCL-3512, SCL-3525, SCL-4809, SCL-6785 */ =>
            ref.bind() match {
              case Some(ScalaResolveResult(named: PsiNamedElement, subst: ScSubstitutor)) =>
                named.nameContext match {
                  case v: ScValue =>
                    Array((subst(named.asInstanceOf[ScTypedDefinition].
                      `type`().getOrAny), v.typeElement))
                  case v: ScVariable =>
                    Array((subst(named.asInstanceOf[ScTypedDefinition].
                      `type`().getOrAny), v.typeElement))
                  case f: ScFunction if f.paramClauses.clauses.isEmpty =>
                    a.mirrorMethodCall match {
                      case Some(call) =>
                        call.args.exprs.head.expectedTypesEx(fromUnderscore = fromUnderscore)
                      case None => Array.empty
                    }
                  case p: ScParameter =>
                    //for named parameters
                    Array((subst(p.`type`().getOrAny), p.typeElement))
                  case f: PsiField =>
                    Array((subst(f.getType.toScType()), None))
                  case _ => Array.empty
                }
              case _ => Array.empty
            }
          case _: ScReferenceExpression => expectedExprTypes(a)
          case _: ScMethodCall =>
            a.mirrorMethodCall match {
              case Some(mirrorCall) => mirrorCall.args.exprs.last.expectedTypesEx(fromUnderscore = fromUnderscore)
              case _ => Array.empty
            }
          case _ => Array.empty
        }
      //method application
      case tuple: ScTuple if tuple.isCall => expectedTypesForArg(tuple.getContext.asInstanceOf[ScInfixExpr])
      case tuple: ScTuple =>
        val result = Array.newBuilder[ParameterType]
        val exprs = tuple.exprs
        val index = exprs.indexOf(sameInContext)
        @tailrec
        def addType(aType: ScType): Unit = {
          aType match {
            case _: ScAbstractType => addType(aType.removeAbstracts)
            case TupleType(comps) if comps.length == exprs.length =>
              result += ((comps(index), None))
            case _ =>
          }
        }
        if (index >= 0) {
          for (tp: ScType <- tuple.expectedTypes()) addType(tp)
        }
        result.result()
      case comp: ScNamedTupleExprComponent =>
        val tuple = comp.namedTuple
        val result = Array.newBuilder[ParameterType]
        val components = tuple.components
        val index = components.indexOf(comp)
        @tailrec
        def addType(aType: ScType): Unit = {
          aType match {
            case _: ScAbstractType => addType(aType.removeAbstracts)
            case NamedTupleType(expectedComps) if expectedComps.length == components.length =>
              expectedComps(index) match {
                case (NamedTupleType.NameType(expectedName), expectedType) if expectedName == comp.name =>
                  result += expectedType -> None
                case _ =>
              }
            case _ =>
          }
        }
        if (index >= 0) {
          for (tp: ScType <- tuple.expectedTypes()) addType(tp)
        }
        result.result()
      case infix@ScInfixExpr.withAssoc(_, _, `sameInContext`) if !expr.is[ScTuple] =>
        expr match {
          case p: ScParenthesisedExpr if p.innerElement.isEmpty => return Array.empty
          case _ =>
        }
        expectedTypesForArg(infix)
      //SLS[4.1]
      case v @ ScPatternDefinition.expr(`sameInContext`)  if v.isSimple || (v.pList.patterns match { case Seq(_: ScWildcardPattern) => true; case _ => false }) =>
        declaredOrInheritedType(v)
      case v @ ScVariableDefinition.expr(`sameInContext`) if v.isSimple || (v.pList.patterns match { case Seq(_: ScWildcardPattern) => true; case _ => false }) =>
        declaredOrInheritedType(v)
      //SLS[4.6]
      case v: ScFunctionDefinition if v.body.contains(sameInContext) => declaredOrInheritedType(v)
      //default parameters
      case param: ScParameter =>
        param.typeElement match {
          case Some(_) => Array((param.`type`().getOrAny, param.typeElement))
          case _ => Array.empty
        }
      case ret: ScReturn =>
        val fun: ScFunction = PsiTreeUtil.getContextOfType(ret, true, classOf[ScFunction])
        if (fun == null) return Array.empty
        fun.returnTypeElement match {
          case Some(rte: ScTypeElement) =>
            fun.returnType match {
              case Right(rt) => Array((rt, Some(rte)))
              case _ => Array.empty
            }
          case None => Array.empty
        }
      case args: ScArgumentExprList =>
        args.getContext match {
          case mc: ScMethodCall => expectedTypesForArg(mc)
          case ctx @ (_: ScConstructorInvocation | _: ScSelfInvocation) =>
            val argExprs = args.exprs
            val argIdx = argIndex(argExprs)

            val tps = ctx match {
              case c: ScConstructorInvocation =>
                val clauseIdx = c.arguments.indexOf(args)

                if (!withResolvedFunction) c.shapeMultiType(clauseIdx)
                else c.multiType(clauseIdx)

              case s: ScSelfInvocation =>
                val clauseIdx = s.arguments.indexOf(args)

                if (!withResolvedFunction) s.shapeMultiType(clauseIdx)
                else s.multiType(clauseIdx)
            }

            tps.flatMap(computeExpectedParamType(expr, _, argExprs, argIdx))

          case _ =>
            Array.empty
        }
      case guard: ScGuard =>
        guard.desugared flatMap { _.content } match {
          case Some(content) => content.expectedTypesEx(fromUnderscore = fromUnderscore)
          case _ => Array.empty
        }
      case b: ScBlock if {
        val context = b.getContext
        context.is[ScTry, ScCaseClause, ScFunctionExpr] ||
          //extra null checks are needed for some broken code, in order not to throw NPEs
          context != null && {
            val context2 = context.getContext
            context2 != null && context2.getContext.is[ScCatchBlock]
          }
      } =>
        b.resultExpression match {
          case Some(e) if sameInContext == e => b.expectedTypesEx(fromUnderscore = true)
          case _ => Array.empty
        }
      case _ => Array.empty
    }

    @tailrec
    def checkIsUnderscore(expr: ScExpression): Boolean = {
      expr match {
        case p: ScParenthesisedExpr =>
          p.innerElement match {
            case Some(e) => checkIsUnderscore(e)
            case _ => false
          }
        case _ => ScUnderScoreSectionUtil.underscores(expr).nonEmpty
      }
    }

    if (fromUnderscore && checkIsUnderscore(expr)) {
      val res = new ArrayBuffer[ParameterType]
      for (tp <- result) {
        tp._1 match {
          case FunctionType(rt: ScType, _) => res += ((rt, None))
          case _ =>
        }
      }
      res.toArray
    } else result
  }