def fit()

in mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala [100:316]


  def fit(
      instances: RDD[Instance],
      instr: OptionalInstrumentation = OptionalInstrumentation.create(
        classOf[WeightedLeastSquares]),
      depth: Int = 2
    ): WeightedLeastSquaresModel = {
    if (regParam == 0.0) {
      instr.logWarning("regParam is zero, which might cause numerical instability and overfitting.")
    }

    val summary = instances.treeAggregate[Aggregator](
      zeroValue = new Aggregator,
      seqOp = (agg: Aggregator, x: Instance) => agg.add(x),
      combOp = (agg1: Aggregator, agg2: Aggregator) => agg1.merge(agg2),
      depth = depth,
      finalAggregateOnExecutor = true)
    summary.validate()
    instr.logInfo(log"Number of instances: ${MDC(COUNT, summary.count)}.")
    val k = if (fitIntercept) summary.k + 1 else summary.k
    val numFeatures = summary.k
    val triK = summary.triK
    val wSum = summary.wSum

    val rawBStd = summary.bStd
    val rawBBar = summary.bBar
    // if b is constant (rawBStd is zero), then b cannot be scaled. In this case
    // setting bStd=abs(rawBBar) ensures that b is not scaled anymore in l-bfgs algorithm.
    val bStd = if (rawBStd == 0.0) math.abs(rawBBar) else rawBStd

    if (rawBStd == 0) {
      if (fitIntercept || rawBBar == 0.0) {
        if (rawBBar == 0.0) {
          instr.logWarning("Mean and standard deviation of the label are zero, so the " +
            "coefficients and the intercept will all be zero; as a result, training is not " +
            "needed.")
        } else {
          instr.logWarning("The standard deviation of the label is zero, so the coefficients " +
            "will be zeros and the intercept will be the mean of the label; as a result, " +
            "training is not needed.")
        }
        val coefficients = new DenseVector(Array.ofDim(numFeatures))
        val intercept = rawBBar
        val diagInvAtWA = new DenseVector(Array(0D))
        return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA, Array(0D))
      } else {
        require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " +
          "zero. Model cannot be regularized when labels are standardized.")
        instr.logWarning("The standard deviation of the label is zero. Consider setting " +
          "fitIntercept=true.")
      }
    }

    val bBar = summary.bBar / bStd
    val bbBar = summary.bbBar / (bStd * bStd)

    val aStd = summary.aStd
    val aStdValues = aStd.values

    val aBar = {
      val _aBar = summary.aBar
      val _aBarValues = _aBar.values
      var i = 0
      // scale aBar to standardized space in-place
      while (i < numFeatures) {
        if (aStdValues(i) == 0.0) {
          _aBarValues(i) = 0.0
        } else {
          _aBarValues(i) /= aStdValues(i)
        }
        i += 1
      }
      _aBar
    }
    val aBarValues = aBar.values

    val abBar = {
      val _abBar = summary.abBar
      val _abBarValues = _abBar.values
      var i = 0
      // scale abBar to standardized space in-place
      while (i < numFeatures) {
        if (aStdValues(i) == 0.0) {
          _abBarValues(i) = 0.0
        } else {
          _abBarValues(i) /= (aStdValues(i) * bStd)
        }
        i += 1
      }
      _abBar
    }
    val abBarValues = abBar.values

    val aaBar = {
      val _aaBar = summary.aaBar
      val _aaBarValues = _aaBar.values
      var j = 0
      var p = 0
      // scale aaBar to standardized space in-place
      while (j < numFeatures) {
        val aStdJ = aStdValues(j)
        var i = 0
        while (i <= j) {
          val aStdI = aStdValues(i)
          if (aStdJ == 0.0 || aStdI == 0.0) {
            _aaBarValues(p) = 0.0
          } else {
            _aaBarValues(p) /= (aStdI * aStdJ)
          }
          p += 1
          i += 1
        }
        j += 1
      }
      _aaBar
    }
    val aaBarValues = aaBar.values

    val effectiveRegParam = regParam / bStd
    val effectiveL1RegParam = elasticNetParam * effectiveRegParam
    val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam

    // add L2 regularization to diagonals
    var i = 0
    var j = 2
    while (i < triK) {
      var lambda = effectiveL2RegParam
      if (!standardizeFeatures) {
        val std = aStdValues(j - 2)
        if (std != 0.0) {
          lambda /= (std * std)
        } else {
          lambda = 0.0
        }
      }
      if (!standardizeLabel) {
        lambda *= bStd
      }
      aaBarValues(i) += lambda
      i += j
      j += 1
    }

    val aa = getAtA(aaBarValues, aBarValues)
    val ab = getAtB(abBarValues, bBar)

    val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 &&
      regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) {
      val effectiveL1RegFun: Option[(Int) => Double] = if (effectiveL1RegParam != 0.0) {
        Some((index: Int) => {
            if (fitIntercept && index == numFeatures) {
              0.0
            } else {
              if (standardizeFeatures) {
                effectiveL1RegParam
              } else {
                if (aStdValues(index) != 0.0) effectiveL1RegParam / aStdValues(index) else 0.0
              }
            }
          })
      } else {
        None
      }
      new QuasiNewtonSolver(fitIntercept, maxIter, tol, effectiveL1RegFun)
    } else {
      new CholeskySolver
    }

    val solution = solver match {
      case cholesky: CholeskySolver =>
        try {
          cholesky.solve(bBar, bbBar, ab, aa, aBar)
        } catch {
          // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to
          // Quasi-Newton solver.
          case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto =>
            instr.logWarning("Cholesky solver failed due to singular covariance matrix. " +
              "Retrying with Quasi-Newton solver.")
            // ab and aa were modified in place, so reconstruct them
            val _aa = getAtA(aaBarValues, aBarValues)
            val _ab = getAtB(abBarValues, bBar)
            val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None)
            newSolver.solve(bBar, bbBar, _ab, _aa, aBar)
        }
      case qn: QuasiNewtonSolver =>
        qn.solve(bBar, bbBar, ab, aa, aBar)
    }

    val (coefficientArray, intercept) = if (fitIntercept) {
      (solution.coefficients.slice(0, solution.coefficients.length - 1),
        solution.coefficients.last * bStd)
    } else {
      (solution.coefficients, 0.0)
    }

    // convert the coefficients from the scaled space to the original space
    var q = 0
    val len = coefficientArray.length
    while (q < len) {
      coefficientArray(q) *= { if (aStdValues(q) != 0.0) bStd / aStdValues(q) else 0.0 }
      q += 1
    }

    // aaInv is a packed upper triangular matrix, here we get all elements on diagonal
    val diagInvAtWA = solution.aaInv.map { inv =>
      new DenseVector((1 to k).map { i =>
        val multiplier = if (i == k && fitIntercept) {
          1.0
        } else {
          aStdValues(i - 1) * aStdValues(i - 1)
        }
        inv(i + (i - 1) * i / 2 - 1) / (wSum * multiplier)
      }.toArray)
    }.getOrElse(new DenseVector(Array(0D)))

    new WeightedLeastSquaresModel(new DenseVector(coefficientArray), intercept, diagInvAtWA,
      solution.objectiveHistory.getOrElse(Array(0D)))
  }