app/util/StepFunctions.scala (81 lines of code) (raw):
package util
import java.time.Instant
import com.amazonaws.services.stepfunctions.model._
import com.fasterxml.jackson.core.JsonParseException
import com.gu.media.upload.model._
import play.api.libs.json.{JsResultException, Json}
import scala.jdk.CollectionConverters._
class StepFunctions(awsConfig: AWSConfig) {
def getById(id: String): Option[Upload] = {
val arn = s"${awsConfig.pipelineArn.replace(":stateMachine:", ":execution:")}:$id"
try {
val request = new DescribeExecutionRequest().withExecutionArn(arn)
val result = awsConfig.stepFunctionsClient.describeExecution(request)
val upload = Json.parse(result.getInput).validate[Upload].asOpt
upload.map(fillInStartTimestamp(result, _))
} catch {
case _: ExecutionDoesNotExistException =>
None
}
}
def getJobs(atomId: String): Iterable[ExecutionListItem] = {
val runningJobs = getExecutions(atomId, ExecutionStatus.RUNNING)
val failedJobs = getExecutions(atomId, ExecutionStatus.FAILED).filter(lessThan10MinutesOld)
runningJobs ++ failedJobs
}
def getTaskEntered(events: Iterable[HistoryEvent]): Option[(String, Upload)] = for {
event <- events.find(_.getType == "TaskStateEntered")
details = event.getStateEnteredEventDetails
upload <- Json.parse(details.getInput).validate[Upload].asOpt
} yield {
details.getName -> upload
}
def getExecutionFailed(events: Iterable[HistoryEvent]): Option[String] = {
events.find(_.getType == "ExecutionFailed").flatMap { event =>
val cause = event.getExecutionFailedEventDetails.getCause
try {
Some((Json.parse(cause) \ "errorMessage").as[String])
} catch {
case _: JsonParseException | _: JsResultException =>
Some(cause)
}
}
}
def start(upload: Upload): Unit = {
val stepFunctionsRequest = new StartExecutionRequest()
.withName(upload.id)
.withStateMachineArn(awsConfig.pipelineArn)
.withInput(Json.stringify(Json.toJson(upload)))
awsConfig.stepFunctionsClient.startExecution(stepFunctionsRequest)
}
def getEventsInReverseOrder(execution: ExecutionListItem): Iterable[HistoryEvent] = {
val request = new GetExecutionHistoryRequest()
.withExecutionArn(execution.getExecutionArn)
.withReverseOrder(true)
.withMaxResults(20)
awsConfig.stepFunctionsClient.getExecutionHistory(request).getEvents.asScala
}
private def getExecutions(atomId: String, filter: ExecutionStatus): Iterable[ExecutionListItem] = {
val request = new ListExecutionsRequest()
.withStateMachineArn(awsConfig.pipelineArn)
.withStatusFilter(filter)
val results = awsConfig.stepFunctionsClient.listExecutions(request).getExecutions.asScala
results.filter(_.getName.startsWith(atomId))
}
private def lessThan10MinutesOld(e: ExecutionListItem): Boolean = {
val now = Instant.now().toEpochMilli
val end = e.getStopDate.toInstant.toEpochMilli
(now - end) < (1000 * 60 * 10)
}
private def fillInStartTimestamp(result: DescribeExecutionResult, upload: Upload): Upload = {
if(upload.metadata.startTimestamp.isEmpty) {
upload.copy(
metadata = upload.metadata.copy(
startTimestamp = Some(result.getStartDate.getTime)
)
)
} else {
upload
}
}
}