in tools/apachebeam-throttling/src/main/java/com/google/cloud/pso/beamthrottling/DynamicThrottlingTransform.java [257:414]
public PCollectionTuple expand(PCollection<InputT> input) {
/**
* {Grouping} Segregates incoming events into a single or multiple groups.
*/
PCollection<KV<Integer, InputT>> groups = input.apply("Grouping", ParDo.of(new DoFn<InputT, KV<Integer, InputT>>() {
@DoFn.ProcessElement
public void processElement(ProcessContext context) {
context.output(KV.of(ThreadLocalRandom.current().nextInt(1, numberOfGroups + 1), context.element()));
}
}));
TimeDomain timerDomain;
if (input.isBounded().equals(PCollection.IsBounded.UNBOUNDED)) {
timerDomain = TimeDomain.PROCESSING_TIME;
} else {
timerDomain = TimeDomain.EVENT_TIME;
}
/**
* {Throttling}: Payload is either sent to the external service or rejected based on the rejection probability.
* Process element will add the payload to the {incomingReqBagState}.
*/
PCollectionTuple enriched = groups.apply("Throttling", ParDo.of(new DoFn<KV<Integer, InputT>, OutputT>() {
@StateId("acceptedRequests")
private final StateSpec<ValueState<Integer>> acceptedRequestsStateSpec =
StateSpecs.value(VarIntCoder.of());
@StateId("incomingReqBagState")
private final StateSpec<BagState<InputT>> incomingReqBagStateSpec =
StateSpecs.bag(inputElementCoder);
@StateId("totalProcessedRequests")
private final StateSpec<ValueState<Integer>> totalProcessedRequestsStateSpec =
StateSpecs.value(VarIntCoder.of());
@TimerId("resetCountsTimer")
private final TimerSpec resetCountsTimerSpec = TimerSpecs.timer(timerDomain);
@TimerId("incomingReqBagTimer")
private final TimerSpec incomingReqBagTimerSpec = TimerSpecs.timer(timerDomain);
/**
* @param context Payload that should be sent by the client.
* @param totalProcessedRequests Counts the total number of requests processed by client.
* @param incomingReqBagState BagState, which holds the requests to process in batch.
* @param resetCountsTimer Timer that resets the counter for every n minutes.
* @param incomingReqBagTimer Timer that is invoked for every N minutes irrespective of the bag size.
*/
@ProcessElement
public void processElement(ProcessContext context,
@StateId("acceptedRequests") ValueState<Integer> acceptedRequests,
@StateId("totalProcessedRequests") ValueState<Integer> totalProcessedRequests,
@StateId("incomingReqBagState") BagState<InputT> incomingReqBagState,
@TimerId("resetCountsTimer") Timer resetCountsTimer,
@TimerId("incomingReqBagTimer") Timer incomingReqBagTimer) {
if (incomingReqBagState.isEmpty().read()) {
incomingReqBagTimer.offset(Duration.millis(batchInterval.toMillis())).setRelative();
}
if (firstNonNull(totalProcessedRequests.read(), 0) == 0) {
resetCountsTimer.offset(Duration.millis(resetCounterInterval.toMillis())).setRelative();
}
incomingReqBagState.add(context.element().getValue());
}
/**
* Timer that processes n requests each time it gets invoked.
* @param context Payload that should be sent by the client.
* @param acceptedRequests Counts the total number of requests processed by client and got accepted by the backend.
*/
@OnTimer("incomingReqBagTimer")
public void incomingReqBagTimer(
OnTimerContext context,
@StateId("acceptedRequests") ValueState<Integer> acceptedRequests,
@StateId("incomingReqBagState") BagState<InputT> incomingReqBagState,
@StateId("totalProcessedRequests") ValueState<Integer> totalProcessedRequests,
@TimerId("incomingReqBagTimer") Timer incomingReqBagTimer) {
Iterable<InputT> bag = incomingReqBagState.read();
int acceptedReqCount = firstNonNull(acceptedRequests.read(), 0);
int totalReqCount = firstNonNull(totalProcessedRequests.read(), 0);
List<InputT> retryRequests = new ArrayList<InputT>();
for (InputT value : Iterables.limit(bag, numOfEventsToBeProcessedForBatch)) {
//Calculates requests rejection probability
double reqRejectionProbability = (totalReqCount - (kInRejectionProbability * acceptedReqCount)) / (totalReqCount + 1);
double randomValue = Math.random();
OutputT successTagValue;
boolean accepted = TRUE;
Result<InputT> result;
if ((reqRejectionProbability) <= randomValue) {
totalReqCount = totalReqCount + 1;
totalProcessedRequests.write((int) totalReqCount);
try {
successTagValue = clientCall.call(value);
acceptedReqCount = acceptedReqCount + 1;
acceptedRequests.write(acceptedReqCount);
context.output(successTag, successTagValue);
} catch (ThrottlingException e) {
switch (throttlingStrategy) {
case DLQ:
context.output(throttlingTag, result = new Result<>(value, e.getMessage()));
break;
case DROP:
break;
case RETRY:
retryRequests.add(value);
break;
}
accepted = FALSE;
} catch (Exception e) {
accepted = FALSE;
context.output(errorTag, result = new Result<>(value, e.getMessage()));
}
} else {
switch (throttlingStrategy){
case DLQ:
context.output(throttlingTag, result = new Result<>(value, "Throttled by Client. Request rejection probability: " + reqRejectionProbability));
break;
case DROP:
break;
case RETRY:
retryRequests.add(value);
break;
}
accepted = FALSE;
}
LOG.debug(" totalCount-" + totalReqCount + ", reqRejecProb-" + reqRejectionProbability + ", Random-" + randomValue + ", " + accepted + ", acceptCount-" + acceptedReqCount);
}
List<InputT> latestElements = new ArrayList<InputT>();
Iterables.addAll(latestElements, Iterables.skip(bag, numOfEventsToBeProcessedForBatch));
incomingReqBagState.clear();
for (InputT b : latestElements) {
incomingReqBagState.add(b);
}
for (InputT b : retryRequests){
incomingReqBagState.add(b);
}
incomingReqBagTimer.offset(Duration.millis(batchInterval.toMillis())).setRelative();
}
/**
* Timer that resets counters when ever it's get invoked.
*/
@OnTimer("resetCountsTimer")
public void resetCountsTimer(
OnTimerContext c,
@StateId("totalProcessedRequests") ValueState<Integer> totalProcessedRequests,
@StateId("acceptedRequests") ValueState<Integer> acceptedRequests,
@TimerId("resetCountsTimer") Timer resetCountsTimer) {
totalProcessedRequests.clear();
acceptedRequests.clear();
resetCountsTimer.offset(Duration.millis(resetCounterInterval.toMillis())).setRelative();
}
}).withOutputTags(successTag, TupleTagList.of(errorTag).and(throttlingTag)));
return enriched;
}