app/agent/origin.scala (256 lines of code) (raw):
package agent
import java.io.FileNotFoundException
import java.net.{URI, URL, URLConnection, URLStreamHandler}
import collectors.Instance
import conf.{AWS, PrismConfiguration}
import play.api.libs.json.{JsObject, JsValue, Json}
import software.amazon.awssdk.auth.credentials.{
AwsBasicCredentials,
AwsCredentialsProvider,
ProfileCredentialsProvider,
StaticCredentialsProvider
}
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.s3.S3Client
import software.amazon.awssdk.services.s3.model.GetObjectRequest
import software.amazon.awssdk.services.sts.StsClient
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest
import utils.{AWSCredentialProviders, Logging, Marker}
import scala.io.Source
import scala.language.postfixOps
import scala.util.Try
import scala.util.control.NonFatal
import scala.util.matching.Regex
class Accounts(prismConfiguration: PrismConfiguration) extends Logging {
val all: Seq[Origin] =
(prismConfiguration.accounts.aws.list ++ prismConfiguration.accounts.amis.list)
.map { awsOrigin =>
if (awsOrigin.accountNumber.isDefined) {
awsOrigin
} else {
Try {
val stsClient = StsClient
.builder()
.credentialsProvider(awsOrigin.credentials.provider)
.region(Region.AWS_GLOBAL)
.build
val accountNumber = stsClient.getCallerIdentity.account()
awsOrigin.copy(accountNumber = Some(accountNumber))
} recover { case NonFatal(e) =>
log.warn(s"Failed to extract the account number for $awsOrigin", e)
awsOrigin.copy(accountNumber = Some("?????????"))
} get
}
} ++ prismConfiguration.accounts.json.list
def forResource(resource: String): Seq[Origin] = all.filter(origin =>
origin.resources.isEmpty || origin.resources.contains(resource)
)
}
trait Origin extends Marker {
def vendor: String
def account: String
def filterMap: Map[String, String] = Map.empty
def resources: Set[String]
def crawlRate: Map[String, CrawlRate]
def transformInstance(input: Instance): Instance = input
def standardFields: Map[String, String] =
Map("vendor" -> vendor, "accountName" -> account)
def jsonFields: Map[String, String]
def toJson: JsObject = JsObject(
(standardFields ++ jsonFields).view.mapValues(Json.toJson(_)).toSeq
)
}
case class Credentials(
accessKey: Option[String],
role: Option[String],
profile: Option[String],
regionName: String
)(secretKey: Option[String]) {
val region: Region = Region.of(regionName)
val (id, provider) = (accessKey, secretKey, role, profile) match {
case (_, _, Some(r), Some(p)) =>
val stsClient = StsClient.builder
.credentialsProvider(
ProfileCredentialsProvider.builder.profileName(p).build
)
.region(region)
.build
val req: AssumeRoleRequest = AssumeRoleRequest.builder
.roleSessionName("prism")
.roleArn(r)
.build
(
s"$p/$r",
StsAssumeRoleCredentialsProvider.builder
.stsClient(stsClient)
.refreshRequest(req)
.build
)
case (_, _, Some(r), _) =>
val req: AssumeRoleRequest = AssumeRoleRequest.builder
.roleSessionName("prism")
.roleArn(r)
.build
val stsClient = StsClient.builder
.region(region)
.build
(
r,
StsAssumeRoleCredentialsProvider.builder
.stsClient(stsClient)
.refreshRequest(req)
.build
)
case (_, _, _, Some(p)) =>
(
p,
ProfileCredentialsProvider.builder.profileName(p).build
)
case (Some(ak), Some(sk), _, _) =>
(
ak,
StaticCredentialsProvider.create(AwsBasicCredentials.create(ak, sk))
)
case _ =>
(
"default",
AWSCredentialProviders.deployToolsCredentialsProviderChain
)
}
}
object AmazonOrigin {
val ArnIamAccountExtractor: Regex = """arn:aws:iam::(\d+):role.*""".r
def apply(
account: String,
region: String,
resources: Set[String],
stagePrefix: Option[String],
credentials: Credentials,
ownerId: Option[String],
crawlRates: Map[String, CrawlRate]
): AmazonOrigin = {
val accountNumber = credentials.role.flatMap {
case ArnIamAccountExtractor(accountId) => Some(accountId)
case _ => None
}
AmazonOrigin(
account,
region,
credentials,
resources,
stagePrefix,
accountNumber,
ownerId,
crawlRates
)
}
def amis(
name: String,
region: String,
accountNumber: Option[String],
credentials: Credentials,
ownerId: Option[String],
crawlRates: Map[String, CrawlRate]
): AmazonOrigin = {
AmazonOrigin(
name,
region,
credentials,
Set("images"),
None,
accountNumber,
ownerId,
crawlRates
)
}
}
case class AmazonOrigin(
account: String,
region: String,
credentials: Credentials,
resources: Set[String],
stagePrefix: Option[String],
accountNumber: Option[String],
ownerId: Option[String],
crawlRate: Map[String, CrawlRate]
) extends Origin {
lazy val vendor = "aws"
override lazy val filterMap =
Map("vendor" -> vendor, "region" -> region, "accountName" -> account)
override def transformInstance(input: Instance): Instance =
stagePrefix.map(input.prefixStage).getOrElse(input)
val jsonFields: Map[String, String] =
Map("region" -> region, "credentials" -> credentials.id) ++
accountNumber.map("accountNumber" -> _) ++
ownerId.map("ownerId" -> _)
val awsRegionV2: Region = Region.of(region)
override def toMarkerMap: Map[String, Any] = Map("region" -> awsRegionV2.id)
}
case class JsonOrigin(
vendor: String,
account: String,
url: String,
resources: Set[String],
crawlRate: Map[String, CrawlRate]
) extends Origin
with Logging {
private val classpathHandler = new URLStreamHandler {
override def openConnection(u: URL): URLConnection = {
Option(getClass.getResource(u.getPath))
.map(_.openConnection())
.getOrElse {
throw new FileNotFoundException(
"%s not found on classpath" format u.getPath
)
}
}
}
def credsFromS3Url(url: URI): AwsCredentialsProvider = {
Option(url.getUserInfo) match {
case Some(role) if role.startsWith("arn:") =>
val request: AssumeRoleRequest = AssumeRoleRequest.builder
.roleSessionName("prismS3")
.roleArn(role)
.build
StsAssumeRoleCredentialsProvider.builder.refreshRequest(request).build
case Some(profile) =>
AWSCredentialProviders.profileCredentialsProvider(profile)
case _ => AWSCredentialProviders.deployToolsCredentialsProviderChain
}
}
def data(resource: ResourceType): JsValue = {
val source: Source = new URI(
url.replace("%resource%", resource.name)
) match {
case classPathLocation if classPathLocation.getScheme == "classpath" =>
Source.fromURL(
new URL(null, classPathLocation.toString, classpathHandler),
"utf-8"
)
case s3Location if s3Location.getScheme == "s3" =>
val s3Client = S3Client.builder
.credentialsProvider(credsFromS3Url(s3Location))
.region(AWS.connectionRegion)
.build
val obj = s3Client.getObjectAsBytes(
GetObjectRequest.builder
.bucket(s3Location.getHost)
.key(s3Location.getPath.stripPrefix("/"))
.build
)
Source.fromBytes(obj.asByteArray)
case otherURL =>
Source.fromURL(otherURL.toURL, "utf-8")
}
val jsonText: String =
try {
source.getLines().mkString
} finally {
source.close()
}
Json.parse(jsonText)
}
val jsonFields = Map("url" -> url)
override def toMarkerMap: Map[String, Any] = jsonFields
}