in scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithSideInput.scala [143:220]
def batch(
batchSize: Long,
maxLiveWindows: Int = BatchDoFn.DEFAULT_MAX_LIVE_WINDOWS
): SCollectionWithSideInput[Iterable[T]] =
new SCollectionWithSideInput[Iterable[T]](coll.batch(batchSize, maxLiveWindows), sides)
/** [[SCollection.batchByteSized]] that retains [[SideInput]]. */
def batchByteSized(
batchByteSize: Long,
maxLiveWindows: Int = BatchDoFn.DEFAULT_MAX_LIVE_WINDOWS
): SCollectionWithSideInput[Iterable[T]] =
batchWeighted(batchByteSize, ScioUtil.elementByteSize(context), maxLiveWindows)
/** [[SCollection.batchWeighted]] that retains [[SideInput]]. */
def batchWeighted(
batchWeight: Long,
cost: T => Long,
maxLiveWindows: Int = BatchDoFn.DEFAULT_MAX_LIVE_WINDOWS
): SCollectionWithSideInput[Iterable[T]] =
new SCollectionWithSideInput[Iterable[T]](
coll.batchWeighted(batchWeight, cost, maxLiveWindows),
sides
)
/**
* Allows multiple outputs from [[SCollectionWithSideInput]].
*
* @return
* map of side output to [[SCollection]]
*/
private[values] def transformWithSideOutputs(
sideOutputs: Seq[SideOutput[T]],
name: String = "TransformWithSideOutputs"
)(f: (T, SideInputContext[T]) => SideOutput[T]): Map[SideOutput[T], SCollection[T]] = {
val _mainTag = SideOutput[T]()
val tagToSide = sideOutputs.map(e => e.tupleTag.getId -> e).toMap +
(_mainTag.tupleTag.getId -> _mainTag)
val sideTags =
TupleTagList.of(sideOutputs.map(e => e.tupleTag.asInstanceOf[TupleTag[_]]).asJava)
def transformWithSideOutputsFn(
partitions: Seq[SideOutput[T]],
f: (T, SideInputContext[T]) => SideOutput[T]
): DoFn[T, T] =
new SideInputDoFn[T, T] {
val g = ClosureCleaner.clean(f) // defeat closure
/*
* ProcessContext is required as an argument because it is passed to SideInputContext
* */
@ProcessElement
private[scio] def processElement(c: DoFn[T, T]#ProcessContext, w: BoundedWindow): Unit = {
val elem = c.element()
val partition = g(elem, sideInputContext(c, w))
if (!partitions.exists(_.tupleTag == partition.tupleTag)) {
throw new IllegalStateException(s"""${partition.tupleTag.getId} is not part of
${partitions.map(_.tupleTag.getId).mkString}""")
}
c.output(partition.tupleTag, elem)
}
}
val transform = parDo[T, T](transformWithSideOutputsFn(sideOutputs, f))
.withOutputTags(_mainTag.tupleTag, sideTags)
val pCollectionWrapper = this.internal.apply(name, transform)
pCollectionWrapper.getAll.asScala.view
.mapValues(
context
.wrap(_)
.asInstanceOf[SCollection[T]]
.setCoder(internal.getCoder)
)
.flatMap { case (tt, col) => Try(tagToSide(tt.getId) -> col).toOption }
.toMap
}