in sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala [118:360]
def handleMlCommand(
sessionHolder: SessionHolder,
mlCommand: proto.MlCommand): proto.MlCommandResult = {
val mlCache = sessionHolder.mlCache
val memoryControlEnabled = sessionHolder.mlCache.getMemoryControlEnabled
// Disable model training summary when memory control is enabled
// because training summary can't support
// size estimation and offloading.
SummaryUtils.enableTrainingSummary = !memoryControlEnabled
if (memoryControlEnabled) {
val maxModelSize = sessionHolder.mlCache.getModelMaxSize
// Note: Tree training stops early when the growing tree model exceeds
// `TreeConfig.trainingEarlyStopModelSizeThresholdInBytes`, to ensure the final
// model size is lower than `maxModelSize`, set early-stop threshold to
// half of `maxModelSize`, because in each tree training iteration, the tree
// nodes will grow up to 2 times, the additional 0.5 is for buffer
// because the in-memory size is not exactly in direct proportion to the tree nodes.
TreeConfig.trainingEarlyStopModelSizeThresholdInBytes = (maxModelSize.toDouble / 2.5).toLong
}
mlCommand.getCommandCase match {
case proto.MlCommand.CommandCase.FIT =>
val fitCmd = mlCommand.getFit
val estimatorProto = fitCmd.getEstimator
assert(estimatorProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR)
val dataset = MLUtils.parseRelationProto(fitCmd.getDataset, sessionHolder)
val estimator =
MLUtils.getEstimator(sessionHolder, estimatorProto, Some(fitCmd.getParams))
if (memoryControlEnabled) {
try {
val estimatedModelSize = estimator.estimateModelSize(dataset)
mlCache.checkModelSize(estimatedModelSize)
} catch {
case _: UnsupportedOperationException => ()
}
if (estimator.getClass.getName == "org.apache.spark.ml.fpm.FPGrowth") {
throw MlUnsupportedException(
"FPGrowth algorithm is not supported " +
"if Spark Connect model cache offloading is enabled.")
}
if (estimator.getClass.getName == "org.apache.spark.ml.clustering.LDA"
&& estimator
.asInstanceOf[org.apache.spark.ml.clustering.LDA]
.getOptimizer
.toLowerCase() == "em") {
throw MlUnsupportedException(
"LDA algorithm with 'em' optimizer is not supported " +
"if Spark Connect model cache offloading is enabled.")
}
}
val model = estimator.fit(dataset).asInstanceOf[Model[_]]
val id = mlCache.register(model)
proto.MlCommandResult
.newBuilder()
.setOperatorInfo(
proto.MlCommandResult.MlOperatorInfo
.newBuilder()
.setObjRef(proto.ObjectRef.newBuilder().setId(id)))
.build()
case proto.MlCommand.CommandCase.FETCH =>
val helper = AttributeHelper(
sessionHolder,
mlCommand.getFetch.getObjRef.getId,
mlCommand.getFetch.getMethodsList.asScala.toArray)
val attrResult = helper.getAttribute
attrResult match {
case s: Summary =>
val id = mlCache.register(s)
proto.MlCommandResult.newBuilder().setSummary(id).build()
case m: Model[_] =>
val id = mlCache.register(m)
proto.MlCommandResult
.newBuilder()
.setOperatorInfo(
proto.MlCommandResult.MlOperatorInfo
.newBuilder()
.setObjRef(proto.ObjectRef.newBuilder().setId(id)))
.build()
case a: Array[_] if a.nonEmpty && a.forall(_.isInstanceOf[Model[_]]) =>
val ids = a.map(m => mlCache.register(m.asInstanceOf[Model[_]]))
proto.MlCommandResult
.newBuilder()
.setOperatorInfo(
proto.MlCommandResult.MlOperatorInfo
.newBuilder()
.setObjRef(proto.ObjectRef.newBuilder().setId(ids.mkString(","))))
.build()
case _ =>
val param = Serializer.serializeParam(attrResult)
proto.MlCommandResult.newBuilder().setParam(param).build()
}
case proto.MlCommand.CommandCase.DELETE =>
val ids = mutable.ArrayBuilder.make[String]
mlCommand.getDelete.getObjRefsList.asScala.toArray.foreach { objId =>
if (!objId.getId.contains(".")) {
if (mlCache.remove(objId.getId)) {
ids += objId.getId
}
}
}
proto.MlCommandResult
.newBuilder()
.setOperatorInfo(
proto.MlCommandResult.MlOperatorInfo
.newBuilder()
.setObjRef(proto.ObjectRef.newBuilder().setId(ids.result().mkString(","))))
.build()
case proto.MlCommand.CommandCase.CLEAN_CACHE =>
val size = mlCache.clear()
proto.MlCommandResult
.newBuilder()
.setParam(LiteralValueProtoConverter.toLiteralProto(size))
.build()
case proto.MlCommand.CommandCase.GET_CACHE_INFO =>
proto.MlCommandResult
.newBuilder()
.setParam(LiteralValueProtoConverter.toLiteralProto(mlCache.getInfo()))
.build()
case proto.MlCommand.CommandCase.WRITE =>
mlCommand.getWrite.getTypeCase match {
case proto.MlCommand.Write.TypeCase.OBJ_REF => // save a model
val objId = mlCommand.getWrite.getObjRef.getId
val model = mlCache.get(objId).asInstanceOf[Model[_]]
if (model == null) {
throw MLCacheInvalidException(s"model $objId")
}
val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
MLUtils.setInstanceParams(copiedModel, mlCommand.getWrite.getParams)
copiedModel match {
case m: MLWritable => MLUtils.write(m, mlCommand.getWrite)
case other => throw MlUnsupportedException(s"$other is not writable")
}
// save an estimator/evaluator/transformer
case proto.MlCommand.Write.TypeCase.OPERATOR =>
val writer = mlCommand.getWrite
val operatorType = writer.getOperator.getType
val operatorName = writer.getOperator.getName
val params = Some(writer.getParams)
operatorType match {
case proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR =>
val estimator = MLUtils.getEstimator(sessionHolder, writer.getOperator, params)
estimator match {
case writable: MLWritable => MLUtils.write(writable, mlCommand.getWrite)
case other => throw MlUnsupportedException(s"Estimator $other is not writable")
}
case proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR =>
val evaluator = MLUtils.getEvaluator(sessionHolder, writer.getOperator, params)
evaluator match {
case writable: MLWritable => MLUtils.write(writable, mlCommand.getWrite)
case other => throw MlUnsupportedException(s"Evaluator $other is not writable")
}
case proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER =>
val transformer =
MLUtils.getTransformer(sessionHolder, writer.getOperator, params)
transformer match {
case writable: MLWritable => MLUtils.write(writable, mlCommand.getWrite)
case other =>
throw MlUnsupportedException(s"Transformer $other is not writable")
}
case _ =>
throw MlUnsupportedException(s"Operator $operatorName is not supported")
}
case other => throw MlUnsupportedException(s"$other write not supported")
}
proto.MlCommandResult.newBuilder().build()
case proto.MlCommand.CommandCase.READ =>
val operator = mlCommand.getRead.getOperator
val name = operator.getName
val path = mlCommand.getRead.getPath
if (operator.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_MODEL) {
val model = MLUtils.loadTransformer(sessionHolder, name, path)
val id = mlCache.register(model)
return proto.MlCommandResult
.newBuilder()
.setOperatorInfo(
proto.MlCommandResult.MlOperatorInfo
.newBuilder()
.setObjRef(proto.ObjectRef.newBuilder().setId(id))
.setUid(model.uid)
.setParams(Serializer.serializeParams(model)))
.build()
}
val mlOperator =
if (operator.getType ==
proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) {
MLUtils.loadEstimator(sessionHolder, name, path).asInstanceOf[Params]
} else if (operator.getType ==
proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR) {
MLUtils.loadEvaluator(sessionHolder, name, path).asInstanceOf[Params]
} else if (operator.getType ==
proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER) {
MLUtils.loadTransformer(sessionHolder, name, path).asInstanceOf[Params]
} else {
throw MlUnsupportedException(s"${operator.getType} read not supported")
}
proto.MlCommandResult
.newBuilder()
.setOperatorInfo(
proto.MlCommandResult.MlOperatorInfo
.newBuilder()
.setName(name)
.setUid(mlOperator.uid)
.setParams(Serializer.serializeParams(mlOperator)))
.build()
case proto.MlCommand.CommandCase.EVALUATE =>
val evalCmd = mlCommand.getEvaluate
val evalProto = evalCmd.getEvaluator
assert(evalProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR)
val dataset = MLUtils.parseRelationProto(evalCmd.getDataset, sessionHolder)
val evaluator =
MLUtils.getEvaluator(sessionHolder, evalProto, Some(evalCmd.getParams))
val metric = evaluator.evaluate(dataset)
proto.MlCommandResult
.newBuilder()
.setParam(LiteralValueProtoConverter.toLiteralProto(metric))
.build()
case other => throw MlUnsupportedException(s"$other not supported")
}
}