in sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/LinearLearnerSageMakerEstimator.scala [658:811]
def setBinaryClassifierModelSelectionCriteria(value: String): this.type =
set(binaryClassifierModelSelectionCriteria, value)
def setTargetRecall(value: Double): this.type = set(targetRecall, value)
def setTargetPrecision(value: Double): this.type = set(targetPrecision, value)
def setPositiveExampleWeightMult(value: String): this.type = set(positiveExampleWeightMult, value)
def setPositiveExampleWeightMult(value: Double): this.type = set(positiveExampleWeightMult,
value.toString())
}
/**
* A [[SageMakerEstimator]] that runs a Linear Learner training job in "multiclass classifier" mode
* in SageMaker and returns a [[SageMakerModel]] that can be used to transform a DataFrame using
* the hosted Linear Learner model. The Linear Learner Binary Classifier is useful for classifying
* examples into one of two classes.
*
* Amazon SageMaker Linear Learner trains on RecordIO-encoded Amazon Record protobuf data.
* SageMaker Spark writes a DataFrame to S3 by selecting a column of Vectors named "features"
* and, if present, a column of Doubles named "label". These names are configurable by passing
* a map with entries in trainingSparkDataFormatOptions with key "labelColumnName" or
* "featuresColumnName", with values corresponding to the desired label and features columns.
*
* Inferences made against an Endpoint hosting a Linear Learner Binary classifier model contain
* a "score" field and a "predicted_label" field, both appended to the input DataFrame as
* Doubles.
*
* @param sagemakerRole The SageMaker TrainingJob and Hosting IAM Role. Used by a SageMaker to
* access S3 and ECR resources. SageMaker hosted Endpoints instances
* launched by this Estimator run with this role.
* @param trainingInstanceType The SageMaker TrainingJob Instance Type to use
* @param trainingInstanceCount The number of instances of instanceType to run an
* SageMaker Training Job with
* @param endpointInstanceType The SageMaker Endpoint Confing instance type
* @param endpointInitialInstanceCount The SageMaker Endpoint Config minimum number of instances
* that can be used to host modelImage
* @param requestRowSerializer Serializes Spark DataFrame [[Row]]s for transformation by Models
* built from this Estimator.
* @param responseRowDeserializer Deserializes an Endpoint response into a series of [[Row]]s.
* @param trainingInputS3DataPath An S3 location to upload SageMaker Training Job input data to.
* @param trainingOutputS3DataPath An S3 location for SageMaker to store Training Job output
* data to.
* @param trainingInstanceVolumeSizeInGB The EBS volume size in gigabytes of each instance.
* @param trainingProjectedColumns The columns to project from the Dataset being fit before
* training. If an Optional.empty is passed then no specific
* projection will occur and all columns will be serialized.
* @param trainingChannelName The SageMaker Channel name to input serialized Dataset fit input to
* @param trainingContentType The MIME type of the training data.
* @param trainingS3DataDistribution The SageMaker Training Job S3 data distribution scheme.
* @param trainingSparkDataFormat The Spark Data Format name used to serialize the Dataset being
* fit for input to SageMaker.
* @param trainingSparkDataFormatOptions The Spark Data Format Options used during serialization of
* the Dataset being fit.
* @param trainingInputMode The SageMaker Training Job Channel input mode.
* @param trainingCompressionCodec The type of compression to use when serializing the Dataset
* being fit for input to SageMaker.
* @param trainingMaxRuntimeInSeconds A SageMaker Training Job Termination Condition
* MaxRuntimeInHours.
* @param trainingKmsKeyId A KMS key ID for the Output Data Source
* @param modelEnvironmentVariables The environment variables that SageMaker will set on the model
* container during execution.
* @param endpointCreationPolicy Defines how a SageMaker Endpoint referenced by a
* SageMakerModel is created.
* @param sagemakerClient Amazon SageMaker client. Used to send CreateTrainingJob, CreateModel,
* and CreateEndpoint requests.
* @param region The region in which to run the algorithm. If not specified, gets the region from
* the DefaultAwsRegionProviderChain.
* @param s3Client AmazonS3. Used to create a bucket for staging SageMaker Training Job input
* and/or output if either are set to S3AutoCreatePath.
* @param stsClient AmazonSTS. Used to resolve the account number when creating staging
* input / output buckets.
* @param modelPrependInputRowsToTransformationRows Whether the transformation result on Models
* built by this Estimator should also include the input Rows. If true, each output Row
* is formed by a concatenation of the input Row with the corresponding Row produced by
* SageMaker Endpoint invocation, produced by responseRowDeserializer.
* If false, each output Row is just taken from responseRowDeserializer.
* @param deleteStagingDataAfterTraining Whether to remove the training data on s3 after training
* is complete or failed.
* @param namePolicyFactory The [[NamePolicyFactory]] to use when naming SageMaker entities
* created during fit
* @param uid The unique identifier of this Estimator. Used to represent this stage in Spark
* ML pipelines.
*/
class LinearLearnerMultiClassClassifier(
override val sagemakerRole : IAMRoleResource = IAMRoleFromConfig(),
override val trainingInstanceType : String,
override val trainingInstanceCount : Int,
override val endpointInstanceType : String,
override val endpointInitialInstanceCount : Int,
override val requestRowSerializer : RequestRowSerializer =
new ProtobufRequestRowSerializer(),
override val responseRowDeserializer : ResponseRowDeserializer =
new LinearLearnerMultiClassClassifierProtobufResponseRowDeserializer(),
override val trainingInputS3DataPath : S3Resource = S3AutoCreatePath(),
override val trainingOutputS3DataPath : S3Resource = S3AutoCreatePath(),
override val trainingInstanceVolumeSizeInGB : Int = 1024,
override val trainingProjectedColumns : Option[List[String]] = None,
override val trainingChannelName : String = "train",
override val trainingContentType: Option[String] = None,
override val trainingS3DataDistribution : String =
S3DataDistribution.ShardedByS3Key.toString,
override val trainingSparkDataFormat : String = "sagemaker",
override val trainingSparkDataFormatOptions : Map[String, String] = Map(),
override val trainingInputMode : String = TrainingInputMode.File.toString,
override val trainingCompressionCodec : Option[String] = None,
override val trainingMaxRuntimeInSeconds : Int = 24 * 60 * 60,
override val trainingKmsKeyId : Option[String] = None,
override val modelEnvironmentVariables : Map[String, String] = Map(),
override val endpointCreationPolicy : EndpointCreationPolicy =
EndpointCreationPolicy.CREATE_ON_CONSTRUCT,
override val sagemakerClient : AmazonSageMaker
= AmazonSageMakerClientBuilder.defaultClient,
override val region : Option[String] = None,
override val s3Client : AmazonS3 = AmazonS3ClientBuilder.defaultClient(),
override val stsClient : AWSSecurityTokenService =
AWSSecurityTokenServiceClientBuilder.defaultClient(),
override val modelPrependInputRowsToTransformationRows : Boolean = true,
override val deleteStagingDataAfterTraining : Boolean = true,
override val namePolicyFactory : NamePolicyFactory = new RandomNamePolicyFactory(),
override val uid : String = Identifiable.randomUID("sagemaker"))
extends LinearLearnerSageMakerEstimator(
sagemakerRole,
trainingInstanceType,
trainingInstanceCount,
endpointInstanceType,
endpointInitialInstanceCount,
requestRowSerializer,
responseRowDeserializer,
trainingInputS3DataPath,
trainingOutputS3DataPath,
trainingInstanceVolumeSizeInGB,
trainingProjectedColumns,
trainingChannelName,
trainingContentType,
trainingS3DataDistribution,
trainingSparkDataFormat,
trainingSparkDataFormatOptions,
trainingInputMode,
trainingCompressionCodec,
trainingMaxRuntimeInSeconds,
trainingKmsKeyId,
modelEnvironmentVariables,
endpointCreationPolicy,
sagemakerClient,
region,
s3Client,
stsClient,
modelPrependInputRowsToTransformationRows,
deleteStagingDataAfterTraining,
namePolicyFactory,
uid) with MultiClassClassifierParams {