app/services/DynamoBanditData.scala (93 lines of code) (raw):
package services
import com.typesafe.scalalogging.StrictLogging
import io.circe.{Decoder, Encoder}
import models.DynamoErrors.{DynamoError, DynamoGetError}
import software.amazon.awssdk.services.dynamodb.DynamoDbClient
import software.amazon.awssdk.services.dynamodb.model.{AttributeValue, QueryRequest}
import utils.Circe.dynamoMapToJson
import io.circe.generic.auto._
import zio.blocking.effectBlocking
import zio.{ZEnv, ZIO}
import scala.jdk.CollectionConverters._
/**
* Models for data received from DynamoDb
*/
case class VariantSample(variantName: String, annualisedValueInGBP: Double, annualisedValueInGBPPerView: Double, views: Double)
case class TestSample(testName: String, variants: List[VariantSample], timestamp: String)
object TestSample {
implicit val decoder = Decoder[TestSample]
implicit val encoder = Encoder[TestSample]
}
/**
* Models for data returned to the client
*/
// models the mean and views for each variant up to a certain timestamp
case class EnrichedVariantSampleData(variantName: String, views: Double, mean: Double)
case class EnrichedTestSampleData(timestamp: String, variants: List[EnrichedVariantSampleData])
// Final mean and views for a variant
case class VariantSummary(variantName: String, mean: Double, views: Double)
case class BanditData(variantSummaries: List[VariantSummary], samples: List[EnrichedTestSampleData])
object BanditData {
implicit val decoder = Decoder[BanditData]
implicit val encoder = Encoder[BanditData]
}
class DynamoBanditData(stage: String, client: DynamoDbClient) extends StrictLogging {
// No DEV table for bandit data
private val tableName = s"support-bandit-${if (stage == "PROD") "PROD" else "CODE"}"
private def query(testName: String, channel: String): ZIO[ZEnv, DynamoGetError, java.util.List[java.util.Map[String, AttributeValue]]] = {
effectBlocking {
client.query(
QueryRequest
.builder()
.tableName(tableName)
.keyConditionExpression("testName = :testName")
.expressionAttributeValues(Map(
":testName" -> AttributeValue.builder.s(s"${channel}_$testName").build
).asJava)
.scanIndexForward(true)
.build()
).items()
}.mapError(DynamoGetError)
}
private def sampleMean(samples: Array[VariantSample], population: Double): Double =
samples.foldLeft(0D)((acc, sample) =>
acc + (sample.views / population) * sample.annualisedValueInGBPPerView
)
private def buildVariantSummaries(samples: Array[TestSample]): List[VariantSummary] =
samples
.flatMap(_.variants)
.groupBy(variantSample => variantSample.variantName)
.map { case (variantName, samples) =>
val population = samples.map(_.views).sum
val mean = sampleMean(samples, population)
VariantSummary(variantName = variantName, mean = mean, views = population)
}
.toList
// For each hourly sample, calculate the means up to that point, so that we can visualise how the Bandit's view of the variants changed over time
private def buildEnrichedSamples(samples: Array[TestSample], sampleCount: Option[Int]): List[EnrichedTestSampleData] = {
val samplesByVariant: Map[String, Array[VariantSample]] = samples
.flatMap(_.variants)
.groupBy(variantSample => variantSample.variantName)
samples
.zipWithIndex
.foldLeft(Array.empty[EnrichedTestSampleData]) { case (enrichedSamples, (sample, idx)) =>
val variants = sample.variants.map(currentVariantSample => {
val startIdx = sampleCount.map(n => Math.max(idx - n, 0)).getOrElse(0) // only use the last sampleCount samples, if defined
val previousSamples = samplesByVariant(currentVariantSample.variantName).slice(startIdx, idx+1)
val population = previousSamples.map(_.views).sum
val mean = sampleMean(previousSamples, population)
EnrichedVariantSampleData(currentVariantSample.variantName, currentVariantSample.views, mean)
})
enrichedSamples :+ EnrichedTestSampleData(sample.timestamp, variants)
}
.toList
}
def getDataForTest(testName: String, channel: String, sampleCount: Option[Int]): ZIO[ZEnv, DynamoError, BanditData] =
query(testName, channel)
.map { results =>
results.asScala
.map(item => dynamoMapToJson(item).as[TestSample])
.flatMap {
case Right(row) => Some(row)
case Left(error) =>
logger.error(s"Failed to decode item from Dynamo: ${error.getMessage}")
None
}
.toArray
}
.map { samples: Array[TestSample] =>
val variantSummaries = buildVariantSummaries(samples)
val enrichedSamples = buildEnrichedSamples(samples, sampleCount)
BanditData(variantSummaries, enrichedSamples)
}
}