private def generateDesugaredExprTextWithMappings()

in scala/scala-impl/src/org/jetbrains/plugins/scala/lang/psi/impl/expr/ScForImpl.scala [152:568]


  private def generateDesugaredExprTextWithMappings(forDisplay: Boolean): Option[(ScExpression, Iterator[(ScPattern, PsiElement)], Iterator[(ScEnumerator, PsiElement)])] = {
    val forceSingleLine = !(forDisplay && this.getText.contains("\n"))

    var nextNameIdx = 0
    val `=>` = ScalaPsiUtil.functionArrow(getProject)

    val underscores = ScUnderScoreSectionUtil.underscores(this).zipWithIndex.toMap

    def underscoreName(i: Int): String = s"forAnonParam$$$i"

    def allUnderscores(expr: PsiElement): Seq[ScUnderscoreSection] = {
      expr match {
        case underscore: ScUnderscoreSection => Seq(underscore)
        case _ => expr.getChildren.flatMap(allUnderscores).toSeq
      }
    }

    def normalizeUnderscores(expr: ScExpression): ScExpression = {
      expr match {
        case underscore: ScUnderscoreSection =>
          underscores.
            get(underscore).
            map(underscoreName).
            map(ScalaPsiElementFactory.createReferenceExpressionFromText).
            getOrElse(underscore)
        case _ =>
          allUnderscores(expr) map {
            underscores.get
          } match {
            case underscoreIndices if !underscoreIndices.exists(_.isDefined) =>
              expr
            case underscoreIndices =>
              val copyOfExpr = expr.copy().asInstanceOf[ScExpression]
              for {
                (underscore, Some(index)) <- allUnderscores(copyOfExpr) zip underscoreIndices
                name = underscoreName(index)
                referenceExpression = ScalaPsiElementFactory.createReferenceExpressionFromText(name)
              } {
                underscore.replaceExpression(referenceExpression, removeParenthesis = false)
              }
              copyOfExpr
          }
      }
    }

    def toTextWithNormalizedUnderscores(expr: ScExpression): String =
      toTextWithPrevWhitespaceInScala3(normalizeUnderscores(expr))

    def canSkipPatternMatchFilter(pattern: ScPattern): Boolean = {
      val features = pattern.features

      lazy val isIrrefutablePattern = pattern.isIrrefutableFor(
        forDisplay.option(pattern.expectedType).flatten.getOrElse(StdTypes.instance.Any)
      )

      if (features.isScala3) {
        val hasCase = pattern.prevSiblingNotWhitespace.exists(_.getNode.getElementType == ScalaTokenTypes.kCASE)
        !hasCase && (features.`Scala 3 Irrefutable Patterns` || isIrrefutablePattern)
      }
      else isIrrefutablePattern
    }

    val resultText = new mutable.StringBuilder()
    val patternMappings = mutable.Map.empty[ScPattern, Int]
    val enumMappings = mutable.Map.empty[ScEnumerator, Int]

    def markMappingHere[K](whatOpt: Option[K], mappings: mutable.Map[K, Int]): Unit = {
      for (what <- whatOpt if !mappings.contains(what))
        mappings += what -> resultText.length
    }

    def appendFunc[R](
      funcName:   String,
      enumerator: Option[ScEnumerator],
      args:       Seq[(Option[ScPattern], String)],
      forceCases: Boolean = false,
      forceBlock: Boolean = false
    )(appendBody: => R
    ): R = {
      val argPatterns = args.flatMap(_._1)
      val needsCase = !forDisplay || forceCases || args.size > 1 || argPatterns.exists(needsDeconstruction)
      val needsParenthesis = args.size > 1 || !needsCase && argPatterns.exists(needsParenthesisAsLambdaArgument)

      if (!forceSingleLine) {
        resultText ++= "\n"
      }
      resultText ++= "."
      markMappingHere(enumerator, enumMappings)
      resultText ++= funcName

      resultText ++= (if (needsCase) " { case " else if (forceBlock) " { " else "(")

      if (needsParenthesis)
        resultText ++= "("
      if (args.isEmpty) {
        resultText ++= "_"
      }
      for (((p, text), idx) <- args.zipWithIndex) {
        if (idx != 0) {
          resultText ++= ", "
        }
        markMappingHere(p, patternMappings)
        resultText ++= text
      }
      if (needsParenthesis)
        resultText ++= ")"
      resultText ++= " "
      resultText ++= `=>`
      resultText ++= " "

      // append the body part
      val ret = appendBody

      resultText ++= (if (needsCase || forceBlock) " }" else ")")
      ret
    }

    case class ForBinding(forBinding: ScForBinding) {

      def exprText: String =
        forBinding.expr
          .fold("???")(toTextWithNormalizedUnderscores)

      val pattern: Option[ScPattern] = Option(forBinding.pattern)

      private val bindingPattern: Option[ScBindingPattern] = pattern.collect {
        case pattern: ScBindingPattern => pattern
      }

      val isBinding: Boolean = bindingPattern.isDefined
      val isWildCard: Boolean = pattern.exists(_.is[ScWildcardPattern])

      val name: String = bindingPattern.fold({
        nextNameIdx += 1
        s"v$$${if (forDisplay) "" else "forIntellij"}$nextNameIdx"
      })(_.name)

      def patternText: String = pattern.fold(name)(_.getText)
    }

    def appendGen(gen: ScGenerator, restEnums: Seq[ScEnumerator]): Unit = {
      val rvalue = gen.expr.map(normalizeUnderscores)
      val isLastGen = !restEnums.exists(_.is[ScGenerator])

      val pattern = gen.pattern
      type Arg = (Option[ScPattern], String)
      val initialArg = Seq(Some(pattern) -> pattern.getText)

      // start with the generator expression
      val generatorNeedsBlock = this.features.indentationBasedSyntaxEnabled && rvalue.exists(_.textContains('\n'))
      lazy val generatorNeedsParenthesis = rvalue.exists {
        rvalue =>
          val inParenthesis = code"($rvalue).foo".getFirstChild.asInstanceOf[ScParenthesisedExpr]
          ScalaPsiUtil.needParentheses(inParenthesis, inParenthesis.innerElement.get)
      }

      if (generatorNeedsBlock)
        resultText ++= "{"
      else if (generatorNeedsParenthesis)
        resultText ++= "("

      resultText ++= rvalue.map(
        if (generatorNeedsBlock)
          toTextWithPrevWhitespaceInScala3
        else
          _.getText
      ).getOrElse("???")

      if (generatorNeedsBlock)
        resultText ++= "}"
      else if (generatorNeedsParenthesis)
        resultText ++= ")"

      // add guards and assignment enumerators
      val filterFunc: String = if (forDisplay && compilerRewritesWithFilterToFilter) {
        val rvalueType = rvalue.flatMap(_.`type`().toOption)
        def hasWithFilter = rvalueType.exists(hasMethod(_, "withFilter"))
        def hasFilter = rvalueType.exists(hasMethod(_, "filter"))

        // try to use withFilter
        // if the type does not have a withFilter method use filter except if filter doesn't exist either
        if (hasWithFilter || !hasFilter) "withFilter"
        else "filter"
      } else {
        "withFilter"
      }

      if (this.betterMonadicForEnabled || canSkipPatternMatchFilter(pattern)) {
        //do nothing
      } else {
        appendFunc(filterFunc, None, initialArg, forceCases = true) {
          resultText ++= "true; case _ => false"
        }
      }

      def printForBindings(forBindings: Seq[ForBinding], newLines: Boolean): Unit = {
        if (newLines) {
          resultText ++= "\n"
        }
        forBindings.foreach {
          binding =>
            val pattern = binding.pattern
            val patternText = binding.patternText

            resultText ++= "val "
            if (binding.isBinding || (forDisplay && binding.isWildCard)) {
              markMappingHere(pattern, patternMappings)
              resultText ++= patternText
            } else {
              val needsParenthesis = forDisplay && pattern.exists(needsParenthesisAsNamedPattern)
              resultText ++= binding.name
              resultText ++= "@"
              if (needsParenthesis)
                resultText ++= "("
              markMappingHere(pattern, patternMappings)
              resultText ++= patternText
              if (needsParenthesis)
                resultText ++= ")"
            }
            resultText ++= " = "
            resultText ++= binding.exprText
            resultText ++= (if (newLines) "\n" else "; ")
        }
      }

      // before a guard can be printed, all forBindings have to be mapped into the argument tuple
      def printForBindingMap(forBindings: Seq[ForBinding], args: Seq[Arg]): Seq[Arg] = {
        forBindings match {
          case first +: _ =>
            appendFunc("map", Some(first.forBinding), args, forceBlock = true) {
              val multilineForBindings = !forceSingleLine && (forBindings.length > 1 || forBindings.exists(_.forBinding.getText.contains("\n")))
              printForBindings(forBindings, newLines = multilineForBindings)

              if (multilineForBindings) {
                resultText ++= "\n"
              }

              // remove wildcards
              val argsWithoutWildcards = args.filterNot(_._2 == "_")

              val usedBindings = if (forDisplay) forBindings.filter(!_.isWildCard) else forBindings
              val needsArgParenthesis = argsWithoutWildcards.length + usedBindings.size != 1

              // if args is empty, we return ()
              if (needsArgParenthesis)
                resultText ++= "("
              resultText ++= (argsWithoutWildcards.map(_._2) ++ usedBindings.map(_.name)).mkString(", ")
              if (needsArgParenthesis)
                resultText ++= ")"

              if (multilineForBindings) {
                resultText ++= "\n"
              }

              argsWithoutWildcards ++ usedBindings.map(b => b.pattern -> b.patternText)
            }
          case _ =>
            args
        }
      }

      val (forBindingsAndGuards, nextEnums) = restEnums.span(!_.is[ScGenerator])

      val (forBindingsInGenBody, generatorArgs) = {
        // accumulate all ForBindings for the next guard
        // (but not if we want to display it, because it changes semantic)
        // see #SCL-16463
        forBindingsAndGuards.foldLeft[(Seq[ForBinding], Seq[Arg])]((Seq.empty, initialArg)) {
          case ((forBindings, args), forBinding: ScForBinding) =>
            if (!forDisplay) (forBindings :+ ForBinding(forBinding), args)
            else {
              val argsWithBindings = printForBindingMap(Seq(ForBinding(forBinding)), args)
              (Seq.empty, argsWithBindings)
            }

          case ((forBindings, args), guard: ScGuard) =>
            val argsWithBindings = printForBindingMap(forBindings, args)

            appendFunc(filterFunc, Some(guard), argsWithBindings) {
              resultText ++= guard.expr.map(toTextWithNormalizedUnderscores).getOrElse("???")
            }

            (Seq.empty, argsWithBindings)
          case _ =>
            ???
        }
      }

      val funcText = if (isYield)
        if (isLastGen) "map" else "flatMap"
      else
        "foreach"

      val needsMultiline =
        !forceSingleLine && forBindingsInGenBody.nonEmpty ||
          forBindingsInGenBody.exists(_.forBinding.textContains('\n'))
      appendFunc(funcText, Some(gen), generatorArgs, forceBlock = needsMultiline || forBindingsInGenBody.nonEmpty) {
        printForBindings(forBindingsInGenBody, newLines = needsMultiline)

        nextEnums.headOption match {
          case Some(nextGen: ScGenerator) =>
            if (!forceSingleLine) {
              resultText ++= "\n"
            }
            appendGen(nextGen, nextEnums.tail)

            if (!forceSingleLine) {
              resultText ++= "\n"
            }
          case _ =>
            assert(nextEnums.isEmpty)
            val needsBraces = forBindingsInGenBody.exists(_.forBinding.textContains('\n'))

            if (needsMultiline) {
              // add an empty line between the value definitions of the for-bindings and the body
              // to avoid merging
              resultText ++= "\n"
              if (needsBraces)
                resultText += '{'
            }

            // sometimes the body of a for loop is enclosed in {}
            // we can remove these brackets
            def withoutBodyBrackets(e: ScExpression): ScalaPsiElement = e match {
              case ScBlockExpr.Statements(inner) => inner
              case _ => e
            }

            resultText ++= body
              .map(normalizeUnderscores)
              .map(withoutBodyBrackets)
              .map(bodyToText)
              .getOrElse("{}")

            if (needsMultiline) {
              resultText ++= "\n"
              if (needsBraces)
                resultText += '}'
            }
        }
      }
    }

    val allEnumerators = enumerators.fold(Seq.empty[ScEnumerator])(_.enumerators)
    val (initialBindings, firstGenerator, restEnumerators) = {
      val (initialBindings, rest) = allEnumerators.span(!_.is[ScGenerator])

      rest match {
        case (firstGenerator: ScGenerator) +: restEnumerators =>
          (initialBindings.filterByType[ScForBinding], firstGenerator, restEnumerators)
        case _ =>
          return None
      }
    }

    val lambdaPrefix = underscores.valuesIterator match {
      case iterator if iterator.isEmpty => ""
      case iterator =>
        (iterator.map(underscoreName).toSeq match {
          case Seq(arg) => arg
          case args => args.commaSeparated(model = Model.Parentheses)
        }) + " " + `=>` + " "
    }

    if (initialBindings.isEmpty) {
      appendGen(firstGenerator, restEnumerators)
    } else {
      resultText += '{'
      if (!forceSingleLine) {
        resultText += '\n'
      }

      for (binding <- initialBindings) {
        val b = ForBinding(binding)
        resultText ++= "val "
        markMappingHere(b.pattern, patternMappings)
        resultText ++= b.patternText
        resultText ++= " = "
        resultText ++= b.exprText
        resultText += (if (forceSingleLine) ';' else '\n')
      }

      appendGen(firstGenerator, restEnumerators)
      resultText += '}'
    }

    // we need the braces to handle incomplete fors, because createExpressionWithContextFromText
    // needs to parse exactly one expression.
    // The parenthesis is needed for correct newline-handling
    val prefix = "{("
    resultText.insert(0, prefix + lambdaPrefix)
    val shiftOffset = prefix.length + lambdaPrefix.length
    resultText ++= ")}"
    val expression = ScalaPsiElementFactory.createExpressionWithContextFromText(
      resultText.toString,
      getContext,
      this
    )

    def withElements[T](mappings: mutable.Map[T, Int]) = for {
      (original, offset) <- if (forDisplay) Iterator.empty else mappings.iterator

      element = expression.findElementAt(offset + shiftOffset)
      if element != null
    } yield original -> element

    val desugared = expression match {
      case ScBlock(ScParenthesisedExpr(desugaredFor)) =>
        desugaredFor.context = getContext
        desugaredFor.child = this
        desugaredFor
      case _ =>
        expression
    }

    Some((desugared, withElements(patternMappings), withElements(enumMappings)))
  }