in scala/scala-impl/src/org/jetbrains/plugins/scala/lang/psi/impl/expr/ScForImpl.scala [152:568]
private def generateDesugaredExprTextWithMappings(forDisplay: Boolean): Option[(ScExpression, Iterator[(ScPattern, PsiElement)], Iterator[(ScEnumerator, PsiElement)])] = {
val forceSingleLine = !(forDisplay && this.getText.contains("\n"))
var nextNameIdx = 0
val `=>` = ScalaPsiUtil.functionArrow(getProject)
val underscores = ScUnderScoreSectionUtil.underscores(this).zipWithIndex.toMap
def underscoreName(i: Int): String = s"forAnonParam$$$i"
def allUnderscores(expr: PsiElement): Seq[ScUnderscoreSection] = {
expr match {
case underscore: ScUnderscoreSection => Seq(underscore)
case _ => expr.getChildren.flatMap(allUnderscores).toSeq
}
}
def normalizeUnderscores(expr: ScExpression): ScExpression = {
expr match {
case underscore: ScUnderscoreSection =>
underscores.
get(underscore).
map(underscoreName).
map(ScalaPsiElementFactory.createReferenceExpressionFromText).
getOrElse(underscore)
case _ =>
allUnderscores(expr) map {
underscores.get
} match {
case underscoreIndices if !underscoreIndices.exists(_.isDefined) =>
expr
case underscoreIndices =>
val copyOfExpr = expr.copy().asInstanceOf[ScExpression]
for {
(underscore, Some(index)) <- allUnderscores(copyOfExpr) zip underscoreIndices
name = underscoreName(index)
referenceExpression = ScalaPsiElementFactory.createReferenceExpressionFromText(name)
} {
underscore.replaceExpression(referenceExpression, removeParenthesis = false)
}
copyOfExpr
}
}
}
def toTextWithNormalizedUnderscores(expr: ScExpression): String =
toTextWithPrevWhitespaceInScala3(normalizeUnderscores(expr))
def canSkipPatternMatchFilter(pattern: ScPattern): Boolean = {
val features = pattern.features
lazy val isIrrefutablePattern = pattern.isIrrefutableFor(
forDisplay.option(pattern.expectedType).flatten.getOrElse(StdTypes.instance.Any)
)
if (features.isScala3) {
val hasCase = pattern.prevSiblingNotWhitespace.exists(_.getNode.getElementType == ScalaTokenTypes.kCASE)
!hasCase && (features.`Scala 3 Irrefutable Patterns` || isIrrefutablePattern)
}
else isIrrefutablePattern
}
val resultText = new mutable.StringBuilder()
val patternMappings = mutable.Map.empty[ScPattern, Int]
val enumMappings = mutable.Map.empty[ScEnumerator, Int]
def markMappingHere[K](whatOpt: Option[K], mappings: mutable.Map[K, Int]): Unit = {
for (what <- whatOpt if !mappings.contains(what))
mappings += what -> resultText.length
}
def appendFunc[R](
funcName: String,
enumerator: Option[ScEnumerator],
args: Seq[(Option[ScPattern], String)],
forceCases: Boolean = false,
forceBlock: Boolean = false
)(appendBody: => R
): R = {
val argPatterns = args.flatMap(_._1)
val needsCase = !forDisplay || forceCases || args.size > 1 || argPatterns.exists(needsDeconstruction)
val needsParenthesis = args.size > 1 || !needsCase && argPatterns.exists(needsParenthesisAsLambdaArgument)
if (!forceSingleLine) {
resultText ++= "\n"
}
resultText ++= "."
markMappingHere(enumerator, enumMappings)
resultText ++= funcName
resultText ++= (if (needsCase) " { case " else if (forceBlock) " { " else "(")
if (needsParenthesis)
resultText ++= "("
if (args.isEmpty) {
resultText ++= "_"
}
for (((p, text), idx) <- args.zipWithIndex) {
if (idx != 0) {
resultText ++= ", "
}
markMappingHere(p, patternMappings)
resultText ++= text
}
if (needsParenthesis)
resultText ++= ")"
resultText ++= " "
resultText ++= `=>`
resultText ++= " "
// append the body part
val ret = appendBody
resultText ++= (if (needsCase || forceBlock) " }" else ")")
ret
}
case class ForBinding(forBinding: ScForBinding) {
def exprText: String =
forBinding.expr
.fold("???")(toTextWithNormalizedUnderscores)
val pattern: Option[ScPattern] = Option(forBinding.pattern)
private val bindingPattern: Option[ScBindingPattern] = pattern.collect {
case pattern: ScBindingPattern => pattern
}
val isBinding: Boolean = bindingPattern.isDefined
val isWildCard: Boolean = pattern.exists(_.is[ScWildcardPattern])
val name: String = bindingPattern.fold({
nextNameIdx += 1
s"v$$${if (forDisplay) "" else "forIntellij"}$nextNameIdx"
})(_.name)
def patternText: String = pattern.fold(name)(_.getText)
}
def appendGen(gen: ScGenerator, restEnums: Seq[ScEnumerator]): Unit = {
val rvalue = gen.expr.map(normalizeUnderscores)
val isLastGen = !restEnums.exists(_.is[ScGenerator])
val pattern = gen.pattern
type Arg = (Option[ScPattern], String)
val initialArg = Seq(Some(pattern) -> pattern.getText)
// start with the generator expression
val generatorNeedsBlock = this.features.indentationBasedSyntaxEnabled && rvalue.exists(_.textContains('\n'))
lazy val generatorNeedsParenthesis = rvalue.exists {
rvalue =>
val inParenthesis = code"($rvalue).foo".getFirstChild.asInstanceOf[ScParenthesisedExpr]
ScalaPsiUtil.needParentheses(inParenthesis, inParenthesis.innerElement.get)
}
if (generatorNeedsBlock)
resultText ++= "{"
else if (generatorNeedsParenthesis)
resultText ++= "("
resultText ++= rvalue.map(
if (generatorNeedsBlock)
toTextWithPrevWhitespaceInScala3
else
_.getText
).getOrElse("???")
if (generatorNeedsBlock)
resultText ++= "}"
else if (generatorNeedsParenthesis)
resultText ++= ")"
// add guards and assignment enumerators
val filterFunc: String = if (forDisplay && compilerRewritesWithFilterToFilter) {
val rvalueType = rvalue.flatMap(_.`type`().toOption)
def hasWithFilter = rvalueType.exists(hasMethod(_, "withFilter"))
def hasFilter = rvalueType.exists(hasMethod(_, "filter"))
// try to use withFilter
// if the type does not have a withFilter method use filter except if filter doesn't exist either
if (hasWithFilter || !hasFilter) "withFilter"
else "filter"
} else {
"withFilter"
}
if (this.betterMonadicForEnabled || canSkipPatternMatchFilter(pattern)) {
//do nothing
} else {
appendFunc(filterFunc, None, initialArg, forceCases = true) {
resultText ++= "true; case _ => false"
}
}
def printForBindings(forBindings: Seq[ForBinding], newLines: Boolean): Unit = {
if (newLines) {
resultText ++= "\n"
}
forBindings.foreach {
binding =>
val pattern = binding.pattern
val patternText = binding.patternText
resultText ++= "val "
if (binding.isBinding || (forDisplay && binding.isWildCard)) {
markMappingHere(pattern, patternMappings)
resultText ++= patternText
} else {
val needsParenthesis = forDisplay && pattern.exists(needsParenthesisAsNamedPattern)
resultText ++= binding.name
resultText ++= "@"
if (needsParenthesis)
resultText ++= "("
markMappingHere(pattern, patternMappings)
resultText ++= patternText
if (needsParenthesis)
resultText ++= ")"
}
resultText ++= " = "
resultText ++= binding.exprText
resultText ++= (if (newLines) "\n" else "; ")
}
}
// before a guard can be printed, all forBindings have to be mapped into the argument tuple
def printForBindingMap(forBindings: Seq[ForBinding], args: Seq[Arg]): Seq[Arg] = {
forBindings match {
case first +: _ =>
appendFunc("map", Some(first.forBinding), args, forceBlock = true) {
val multilineForBindings = !forceSingleLine && (forBindings.length > 1 || forBindings.exists(_.forBinding.getText.contains("\n")))
printForBindings(forBindings, newLines = multilineForBindings)
if (multilineForBindings) {
resultText ++= "\n"
}
// remove wildcards
val argsWithoutWildcards = args.filterNot(_._2 == "_")
val usedBindings = if (forDisplay) forBindings.filter(!_.isWildCard) else forBindings
val needsArgParenthesis = argsWithoutWildcards.length + usedBindings.size != 1
// if args is empty, we return ()
if (needsArgParenthesis)
resultText ++= "("
resultText ++= (argsWithoutWildcards.map(_._2) ++ usedBindings.map(_.name)).mkString(", ")
if (needsArgParenthesis)
resultText ++= ")"
if (multilineForBindings) {
resultText ++= "\n"
}
argsWithoutWildcards ++ usedBindings.map(b => b.pattern -> b.patternText)
}
case _ =>
args
}
}
val (forBindingsAndGuards, nextEnums) = restEnums.span(!_.is[ScGenerator])
val (forBindingsInGenBody, generatorArgs) = {
// accumulate all ForBindings for the next guard
// (but not if we want to display it, because it changes semantic)
// see #SCL-16463
forBindingsAndGuards.foldLeft[(Seq[ForBinding], Seq[Arg])]((Seq.empty, initialArg)) {
case ((forBindings, args), forBinding: ScForBinding) =>
if (!forDisplay) (forBindings :+ ForBinding(forBinding), args)
else {
val argsWithBindings = printForBindingMap(Seq(ForBinding(forBinding)), args)
(Seq.empty, argsWithBindings)
}
case ((forBindings, args), guard: ScGuard) =>
val argsWithBindings = printForBindingMap(forBindings, args)
appendFunc(filterFunc, Some(guard), argsWithBindings) {
resultText ++= guard.expr.map(toTextWithNormalizedUnderscores).getOrElse("???")
}
(Seq.empty, argsWithBindings)
case _ =>
???
}
}
val funcText = if (isYield)
if (isLastGen) "map" else "flatMap"
else
"foreach"
val needsMultiline =
!forceSingleLine && forBindingsInGenBody.nonEmpty ||
forBindingsInGenBody.exists(_.forBinding.textContains('\n'))
appendFunc(funcText, Some(gen), generatorArgs, forceBlock = needsMultiline || forBindingsInGenBody.nonEmpty) {
printForBindings(forBindingsInGenBody, newLines = needsMultiline)
nextEnums.headOption match {
case Some(nextGen: ScGenerator) =>
if (!forceSingleLine) {
resultText ++= "\n"
}
appendGen(nextGen, nextEnums.tail)
if (!forceSingleLine) {
resultText ++= "\n"
}
case _ =>
assert(nextEnums.isEmpty)
val needsBraces = forBindingsInGenBody.exists(_.forBinding.textContains('\n'))
if (needsMultiline) {
// add an empty line between the value definitions of the for-bindings and the body
// to avoid merging
resultText ++= "\n"
if (needsBraces)
resultText += '{'
}
// sometimes the body of a for loop is enclosed in {}
// we can remove these brackets
def withoutBodyBrackets(e: ScExpression): ScalaPsiElement = e match {
case ScBlockExpr.Statements(inner) => inner
case _ => e
}
resultText ++= body
.map(normalizeUnderscores)
.map(withoutBodyBrackets)
.map(bodyToText)
.getOrElse("{}")
if (needsMultiline) {
resultText ++= "\n"
if (needsBraces)
resultText += '}'
}
}
}
}
val allEnumerators = enumerators.fold(Seq.empty[ScEnumerator])(_.enumerators)
val (initialBindings, firstGenerator, restEnumerators) = {
val (initialBindings, rest) = allEnumerators.span(!_.is[ScGenerator])
rest match {
case (firstGenerator: ScGenerator) +: restEnumerators =>
(initialBindings.filterByType[ScForBinding], firstGenerator, restEnumerators)
case _ =>
return None
}
}
val lambdaPrefix = underscores.valuesIterator match {
case iterator if iterator.isEmpty => ""
case iterator =>
(iterator.map(underscoreName).toSeq match {
case Seq(arg) => arg
case args => args.commaSeparated(model = Model.Parentheses)
}) + " " + `=>` + " "
}
if (initialBindings.isEmpty) {
appendGen(firstGenerator, restEnumerators)
} else {
resultText += '{'
if (!forceSingleLine) {
resultText += '\n'
}
for (binding <- initialBindings) {
val b = ForBinding(binding)
resultText ++= "val "
markMappingHere(b.pattern, patternMappings)
resultText ++= b.patternText
resultText ++= " = "
resultText ++= b.exprText
resultText += (if (forceSingleLine) ';' else '\n')
}
appendGen(firstGenerator, restEnumerators)
resultText += '}'
}
// we need the braces to handle incomplete fors, because createExpressionWithContextFromText
// needs to parse exactly one expression.
// The parenthesis is needed for correct newline-handling
val prefix = "{("
resultText.insert(0, prefix + lambdaPrefix)
val shiftOffset = prefix.length + lambdaPrefix.length
resultText ++= ")}"
val expression = ScalaPsiElementFactory.createExpressionWithContextFromText(
resultText.toString,
getContext,
this
)
def withElements[T](mappings: mutable.Map[T, Int]) = for {
(original, offset) <- if (forDisplay) Iterator.empty else mappings.iterator
element = expression.findElementAt(offset + shiftOffset)
if element != null
} yield original -> element
val desugared = expression match {
case ScBlock(ScParenthesisedExpr(desugaredFor)) =>
desugaredFor.context = getContext
desugaredFor.child = this
desugaredFor
case _ =>
expression
}
Some((desugared, withElements(patternMappings), withElements(enumMappings)))
}