in sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java [225:485]
private void addRunnerAndConsumersForPTransformRecursively(
BeamFnStateClient beamFnStateClient,
BeamFnDataClient queueingClient,
String pTransformId,
PTransform pTransform,
Supplier<String> processBundleInstructionId,
Supplier<List<CacheToken>> cacheTokens,
Supplier<Cache<?, ?>> bundleCache,
ProcessBundleDescriptor processBundleDescriptor,
RunnerApi.Components components,
SetMultimap<String, String> pCollectionIdsToConsumingPTransforms,
PCollectionConsumerRegistry pCollectionConsumerRegistry,
Set<String> processedPTransformIds,
PTransformFunctionRegistry startFunctionRegistry,
PTransformFunctionRegistry finishFunctionRegistry,
Consumer<ThrowingRunnable> addResetFunction,
Consumer<ThrowingRunnable> addTearDownFunction,
BiConsumer<ApiServiceDescriptor, DataEndpoint<?>> addDataEndpoint,
Consumer<TimerEndpoint<?>> addTimerEndpoint,
Consumer<BundleProgressReporter> addBundleProgressReporter,
BundleSplitListener splitListener,
BundleFinalizer bundleFinalizer,
Collection<BeamFnDataReadRunner<?>> channelRoots,
Map<ApiServiceDescriptor, BeamFnDataOutboundAggregator> outboundAggregatorMap,
Set<String> runnerCapabilities)
throws IOException {
// Recursively ensure that all consumers of the output PCollection have been created.
// Since we are creating the consumers first, we know that the we are building the DAG
// in reverse topological order.
for (String pCollectionId : pTransform.getOutputsMap().values()) {
for (String consumingPTransformId : pCollectionIdsToConsumingPTransforms.get(pCollectionId)) {
addRunnerAndConsumersForPTransformRecursively(
beamFnStateClient,
queueingClient,
consumingPTransformId,
processBundleDescriptor.getTransformsMap().get(consumingPTransformId),
processBundleInstructionId,
cacheTokens,
bundleCache,
processBundleDescriptor,
components,
pCollectionIdsToConsumingPTransforms,
pCollectionConsumerRegistry,
processedPTransformIds,
startFunctionRegistry,
finishFunctionRegistry,
addResetFunction,
addTearDownFunction,
addDataEndpoint,
addTimerEndpoint,
addBundleProgressReporter,
splitListener,
bundleFinalizer,
channelRoots,
outboundAggregatorMap,
runnerCapabilities);
}
}
if (!pTransform.hasSpec()) {
throw new IllegalArgumentException(
String.format(
"Cannot process transform with no spec: %s",
TextFormat.printer().printToString(pTransform)));
}
if (pTransform.getSubtransformsCount() > 0) {
throw new IllegalArgumentException(
String.format(
"Cannot process composite transform: %s",
TextFormat.printer().printToString(pTransform)));
}
// Skip reprocessing processed pTransforms.
if (!processedPTransformIds.contains(pTransformId)) {
urnToPTransformRunnerFactoryMap
.getOrDefault(pTransform.getSpec().getUrn(), defaultPTransformRunnerFactory)
.addRunnerForPTransform(
new Context() {
@Override
public PipelineOptions getPipelineOptions() {
return options;
}
@Override
public ShortIdMap getShortIdMap() {
return shortIds;
}
@Override
public BeamFnDataClient getBeamFnDataClient() {
return queueingClient;
}
@Override
public BeamFnStateClient getBeamFnStateClient() {
return beamFnStateClient;
}
@Override
public String getPTransformId() {
return pTransformId;
}
@Override
public PTransform getPTransform() {
return pTransform;
}
@Override
public Supplier<String> getProcessBundleInstructionIdSupplier() {
return processBundleInstructionId;
}
@Override
public Supplier<List<CacheToken>> getCacheTokensSupplier() {
return cacheTokens;
}
@Override
public Supplier<Cache<?, ?>> getBundleCacheSupplier() {
return bundleCache;
}
@Override
public Cache<?, ?> getProcessWideCache() {
return processWideCache;
}
@Override
public RunnerApi.Components getComponents() {
return components;
}
@Override
public Set<String> getRunnerCapabilities() {
return runnerCapabilities;
}
@Override
public <T> void addPCollectionConsumer(
String pCollectionId, FnDataReceiver<WindowedValue<T>> consumer) {
pCollectionConsumerRegistry.register(
pCollectionId, pTransformId, pTransform.getUniqueName(), consumer);
}
@Override
public <T> FnDataReceiver<T> addOutgoingDataEndpoint(
ApiServiceDescriptor apiServiceDescriptor,
org.apache.beam.sdk.coders.Coder<T> coder) {
BeamFnDataOutboundAggregator aggregator =
outboundAggregatorMap.computeIfAbsent(
apiServiceDescriptor,
asd ->
queueingClient.createOutboundAggregator(
asd,
processBundleInstructionId,
runnerCapabilities.contains(
BeamUrns.getUrn(
StandardRunnerProtocols.Enum
.CONTROL_RESPONSE_ELEMENTS_EMBEDDING))));
return aggregator.registerOutputDataLocation(pTransformId, coder);
}
@Override
public <T> FnDataReceiver<Timer<T>> addOutgoingTimersEndpoint(
String timerFamilyId, org.apache.beam.sdk.coders.Coder<Timer<T>> coder) {
BeamFnDataOutboundAggregator aggregator;
if (!processBundleDescriptor.hasTimerApiServiceDescriptor()) {
throw new IllegalStateException(
String.format(
"Timers are unsupported because the "
+ "ProcessBundleRequest %s does not provide a timer ApiServiceDescriptor.",
processBundleInstructionId.get()));
}
aggregator =
outboundAggregatorMap.computeIfAbsent(
processBundleDescriptor.getTimerApiServiceDescriptor(),
asd ->
queueingClient.createOutboundAggregator(
asd,
processBundleInstructionId,
runnerCapabilities.contains(
BeamUrns.getUrn(
StandardRunnerProtocols.Enum
.CONTROL_RESPONSE_ELEMENTS_EMBEDDING))));
return aggregator.registerOutputTimersLocation(
pTransformId, timerFamilyId, coder);
}
@Override
public FnDataReceiver<?> getPCollectionConsumer(String pCollectionId) {
return pCollectionConsumerRegistry.getMultiplexingConsumer(pCollectionId);
}
@Override
public void addStartBundleFunction(ThrowingRunnable startFunction) {
startFunctionRegistry.register(
pTransformId, pTransform.getUniqueName(), startFunction);
}
@Override
public void addFinishBundleFunction(ThrowingRunnable finishFunction) {
finishFunctionRegistry.register(
pTransformId, pTransform.getUniqueName(), finishFunction);
}
@Override
public <T> void addIncomingDataEndpoint(
ApiServiceDescriptor apiServiceDescriptor,
org.apache.beam.sdk.coders.Coder<T> coder,
FnDataReceiver<T> receiver) {
addDataEndpoint.accept(
apiServiceDescriptor, DataEndpoint.create(pTransformId, coder, receiver));
}
@Override
public <T> void addIncomingTimerEndpoint(
String timerFamilyId,
org.apache.beam.sdk.coders.Coder<Timer<T>> coder,
FnDataReceiver<Timer<T>> receiver) {
addTimerEndpoint.accept(
TimerEndpoint.create(pTransformId, timerFamilyId, coder, receiver));
}
@Override
public void addResetFunction(ThrowingRunnable resetFunction) {
addResetFunction.accept(resetFunction);
}
@Override
public <T> void addChannelRoot(BeamFnDataReadRunner<T> beamFnDataReadRunner) {
channelRoots.add(beamFnDataReadRunner);
}
@Override
public void addTearDownFunction(ThrowingRunnable tearDownFunction) {
addTearDownFunction.accept(tearDownFunction);
}
@Override
public void addBundleProgressReporter(
BundleProgressReporter bundleProgressReporter) {
addBundleProgressReporter.accept(bundleProgressReporter);
}
@Override
public BundleSplitListener getSplitListener() {
return splitListener;
}
@Override
public BundleFinalizer getBundleFinalizer() {
return bundleFinalizer;
}
});
processedPTransformIds.add(pTransformId);
}
}