app/helpers/ParanoidS3Source.scala (101 lines of code) (raw):

package helpers import java.net.URLEncoder import java.time.OffsetDateTime import java.time.format.DateTimeFormatter import akka.actor.ActorSystem import akka.http.scaladsl.{Http, model} import akka.http.scaladsl.model._ import akka.stream._ import akka.stream.stage.{AbstractOutHandler, GraphStage, GraphStageLogic} import akka.http.scaladsl.model.HttpHeader.ParsingResult._ import akka.stream.scaladsl.Sink import akka.util.ByteString import com.amazonaws.regions.Region import scala.collection.immutable.Seq import play.api.Logger import akka.http.scaladsl.model.HttpProtocol import org.slf4j.LoggerFactory import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ /** * Source which performs Listbucket requests and outputs the XML as a String to be sanitised * @param bucketName bucket to scan * @param region AWS region that the bucket lives in * @param actorSystem implicitly provided ActorSystem for akka http */ class ParanoidS3Source(bucketName:String, region:Region, credsProvider: AwsCredentialsProvider)(implicit actorSystem:ActorSystem, override val mat:Materializer) extends GraphStage[SourceShape[ByteString]] with S3Signer { private val out:Outlet[ByteString] = Outlet.create("ParanoidS3Source.out") implicit val ec:ExecutionContext = actorSystem.dispatcher override def shape: SourceShape[ByteString] = SourceShape.of(out) val logger = LoggerFactory.getLogger(getClass) /** * extract given entries from the XML manually. This is necessary as we are in "paranoid mode", and can't * rely on the XML being valid. * @param paramsToFind Seq[String] giving the parameters to find * @param body ByteString * @return a Map, containing each `paramToFind` pointing to an Option which has the string value, if present. * NOTE: this assumes that the keys and data to extract are both UTF-8 compatible (but the overall document can break) */ def findParams(paramsToFind:Seq[String], body:ByteString):Map[String,Option[String]] = { val paramsToFindBytes = paramsToFind.map(ByteString(_)) def captureString(toFind:ByteString, haystack:ByteString,n:Int):Option[ByteString] = { //now capture everything up to the next < character for(i<-n+toFind.length to haystack.length){ if(haystack(i)=="<".charAt(0).toByte){ val result = haystack.slice(n+toFind.length+1,i) return Some(result) } } None } def locateString(toFind:ByteString, haystack:ByteString):Option[Int] = { for(n<-0 to haystack.length-toFind.length){ val matcher = haystack.slice(n, n+toFind.length) if(matcher==toFind){ return Some(n) } } None } def findParamFuture(toFind: ByteString, haystack:ByteString):Future[Tuple2[ByteString, Option[ByteString]]] = Future { locateString(toFind, haystack) match { case Some(location)=>Tuple2(toFind, captureString(toFind, haystack, location)) case None=>Tuple2(toFind,None) } } Await.result(Future.sequence(paramsToFindBytes.map(findParamFuture(_,body))) .map(_.map(tuple=>Tuple2(tuple._1.utf8String, tuple._2.map(_.utf8String)))) .map(_.toMap) , 10.seconds) } override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private val logger=Logger(getClass) var onPage:Int = 0 var continuationToken:Option[String] = None setHandler(out, new AbstractOutHandler { override def onPull(): Unit = { val headerSequence = Seq( HttpHeader.parse("Host", s"$bucketName.s3.${region.getName}.amazonaws.com"), HttpHeader.parse("Date", OffsetDateTime.now().format(DateTimeFormatter.RFC_1123_DATE_TIME)), ).map({ case Ok(header, errors)=>header case Error(err)=>throw new RuntimeException(err.toString) }) val baseParams = "list-type=2&encoding-type=url" //val baseParams = "delimiter=/&encoding-type=url&prefix" val qParams = continuationToken match { case Some(token)=> logger.debug(s"continuation token is $token") baseParams + s"&continuation-token=${URLEncoder.encode(token, "UTF-8")}" case None=> logger.debug("no continuation token") baseParams } val request = HttpRequest(HttpMethods.GET, Uri(s"https://$bucketName.s3.${region.getName}.amazonaws.com?$qParams"), headerSequence) val signedRequest = Await.result(signHttpRequest(request, region,"s3", credsProvider), 10 seconds) logger.debug(s"Signed request is ${signedRequest.toString()}") val response = Await.result(Http().singleRequest(signedRequest), 10 seconds) //we are in paranoid mode, so can't assume that this is valid xml (yet). So, we buffer the content and manually scan for //the continuationToken and isTruncated flags we require. val body = Await.result(response.entity.getDataBytes().runWith(Sink.fold[ByteString,ByteString](ByteString.empty)(_.concat(_)), mat), 10 seconds) push(out, body) val flags:Map[String,Option[String]] = findParams(Seq("NextContinuationToken","KeyCount","IsTruncated"), body) flags("IsTruncated") match { case Some(flag)=> if(flag=="true"){ continuationToken = flags("NextContinuationToken") } else { completeStage() } case None=>completeStage() } } }) } }