def batch()

in scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithSideInput.scala [143:220]


  def batch(
    batchSize: Long,
    maxLiveWindows: Int = BatchDoFn.DEFAULT_MAX_LIVE_WINDOWS
  ): SCollectionWithSideInput[Iterable[T]] =
    new SCollectionWithSideInput[Iterable[T]](coll.batch(batchSize, maxLiveWindows), sides)

  /** [[SCollection.batchByteSized]] that retains [[SideInput]]. */
  def batchByteSized(
    batchByteSize: Long,
    maxLiveWindows: Int = BatchDoFn.DEFAULT_MAX_LIVE_WINDOWS
  ): SCollectionWithSideInput[Iterable[T]] =
    batchWeighted(batchByteSize, ScioUtil.elementByteSize(context), maxLiveWindows)

  /** [[SCollection.batchWeighted]] that retains [[SideInput]]. */
  def batchWeighted(
    batchWeight: Long,
    cost: T => Long,
    maxLiveWindows: Int = BatchDoFn.DEFAULT_MAX_LIVE_WINDOWS
  ): SCollectionWithSideInput[Iterable[T]] =
    new SCollectionWithSideInput[Iterable[T]](
      coll.batchWeighted(batchWeight, cost, maxLiveWindows),
      sides
    )

  /**
   * Allows multiple outputs from [[SCollectionWithSideInput]].
   *
   * @return
   *   map of side output to [[SCollection]]
   */
  private[values] def transformWithSideOutputs(
    sideOutputs: Seq[SideOutput[T]],
    name: String = "TransformWithSideOutputs"
  )(f: (T, SideInputContext[T]) => SideOutput[T]): Map[SideOutput[T], SCollection[T]] = {
    val _mainTag = SideOutput[T]()
    val tagToSide = sideOutputs.map(e => e.tupleTag.getId -> e).toMap +
      (_mainTag.tupleTag.getId -> _mainTag)

    val sideTags =
      TupleTagList.of(sideOutputs.map(e => e.tupleTag.asInstanceOf[TupleTag[_]]).asJava)

    def transformWithSideOutputsFn(
      partitions: Seq[SideOutput[T]],
      f: (T, SideInputContext[T]) => SideOutput[T]
    ): DoFn[T, T] =
      new SideInputDoFn[T, T] {
        val g = ClosureCleaner.clean(f) // defeat closure

        /*
         * ProcessContext is required as an argument because it is passed to SideInputContext
         * */
        @ProcessElement
        private[scio] def processElement(c: DoFn[T, T]#ProcessContext, w: BoundedWindow): Unit = {
          val elem = c.element()
          val partition = g(elem, sideInputContext(c, w))
          if (!partitions.exists(_.tupleTag == partition.tupleTag)) {
            throw new IllegalStateException(s"""${partition.tupleTag.getId} is not part of
            ${partitions.map(_.tupleTag.getId).mkString}""")
          }

          c.output(partition.tupleTag, elem)
        }
      }

    val transform = parDo[T, T](transformWithSideOutputsFn(sideOutputs, f))
      .withOutputTags(_mainTag.tupleTag, sideTags)

    val pCollectionWrapper = this.internal.apply(name, transform)
    pCollectionWrapper.getAll.asScala.view
      .mapValues(
        context
          .wrap(_)
          .asInstanceOf[SCollection[T]]
          .setCoder(internal.getCoder)
      )
      .flatMap { case (tt, col) => Try(tagToSide(tt.getId) -> col).toOption }
      .toMap
  }