def handleMlCommand()

in sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala [118:360]


  def handleMlCommand(
      sessionHolder: SessionHolder,
      mlCommand: proto.MlCommand): proto.MlCommandResult = {

    val mlCache = sessionHolder.mlCache
    val memoryControlEnabled = sessionHolder.mlCache.getMemoryControlEnabled

    // Disable model training summary when memory control is enabled
    // because training summary can't support
    // size estimation and offloading.
    SummaryUtils.enableTrainingSummary = !memoryControlEnabled

    if (memoryControlEnabled) {
      val maxModelSize = sessionHolder.mlCache.getModelMaxSize

      // Note: Tree training stops early when the growing tree model exceeds
      //  `TreeConfig.trainingEarlyStopModelSizeThresholdInBytes`, to ensure the final
      // model size is lower than `maxModelSize`, set early-stop threshold to
      // half of `maxModelSize`, because in each tree training iteration, the tree
      // nodes will grow up to 2 times, the additional 0.5 is for buffer
      // because the in-memory size is not exactly in direct proportion to the tree nodes.
      TreeConfig.trainingEarlyStopModelSizeThresholdInBytes = (maxModelSize.toDouble / 2.5).toLong
    }

    mlCommand.getCommandCase match {
      case proto.MlCommand.CommandCase.FIT =>
        val fitCmd = mlCommand.getFit
        val estimatorProto = fitCmd.getEstimator
        assert(estimatorProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR)

        val dataset = MLUtils.parseRelationProto(fitCmd.getDataset, sessionHolder)
        val estimator =
          MLUtils.getEstimator(sessionHolder, estimatorProto, Some(fitCmd.getParams))

        if (memoryControlEnabled) {
          try {
            val estimatedModelSize = estimator.estimateModelSize(dataset)
            mlCache.checkModelSize(estimatedModelSize)
          } catch {
            case _: UnsupportedOperationException => ()
          }
          if (estimator.getClass.getName == "org.apache.spark.ml.fpm.FPGrowth") {
            throw MlUnsupportedException(
              "FPGrowth algorithm is not supported " +
                "if Spark Connect model cache offloading is enabled.")
          }
          if (estimator.getClass.getName == "org.apache.spark.ml.clustering.LDA"
            && estimator
              .asInstanceOf[org.apache.spark.ml.clustering.LDA]
              .getOptimizer
              .toLowerCase() == "em") {
            throw MlUnsupportedException(
              "LDA algorithm with 'em' optimizer is not supported " +
                "if Spark Connect model cache offloading is enabled.")
          }
        }
        val model = estimator.fit(dataset).asInstanceOf[Model[_]]
        val id = mlCache.register(model)
        proto.MlCommandResult
          .newBuilder()
          .setOperatorInfo(
            proto.MlCommandResult.MlOperatorInfo
              .newBuilder()
              .setObjRef(proto.ObjectRef.newBuilder().setId(id)))
          .build()

      case proto.MlCommand.CommandCase.FETCH =>
        val helper = AttributeHelper(
          sessionHolder,
          mlCommand.getFetch.getObjRef.getId,
          mlCommand.getFetch.getMethodsList.asScala.toArray)
        val attrResult = helper.getAttribute
        attrResult match {
          case s: Summary =>
            val id = mlCache.register(s)
            proto.MlCommandResult.newBuilder().setSummary(id).build()
          case m: Model[_] =>
            val id = mlCache.register(m)
            proto.MlCommandResult
              .newBuilder()
              .setOperatorInfo(
                proto.MlCommandResult.MlOperatorInfo
                  .newBuilder()
                  .setObjRef(proto.ObjectRef.newBuilder().setId(id)))
              .build()
          case a: Array[_] if a.nonEmpty && a.forall(_.isInstanceOf[Model[_]]) =>
            val ids = a.map(m => mlCache.register(m.asInstanceOf[Model[_]]))
            proto.MlCommandResult
              .newBuilder()
              .setOperatorInfo(
                proto.MlCommandResult.MlOperatorInfo
                  .newBuilder()
                  .setObjRef(proto.ObjectRef.newBuilder().setId(ids.mkString(","))))
              .build()
          case _ =>
            val param = Serializer.serializeParam(attrResult)
            proto.MlCommandResult.newBuilder().setParam(param).build()
        }

      case proto.MlCommand.CommandCase.DELETE =>
        val ids = mutable.ArrayBuilder.make[String]
        mlCommand.getDelete.getObjRefsList.asScala.toArray.foreach { objId =>
          if (!objId.getId.contains(".")) {
            if (mlCache.remove(objId.getId)) {
              ids += objId.getId
            }
          }
        }
        proto.MlCommandResult
          .newBuilder()
          .setOperatorInfo(
            proto.MlCommandResult.MlOperatorInfo
              .newBuilder()
              .setObjRef(proto.ObjectRef.newBuilder().setId(ids.result().mkString(","))))
          .build()

      case proto.MlCommand.CommandCase.CLEAN_CACHE =>
        val size = mlCache.clear()
        proto.MlCommandResult
          .newBuilder()
          .setParam(LiteralValueProtoConverter.toLiteralProto(size))
          .build()

      case proto.MlCommand.CommandCase.GET_CACHE_INFO =>
        proto.MlCommandResult
          .newBuilder()
          .setParam(LiteralValueProtoConverter.toLiteralProto(mlCache.getInfo()))
          .build()

      case proto.MlCommand.CommandCase.WRITE =>
        mlCommand.getWrite.getTypeCase match {
          case proto.MlCommand.Write.TypeCase.OBJ_REF => // save a model
            val objId = mlCommand.getWrite.getObjRef.getId
            val model = mlCache.get(objId).asInstanceOf[Model[_]]
            if (model == null) {
              throw MLCacheInvalidException(s"model $objId")
            }
            val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
            MLUtils.setInstanceParams(copiedModel, mlCommand.getWrite.getParams)

            copiedModel match {
              case m: MLWritable => MLUtils.write(m, mlCommand.getWrite)
              case other => throw MlUnsupportedException(s"$other is not writable")
            }

          // save an estimator/evaluator/transformer
          case proto.MlCommand.Write.TypeCase.OPERATOR =>
            val writer = mlCommand.getWrite
            val operatorType = writer.getOperator.getType
            val operatorName = writer.getOperator.getName
            val params = Some(writer.getParams)

            operatorType match {
              case proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR =>
                val estimator = MLUtils.getEstimator(sessionHolder, writer.getOperator, params)
                estimator match {
                  case writable: MLWritable => MLUtils.write(writable, mlCommand.getWrite)
                  case other => throw MlUnsupportedException(s"Estimator $other is not writable")
                }

              case proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR =>
                val evaluator = MLUtils.getEvaluator(sessionHolder, writer.getOperator, params)
                evaluator match {
                  case writable: MLWritable => MLUtils.write(writable, mlCommand.getWrite)
                  case other => throw MlUnsupportedException(s"Evaluator $other is not writable")
                }

              case proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER =>
                val transformer =
                  MLUtils.getTransformer(sessionHolder, writer.getOperator, params)
                transformer match {
                  case writable: MLWritable => MLUtils.write(writable, mlCommand.getWrite)
                  case other =>
                    throw MlUnsupportedException(s"Transformer $other is not writable")
                }

              case _ =>
                throw MlUnsupportedException(s"Operator $operatorName is not supported")
            }
          case other => throw MlUnsupportedException(s"$other write not supported")
        }
        proto.MlCommandResult.newBuilder().build()

      case proto.MlCommand.CommandCase.READ =>
        val operator = mlCommand.getRead.getOperator
        val name = operator.getName
        val path = mlCommand.getRead.getPath

        if (operator.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_MODEL) {
          val model = MLUtils.loadTransformer(sessionHolder, name, path)
          val id = mlCache.register(model)
          return proto.MlCommandResult
            .newBuilder()
            .setOperatorInfo(
              proto.MlCommandResult.MlOperatorInfo
                .newBuilder()
                .setObjRef(proto.ObjectRef.newBuilder().setId(id))
                .setUid(model.uid)
                .setParams(Serializer.serializeParams(model)))
            .build()
        }

        val mlOperator =
          if (operator.getType ==
              proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) {
            MLUtils.loadEstimator(sessionHolder, name, path).asInstanceOf[Params]
          } else if (operator.getType ==
              proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR) {
            MLUtils.loadEvaluator(sessionHolder, name, path).asInstanceOf[Params]
          } else if (operator.getType ==
              proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER) {
            MLUtils.loadTransformer(sessionHolder, name, path).asInstanceOf[Params]
          } else {
            throw MlUnsupportedException(s"${operator.getType} read not supported")
          }

        proto.MlCommandResult
          .newBuilder()
          .setOperatorInfo(
            proto.MlCommandResult.MlOperatorInfo
              .newBuilder()
              .setName(name)
              .setUid(mlOperator.uid)
              .setParams(Serializer.serializeParams(mlOperator)))
          .build()

      case proto.MlCommand.CommandCase.EVALUATE =>
        val evalCmd = mlCommand.getEvaluate
        val evalProto = evalCmd.getEvaluator
        assert(evalProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR)

        val dataset = MLUtils.parseRelationProto(evalCmd.getDataset, sessionHolder)
        val evaluator =
          MLUtils.getEvaluator(sessionHolder, evalProto, Some(evalCmd.getParams))
        val metric = evaluator.evaluate(dataset)
        proto.MlCommandResult
          .newBuilder()
          .setParam(LiteralValueProtoConverter.toLiteralProto(metric))
          .build()

      case other => throw MlUnsupportedException(s"$other not supported")
    }
  }