public PCollectionTuple expand()

in tools/apachebeam-throttling/src/main/java/com/google/cloud/pso/beamthrottling/DynamicThrottlingTransform.java [257:414]


    public PCollectionTuple expand(PCollection<InputT> input) {
        /**
         * {Grouping} Segregates incoming events into a single or multiple groups.
         */
        PCollection<KV<Integer, InputT>> groups = input.apply("Grouping", ParDo.of(new DoFn<InputT, KV<Integer, InputT>>() {
            @DoFn.ProcessElement
            public void processElement(ProcessContext context) {
                context.output(KV.of(ThreadLocalRandom.current().nextInt(1, numberOfGroups + 1), context.element()));
            }
        }));

        TimeDomain timerDomain;
        if (input.isBounded().equals(PCollection.IsBounded.UNBOUNDED)) {
            timerDomain = TimeDomain.PROCESSING_TIME;
        } else {
            timerDomain = TimeDomain.EVENT_TIME;
        }

        /**
         * {Throttling}: Payload is either sent to the external service or rejected based on the rejection probability.
         * Process element will add the payload to the {incomingReqBagState}.
         */
        PCollectionTuple enriched = groups.apply("Throttling", ParDo.of(new DoFn<KV<Integer, InputT>, OutputT>() {
            @StateId("acceptedRequests")
            private final StateSpec<ValueState<Integer>> acceptedRequestsStateSpec =
                    StateSpecs.value(VarIntCoder.of());

            @StateId("incomingReqBagState")
            private final StateSpec<BagState<InputT>> incomingReqBagStateSpec =
                    StateSpecs.bag(inputElementCoder);

            @StateId("totalProcessedRequests")
            private final StateSpec<ValueState<Integer>> totalProcessedRequestsStateSpec =
                    StateSpecs.value(VarIntCoder.of());

            @TimerId("resetCountsTimer")
            private final TimerSpec resetCountsTimerSpec = TimerSpecs.timer(timerDomain);

            @TimerId("incomingReqBagTimer")
            private final TimerSpec incomingReqBagTimerSpec = TimerSpecs.timer(timerDomain);

            /**
             * @param context Payload that should be sent by the client.
             * @param totalProcessedRequests Counts the total number of requests processed by client.
             * @param incomingReqBagState BagState, which holds the requests to process in batch.
             * @param resetCountsTimer Timer that resets the counter for every n minutes.
             * @param incomingReqBagTimer Timer that is invoked for every N minutes irrespective of the bag size.
             */
            @ProcessElement
            public void processElement(ProcessContext context,
                                       @StateId("acceptedRequests") ValueState<Integer> acceptedRequests,
                                       @StateId("totalProcessedRequests") ValueState<Integer> totalProcessedRequests,
                                       @StateId("incomingReqBagState") BagState<InputT> incomingReqBagState,
                                       @TimerId("resetCountsTimer") Timer resetCountsTimer,
                                       @TimerId("incomingReqBagTimer") Timer incomingReqBagTimer) {

                if (incomingReqBagState.isEmpty().read()) {
                    incomingReqBagTimer.offset(Duration.millis(batchInterval.toMillis())).setRelative();
                }
                if (firstNonNull(totalProcessedRequests.read(), 0) == 0) {
                    resetCountsTimer.offset(Duration.millis(resetCounterInterval.toMillis())).setRelative();
                }
                incomingReqBagState.add(context.element().getValue());
            }

            /**
             * Timer that processes n requests each time it gets invoked.
             * @param context Payload that should be sent by the client.
             * @param acceptedRequests Counts the total number of requests processed by client and got accepted by the backend.
             */
            @OnTimer("incomingReqBagTimer")
            public void incomingReqBagTimer(
                    OnTimerContext context,
                    @StateId("acceptedRequests") ValueState<Integer> acceptedRequests,
                    @StateId("incomingReqBagState") BagState<InputT> incomingReqBagState,
                    @StateId("totalProcessedRequests") ValueState<Integer> totalProcessedRequests,
                    @TimerId("incomingReqBagTimer") Timer incomingReqBagTimer) {
                Iterable<InputT> bag = incomingReqBagState.read();

                int acceptedReqCount = firstNonNull(acceptedRequests.read(), 0);
                int totalReqCount = firstNonNull(totalProcessedRequests.read(), 0);
                List<InputT> retryRequests = new ArrayList<InputT>();

                for (InputT value : Iterables.limit(bag, numOfEventsToBeProcessedForBatch)) {
                    //Calculates requests rejection probability
                    double reqRejectionProbability = (totalReqCount - (kInRejectionProbability * acceptedReqCount)) / (totalReqCount + 1);
                    double randomValue = Math.random();
                    OutputT successTagValue;
                    boolean accepted = TRUE;
                    Result<InputT> result;
                    if ((reqRejectionProbability) <= randomValue) {
                        totalReqCount = totalReqCount + 1;
                        totalProcessedRequests.write((int) totalReqCount);
                        try {
                            successTagValue = clientCall.call(value);
                            acceptedReqCount = acceptedReqCount + 1;
                            acceptedRequests.write(acceptedReqCount);
                            context.output(successTag, successTagValue);
                        } catch (ThrottlingException e) {
                            switch (throttlingStrategy) {
                                case DLQ:
                                    context.output(throttlingTag, result = new Result<>(value, e.getMessage()));
                                    break;
                                case DROP:
                                    break;
                                case RETRY:
                                    retryRequests.add(value);
                                    break;
                            }
                            accepted = FALSE;
                        } catch (Exception e) {
                            accepted = FALSE;
                            context.output(errorTag, result = new Result<>(value, e.getMessage()));
                        }
                    } else {
                        switch (throttlingStrategy){
                            case DLQ:
                                context.output(throttlingTag, result = new Result<>(value, "Throttled by Client. Request rejection probability: " + reqRejectionProbability));
                                break;
                            case DROP:
                                break;
                            case RETRY:
                                retryRequests.add(value);
                                break;
                        }
                        accepted = FALSE;
                    }
                    LOG.debug("  totalCount-" + totalReqCount + ",  reqRejecProb-" + reqRejectionProbability + ",  Random-" + randomValue + ", " + accepted + ",  acceptCount-" + acceptedReqCount);
                }

                List<InputT> latestElements = new ArrayList<InputT>();
                Iterables.addAll(latestElements, Iterables.skip(bag, numOfEventsToBeProcessedForBatch));
                incomingReqBagState.clear();
                for (InputT b : latestElements) {
                    incomingReqBagState.add(b);
                }
                for (InputT b : retryRequests){
                    incomingReqBagState.add(b);
                }
                incomingReqBagTimer.offset(Duration.millis(batchInterval.toMillis())).setRelative();
            }

            /**
             * Timer that resets counters when ever it's get invoked.
             */
            @OnTimer("resetCountsTimer")
            public void resetCountsTimer(
                    OnTimerContext c,
                    @StateId("totalProcessedRequests") ValueState<Integer> totalProcessedRequests,
                    @StateId("acceptedRequests") ValueState<Integer> acceptedRequests,
                    @TimerId("resetCountsTimer") Timer resetCountsTimer) {
                totalProcessedRequests.clear();
                acceptedRequests.clear();
                resetCountsTimer.offset(Duration.millis(resetCounterInterval.toMillis())).setRelative();
            }
        }).withOutputTags(successTag, TupleTagList.of(errorTag).and(throttlingTag)));
        return enriched;
    }