app/helpers/S3Signer.scala (179 lines of code) (raw):
package helpers
import java.net.URLEncoder
import java.security.MessageDigest
import java.time.{OffsetDateTime, ZoneOffset, ZonedDateTime}
import java.time.format.DateTimeFormatter
import akka.http.scaladsl.model.HttpHeader.ParsingResult.{Error, Ok}
import akka.http.scaladsl.model.{HttpHeader, HttpRequest}
import akka.stream.Materializer
import akka.stream.scaladsl.{Flow, Keep, Sink}
import akka.util.ByteString
import com.amazonaws.auth.{AWSCredentialsProvider, AWSSessionCredentials}
import com.amazonaws.regions.Region
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import play.api.Logger
import software.amazon.awssdk.auth.credentials.{AwsCredentialsProvider, AwsSessionCredentials}
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
/** this trait implements logic for signing S3 requests over HTTP */
/* see https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html*/
trait S3Signer {
protected val logger:org.slf4j.Logger
implicit val mat:Materializer
implicit val ec:ExecutionContext
val aws_compatible_date = DateTimeFormatter.ofPattern("uuuuMMdd")
val aws_compatible_datetime = DateTimeFormatter.ofPattern("uuuuMMdd'T'HHmmss'Z'")
protected def digestString(str:String) = {
val checksummer = MessageDigest.getInstance("SHA-256")
checksummer.digest(str.getBytes("UTF-8")).map("%02x".format(_)).mkString
}
protected def hmacString(key:Array[Byte],str:String) = hmacBinary(key,str).map("%02x".format(_)).mkString
protected def hmacString(key:String,str:String) = hmacBinary(key.getBytes("UTF-8"),str).map("%02x".format(_)).mkString
protected def hmacBinary(key:Array[Byte],str:String) = {
val secretKeySpec = new SecretKeySpec(key,"SHA-256")
val hmaccer = Mac.getInstance("HmacSHA256")
hmaccer.init(secretKeySpec)
hmaccer.doFinal(str.getBytes("UTF-8"))
}
/**
* splits down a standard query string into a map of key-value pairs
* @param queryString
* @return
*/
protected def convertQueryString(queryString:Option[String]):Map[String,String] = {
def makeTuple(str:String):(String,String) = {
try {
val parts = str.split("=")
Tuple2(parts.head, parts(1))
} catch {
case ex:ArrayIndexOutOfBoundsException=>
Tuple2(str, "")
}
}
queryString.map(_.split("&").map(el=>makeTuple(el)).toMap).getOrElse(Map())
}
protected def makeHttpHeader(key:String,value:String):HttpHeader = HttpHeader.parse(key, value) match {
case Ok(result, errors)=>result
case Error(errors)=>throw new RuntimeException(errors.toString)
}
protected def headersToMap(headers: Seq[HttpHeader]) = headers.map(head=>Tuple2(head.name, head.value)).toMap
/**
* Asynchronously signs the given [[HttpRequest]] object using the credentials provider given
* @param req Incoming HttpRequest
* @param region AWS region to access
* @param serviceName service name
* @param credsProvider AWSCredentialsProvider instance (or chain) that gives us credentials
* @return a Future with the updated [[HttpRequest]]
*/
def signHttpRequest(req:HttpRequest, region:Region, serviceName:String, credsProvider:AwsCredentialsProvider, timestamp:Option[OffsetDateTime]=None) = {
import scala.collection.immutable.Seq
val checksummer = MessageDigest.getInstance("SHA-256")
val requestTime = timestamp.getOrElse(OffsetDateTime.now(ZoneOffset.UTC))
val contentHashFuture = if(req.entity.isKnownEmpty()) {
logger.debug("request entity is empty")
Future(ByteString(checksummer.digest("".getBytes("UTF-8"))))
} else {
logger.debug("request entity has data")
req.entity.getDataBytes()
.via(new ContentHashingFlow("SHA-256"))
.runWith(Sink.reduce[ByteString](_.concat(_)), mat)
}
val contentHashHexFuture = contentHashFuture.map(bs=>bs.map("%02x".format(_)).mkString)
contentHashHexFuture.onComplete({
case Success(string)=>logger.debug(s"content hash string is $string")
case Failure(err)=>logger.error(s"failed to generate content hash", err)
})
val credentials = credsProvider.resolveCredentials()
val sessionTokenHeaders = credentials match {
case session:AwsSessionCredentials=>
Seq(makeHttpHeader("x-amz-security-token", session.sessionToken()))
case _=>
Seq()
}
val updatedHeadersFuture = contentHashHexFuture.map(hash=>
req.headers ++ Seq(
makeHttpHeader("x-amz-date",requestTime.format(aws_compatible_datetime)),
makeHttpHeader("x-amz-content-sha256",hash)
) ++ sessionTokenHeaders
)
val canonStringFuture = updatedHeadersFuture.map(headers=> {
val hash = contentHashHexFuture.value.get.get //this is safe, because updatedHeadersFuture is mapped from contentHashHexFuture; therefore if we got here, it succeeded.
calculateCanonicalString(req.method.value, req.uri.path.toString(), convertQueryString(req.uri.rawQueryString), headersToMap(headers), Some(hash))
})
canonStringFuture.onComplete({
case Success(str)=>logger.debug(s"canonicalString is $str")
case Failure(err)=>logger.error(s"Canonical string failed", err)
})
val stringToSignFuture = canonStringFuture.map(cs=>stringToSign(region.getName, serviceName, cs, requestTime))
stringToSignFuture.onComplete({
case Success(str)=>logger.debug(s"stringToSign is $str")
case Failure(err)=>logger.error(s"stringToSign failed", err)
})
val signingKeyResult = signingKey(credentials.secretAccessKey(), serviceName, region.getName, requestTime)
val sig = stringToSignFuture.map(sts=>finalSignature(signingKeyResult, sts))
Future.sequence(Seq(sig, updatedHeadersFuture)).map(results=>{
val finalSig = results.head.asInstanceOf[String]
val signedHeaders = results(1).asInstanceOf[Seq[HttpHeader]]
val signedHeadersString = signedHeaders.map(_.name().toLowerCase()).sorted.mkString(";")
val finalHeaders = signedHeaders ++ Seq(
makeHttpHeader("Authorization", s"AWS4-HMAC-SHA256 Credential=${credentials.accessKeyId()}/${requestTime.format(aws_compatible_date)}/$region/$serviceName/aws4_request,SignedHeaders=$signedHeadersString,Signature=$finalSig")
)
logger.debug(s"Final headers are: $finalHeaders")
req.withHeaders(finalHeaders)
})
}
/**
* Step one of the algorithm: calculate the canonical string
* @param httpMethod
* @param uriPath
* @param uriQueryParams
* @param headers
* @param payloadHash
* @return
*/
protected def calculateCanonicalString(httpMethod:String, uriPath: String, uriQueryParams:Map[String,String],
headers:Map[String,String], payloadHash:Option[String]) = {
val checksummer = MessageDigest.getInstance("SHA-256")
logger.debug(s"uriPath is $uriPath ${uriPath.length}")
val canonicalUrl = if(uriPath.length<=1) {
"/"
} else {
//Note that the spec at https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html demands that spaces should
//be double-encoded for all services with the EXCEPTION of s3. So we encode once and make sure that + (how the java urlencoder does a space)
//gets replaced with %20 which is what AWS wants
uriPath
.replaceAll("\\+","%20") //java encoder often encodes spaces as '+' instead of %20
.replaceAll(":", "%3A") //: character is not being automatically encoded (%3A)
}
logger.debug(s"encodedUrl: $canonicalUrl")
val canonicalQueryString = uriQueryParams
.map(entry=>URLEncoder.encode(entry._1,"UTF-8")+"=" + entry._2)
.toList.sorted
.mkString("&")
logger.debug(s"canonicalQueryString: $canonicalQueryString")
val updatedHeaders = if(headers.keys.exists(_=="x-amz-content-sha256")){
headers
} else {
headers + ("x-amz-content-sha256"->checksummer.digest("".getBytes("UTF-8")).map("%02x".format(_)).mkString)
}
checksummer.reset()
val canonicalHeaders = updatedHeaders.keys.toList.sorted.map(header=>{
header.toLowerCase + ":" + headers(header).trim
}).mkString("\n") + "\n"
logger.debug(s"canonicalHeaders: $canonicalHeaders")
val signedHeaders = headers.keys.map(_.toLowerCase).toList.sorted.mkString(";")
logger.debug(s"signedHeaders: $signedHeaders")
val hashedPayload = payloadHash match {
case Some(hexDigest)=>hexDigest
case None=>checksummer.digest("".getBytes("UTF-8")).map("%02x".format(_)).mkString
}
logger.debug(s"hashedPayload: $hashedPayload")
s"""$httpMethod
|$canonicalUrl
|$canonicalQueryString
|$canonicalHeaders
|$signedHeaders
|$hashedPayload""".stripMargin
}
/**
* Step two - create a string to sign
* @param region AWS region name (string)
* @param serviceName AWS service name (String)
* @param canonicalRequestString Canonical request string as provided by `calculateCanonicalString`
*/
protected def stringToSign(region:String, serviceName:String, canonicalRequestString: String, requestTime:OffsetDateTime) = {
val checksummer = MessageDigest.getInstance("SHA-256")
val timestamp = requestTime.format(aws_compatible_datetime)
logger.debug(s"timestamp is $timestamp")
val scope = s"${requestTime.format(aws_compatible_date)}/$region/$serviceName/aws4_request"
logger.debug(s"scope is $scope")
val canonRequestDigest = checksummer.digest(canonicalRequestString.getBytes("UTF-8")).map("%02x".format(_)).mkString
logger.debug(s"Digest of canonical string is $canonRequestDigest")
s"""AWS4-HMAC-SHA256
|$timestamp
|$scope
|$canonRequestDigest""".stripMargin
}
/**
* Step three - Calculate signing key
*/
protected def signingKey(secretAccessKey:String, serviceName:String, awsRegion:String, requestTime:OffsetDateTime) = {
val dateValue = requestTime.format(aws_compatible_date)
logger.debug(s"dateValue is $dateValue")
val dateKey = hmacBinary(("AWS4" + secretAccessKey).getBytes("UTF-8"), dateValue)
logger.debug(s"dateKey is $dateKey from $secretAccessKey and $dateValue")
val dateRegionKey = hmacBinary(dateKey, awsRegion)
logger.debug(s"dateRegionKey is $dateRegionKey from $awsRegion")
val dateRegionServiceKey = hmacBinary(dateRegionKey, serviceName)
logger.debug(s"dateRegionServiceKey is $dateRegionServiceKey from $serviceName")
hmacBinary(dateRegionServiceKey, "aws4_request")
}
/**
* Step four - use the signing key on the string to sign
* @param signingKey provided from step three
* @param stringToSign provided from step two
* @return
*/
protected def finalSignature(signingKey:Array[Byte], stringToSign:String) = {
hmacString(signingKey, stringToSign)
}
}