void EdgeML::altMinSGD()

in cpp/src/ProtoNN/ProtoNNFunctions.cpp [575:992]


void EdgeML::altMinSGD(
  const EdgeML::Data& data,
  EdgeML::ProtoNN::ProtoNNModel& model,
  FP_TYPE *const stats,
  const std::string& outDir)
{
  // This allows us to make mkl-blas calls on Eigen matrices   
  assert(sizeof(MKL_INT) == sizeof(Eigen::Index));
  Timer timer("altMinSGD");
  assert(sizeof(Eigen::Index) == sizeof(dataCount_t));

  /*
    [~, n] = size(X); [l, m] = size(Z);
    if isempty(iters)
    iters = 50;
    end
    tol = 1e-5;

    % alternating minimization
    counter = 1;epochs = 10; batchSize = 512;

    batchSize = 128; epochs = 3; sgdTol = 0.02;
    learning_rate = 0;
    learning_rate_Z = 0.2; learning_rate_B = 0.2; learning_rate_W = 0.2;
  */

  dataCount_t n = data.Xtrain.cols();
  int         epochs = model.hyperParams.epochs;
  FP_TYPE     sgdTol = (FP_TYPE) 0.02;
  dataCount_t bs = std::min((dataCount_t)model.hyperParams.batchSize, (dataCount_t)n);
#ifdef XML
  dataCount_t hessianbs = std::min((dataCount_t)(1 << 10), bs);
#else
  dataCount_t hessianbs = bs;
#endif 
  //const FP_TYPE hessianAdjustment = 1.0; // (FP_TYPE)hessianbs / (FP_TYPE)bs;
  int     etaUpdate = 0;
  FP_TYPE armijoZ((FP_TYPE)0.2), armijoB((FP_TYPE)0.2), armijoW((FP_TYPE)0.2);
  FP_TYPE fOld, fNew, etaZ(1), etaB(1), etaW(1);

  LOG_INFO("\nComputing model size assuming 4 bytes per entry for matrices with sparsity > 0.5 and 8 bytes per entry for matrices with sparsity <= 0.5 (to store sparse matrices, we require about 4 bytes for the index information)...");
  LOG_INFO("Model size in kB = " + std::to_string(computeModelSizeInkB(model.hyperParams.lambdaW, model.hyperParams.lambdaZ, model.hyperParams.lambdaB, model.params.W, model.params.Z, model.params.B)));

  MatrixXuf WX(model.params.W.rows(), data.Xtrain.cols());
  mm(WX, model.params.W, CblasNoTrans, data.Xtrain, CblasNoTrans, 1.0, 0.0L);

  MatrixXuf WXvalidation(model.params.W.rows(), data.Xvalidation.cols());
  if (data.Xvalidation.cols() > 0) {
    mm(WXvalidation, model.params.W, CblasNoTrans, data.Xvalidation, CblasNoTrans, 1.0, 0.0L);
  }

#ifdef XML
  dataCount_t numEvalTrain = std::min((dataCount_t)20000, (dataCount_t)data.Xtrain.cols());
  MatrixXuf WX_sub(WX.rows(), numEvalTrain);
  SparseMatrixuf Y_sub(data.Ytrain.rows(), numEvalTrain);
  SparseMatrixuf X_sub(data.Xtrain.rows(), numEvalTrain);
  randPick(data.Xtrain, X_sub);
  randPick(data.Ytrain, Y_sub);

  mm(WX_sub, model.params.W, CblasNoTrans, X_sub, CblasNoTrans, 1.0, 0.0L);

  dataCount_t numEvalValidation= std::min((dataCount_t)10000, (dataCount_t)data.Xvalidation.cols());
  MatrixXuf WXvalidation_sub(WX.rows(), numEvalValidation);
  SparseMatrixuf Yvalidation_sub(data.Yvalidation.rows(), numEvalValidation);
  SparseMatrixuf Xvalidation_sub(data.Xvalidation.rows(), numEvalValidation);
  if (data.Xvalidation.cols() > 0) {
    randPick(data.Xvalidation, Xvalidation_sub);
    randPick(data.Yvalidation, Yvalidation_sub);
    mm(WXvalidation_sub, model.params.W, CblasNoTrans, Xvalidation_sub, CblasNoTrans, 1.0, 0.0L);
  }
#endif

  timer.nextTime("starting evaluation");


  LOG_INFO("\nInitial stats...");
#ifdef XML
  fNew = batchEvaluate(model.params.Z, Y_sub, Yvalidation_sub, model.params.B, WX_sub, WXvalidation_sub, model.hyperParams.gamma, model.hyperParams.problemType, stats);
#else 
  fNew = batchEvaluate(model.params.Z, data.Ytrain, data.Yvalidation, model.params.B, WX, WXvalidation, model.hyperParams.gamma, model.hyperParams.problemType, stats);
#endif 
  timer.nextTime("evaluating");

  VectorXf eta = VectorXf::Zero(10, 1);

  MatrixXuf gtmpW(model.params.W.rows(), model.params.W.cols());
  WMatType  Wtmp(model.params.W.rows(), model.params.W.cols());

  MatrixXuf gtmpB(model.params.B.rows(), model.params.B.cols());
  BMatType Btmp(model.params.B.rows(), model.params.B.cols());

  MatrixXuf gtmpZ(model.params.Z.rows(), model.params.Z.cols());
  ZMatType  Ztmp(model.params.Z.rows(), model.params.Z.cols());


#if defined(DUMP) || defined(VERIFY)
  std::ofstream f;
  std::string fileName;
#endif

  LOG_INFO("\nStarting optimization. Number of outer iterations (altMinSGD) = " + std::to_string(model.hyperParams.iters));
  // for i = 1 : iters
  for (int i = 0; i < model.hyperParams.iters; ++i) {
    LOG_INFO(
      "\n=========================== " + std::to_string(i) + "\n"
      + "On iter " + std::to_string(i) + "\n" +
      +"=========================== " + std::to_string(i));
    timer.nextTime("starting optimization w.r.t. W");
    LOG_INFO("Optimizing w.r.t. projection matrix (W)...");

#ifdef BTLS
    etaW = armijoW * btls<WMatType>
      ([&model, &data] (const WMatType& W, const Eigen::Index begin, const Eigen::Index end) ->FP_TYPE {
	MatrixXuf WX = MatrixXuf::Zero(W.rows(), end - begin);
	SparseMatrixuf XMiddle = data.Xtrain.middleCols(begin, end - begin);
	mm(WX, W, CblasNoTrans, XMiddle,
	   CblasNoTrans, 1.0, 0.0L);
	return L(model.params.Z, data.Ytrain,
		 gaussianKernel(model.params.B, WX, model.hyperParams.gamma),
		 begin, end);
      },
	[&model, &data]
	(const WMatType& W, const Eigen::Index begin, const Eigen::Index end)
	->MatrixXuf {
	MatrixXuf WX = MatrixXuf::Zero(W.rows(), end - begin);
	SparseMatrixuf XMiddle = data.Xtrain.middleCols(begin, end - begin);
	mm(WX, W, CblasNoTrans,
	   XMiddle,
	   CblasNoTrans, 1.0, 0.0L);
	return gradL_W(model.params.B, data.Ytrain, model.params.Z, W, data.Xtrain,
		       gaussianKernel(model.params.B, WX, model.hyperParams.gamma),
		       model.hyperParams.gamma, begin, end);
      },
	std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaW),
	model.params.W, n, bs, (etaW/armijoW)*2);
#else
    for (auto j = 0; j < eta.size(); ++j) {
      Eigen::Index idx1 = (j*(Eigen::Index)hessianbs) % n;
      Eigen::Index idx2 = ((j + 1)*(Eigen::Index)hessianbs) % n;
      //assert (((j+1)*(Eigen::Index)hessianbs) < n);

      if (idx2 <= idx1) idx2 = n;

      gtmpW = gradL_W(model.params.B, data.Ytrain, model.params.Z, model.params.W, data.Xtrain,
        gaussianKernel(model.params.B, WX, model.hyperParams.gamma, idx1, idx2),
        model.hyperParams.gamma, idx1, idx2);

      MatrixXuf gtmpWThresh = gtmpW;
      hardThrsd(gtmpWThresh, model.hyperParams.lambdaW);

      Wtmp = model.params.W
        - 0.001*safeDiv(model.params.W.cwiseAbs().maxCoeff(), gtmpW.cwiseAbs().maxCoeff()) * gtmpWThresh;
      gtmpW -= gradL_W(model.params.B, data.Ytrain, model.params.Z, Wtmp, data.Xtrain,
        gaussianKernel(model.params.B, Wtmp*data.Xtrain.middleCols(idx1, idx2 - idx1), model.hyperParams.gamma),
        model.hyperParams.gamma, idx1, idx2);

      if (gtmpW.norm() <= 1e-20L) {
        LOG_WARNING("Difference between consecutive gradients of W has become really low.");
        eta(j) = 1.0;
      }
      else
        eta(j) = safeDiv((Wtmp - model.params.W).norm(), gtmpW.norm());
    }
    std::sort(eta.data(), eta.data() + eta.size());
    etaW = armijoW * eta(4);
#endif
    //LOG_INFO("Step-length estimate for gradW = " + std::to_string(etaW));

    accProxSGD<WMatType>
      (//[&model.params.Z, &data.Ytrain, &model.params.B, &data.Xtrain, &model.hyperParams] TODO: Figure out the elegant way of getting this to work
        [&model, &data]
    (const WMatType& W, const Eigen::Index begin, const Eigen::Index end)
      ->FP_TYPE {
      MatrixXuf WX = MatrixXuf::Zero(W.rows(), end - begin);
      SparseMatrixuf XMiddle = data.Xtrain.middleCols(begin, end - begin);
      mm(WX, W, CblasNoTrans,
        XMiddle,
        CblasNoTrans, 1.0, 0.0L);
      return L(model.params.Z, data.Ytrain, gaussianKernel(model.params.B, WX, model.hyperParams.gamma), begin, end);
    },
      // [&(model.params.B), &(data.Ytrain), &(model.params.Z), &(data.Xtrain), &(model.hyperParams)]
      [&model, &data]
    (const WMatType& W, const Eigen::Index begin, const Eigen::Index end)
      ->MatrixXuf {
      MatrixXuf WX = MatrixXuf::Zero(W.rows(), end - begin);
      SparseMatrixuf XMiddle = data.Xtrain.middleCols(begin, end - begin);
      mm(WX, W, CblasNoTrans,
        XMiddle,
        CblasNoTrans, 1.0, 0.0L);
      return gradL_W(model.params.B, data.Ytrain, model.params.Z, W, data.Xtrain,
        gaussianKernel(model.params.B, WX, model.hyperParams.gamma),
        model.hyperParams.gamma, begin, end);
    },
      std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaW),
      model.params.W, epochs, n, bs, etaW, etaUpdate);
    timer.nextTime("ending gradW");
    //LOG_INFO("Final step-length for gradW = " + std::to_string(etaW));

    mm(WX, model.params.W, CblasNoTrans, data.Xtrain, CblasNoTrans, 1.0, 0.0L);
    if (data.Xvalidation.cols() > 0) {
      mm(WXvalidation, model.params.W, CblasNoTrans, data.Xvalidation, CblasNoTrans, 1.0, 0.0L);
    }

    fOld = fNew;
#ifdef XML
    mm(WX_sub, model.params.W, CblasNoTrans, X_sub, CblasNoTrans, 1.0, 0.0L);
    if (data.Xvalidation.cols() > 0) {
      mm(WXvalidation_sub, model.params.W, CblasNoTrans, Xvalidation_sub, CblasNoTrans, 1.0, 0.0L);
    }
    fNew = batchEvaluate(model.params.Z, Y_sub, Yvalidation_sub, model.params.B, WX_sub, WXvalidation_sub, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 3);
#else 
    fNew = batchEvaluate(model.params.Z, data.Ytrain, data.Yvalidation, model.params.B, WX, WXvalidation, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 3);
#endif 

    if (fNew >= fOld * (1 + safeDiv(sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
      armijoW *= (FP_TYPE)0.7;
    else if (fNew <= fOld * (1 - safeDiv(3 * sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
      armijoW *= (FP_TYPE)1.1;
    else;

#ifdef VERIFY
    fileName = outDir + "/verify/W" + std::to_string(i);
    f.open(fileName);
    f << "W_check = [" << model.params.W << "];" << std::endl;
    f.close();
#endif 
#ifdef DUMP 
    fileName = outDir + "/dump/W" + std::to_string(i);
    f.open(fileName);
    f << model.params.W.format(eigen_tsv);
    f.close();
#endif 

    timer.nextTime("starting optimization w.r.t. Z");
    LOG_INFO("Optimizing w.r.t. prototype-label matrix (Z)...");

#ifdef BTLS
    etaZ = armijoZ * btls<ZMatType>
      ([&model, &data, &WX]
	(const ZMatType& Z, const Eigen::Index begin, const Eigen::Index end)
	->FP_TYPE {return L(Z, data.Ytrain, gaussianKernel(model.params.B, WX, model.hyperParams.gamma, begin, end), begin, end); },
	[&model, &data, &WX]
	(const ZMatType& Z, const Eigen::Index begin, const Eigen::Index end)
	->MatrixXuf
      {return gradL_Z(Z, data.Ytrain,
		      gaussianKernel(model.params.B, WX, model.hyperParams.gamma, begin, end),
		      begin, end); },
	std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaZ),
       model.params.Z, n, bs, (etaZ/armijoZ)*2);
#else
    for (auto j = 0; j < eta.size(); ++j) { //eta.size(); ++j) {
      Eigen::Index idx1 = (j*(Eigen::Index)hessianbs) % n;
      Eigen::Index idx2 = ((j + 1)*(Eigen::Index)hessianbs) % n;
      if (idx2 <= idx1) idx2 = n;

      gtmpZ = gradL_Z(model.params.Z, data.Ytrain,
        gaussianKernel(model.params.B, WX, model.hyperParams.gamma, idx1, idx2),
        idx1, idx2);

      MatrixXuf gtmpZThresh = gtmpZ;
      hardThrsd(gtmpZThresh, model.hyperParams.lambdaZ);

      // Below: Ztmp = Z - 0.001*safeDiv(maxAbsVal(Z), gtmpZ.cwiseAbs().maxCoeff()) * gtmpZThresh;
      gtmpZThresh *= (FP_TYPE)-0.001*safeDiv(maxAbsVal(model.params.Z), gtmpZ.cwiseAbs().maxCoeff());
      typeMismatchAssign(Ztmp, gtmpZThresh);
      Ztmp += model.params.Z;

      gtmpZ -= gradL_Z(Ztmp, data.Ytrain,
        gaussianKernel(model.params.B, WX, model.hyperParams.gamma, idx1, idx2),
        idx1, idx2);

      if (gtmpZ.norm() <= 1e-20L) {
        LOG_WARNING("Difference between consecutive gradients of Z has become really low.");
        eta(j) = 1.0;
      }
      else
        eta(j) = safeDiv((Ztmp - model.params.Z).norm(), gtmpZ.norm());
    }
    std::sort(eta.data(), eta.data() + eta.size());
    etaZ = armijoZ * eta(4);
#endif
    //LOG_INFO("Step-length estimate for gradZ = " + std::to_string(etaZ));
    
    accProxSGD<ZMatType>
      (//[&model.params.B, &data.Ytrain, &WX, &model.hyperParams] 
        [&model, &data, &WX]
    (const ZMatType& Z, const Eigen::Index begin, const Eigen::Index end)
      ->FP_TYPE {return L(Z, data.Ytrain, gaussianKernel(model.params.B, WX, model.hyperParams.gamma, begin, end), begin, end); },
      //[&WX, &data.Ytrain, &model.params.B, &model.hyperParams]
      [&model, &data, &WX]
    (const ZMatType& Z, const Eigen::Index begin, const Eigen::Index end)
      ->MatrixXuf
    {return gradL_Z(Z, data.Ytrain,
      gaussianKernel(model.params.B, WX, model.hyperParams.gamma, begin, end),
      begin, end); },
      std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaZ),
      model.params.Z, epochs, n, bs, etaZ, etaUpdate);
    timer.nextTime("ending gradZ");
    //LOG_INFO("Final step-length for gradZ = " + std::to_string(etaZ));

    fOld = fNew;
#ifdef XML
    fNew = batchEvaluate(model.params.Z, Y_sub, Yvalidation_sub, model.params.B, WX_sub, WXvalidation_sub, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 6);
#else 
    fNew = batchEvaluate(model.params.Z, data.Ytrain, data.Yvalidation, model.params.B, WX, WXvalidation, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 6);
#endif

    if (fNew >= fOld * (1 + safeDiv(sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
      armijoZ *= (FP_TYPE)0.7;
    else if (fNew <= fOld * (1 - safeDiv(3 * (FP_TYPE)sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
      armijoZ *= (FP_TYPE)1.1;
    else;

#ifdef VERIFY
    fileName = outDir + "/verify/Z" + std::to_string(i);
    f.open(fileName);
    f << "Z_check = [" << model.params.Z << "];" << std::endl;
    f.close();
#endif 
#ifdef DUMP
    fileName = outDir + "/dump/Z" + std::to_string(i);
    f.open(fileName);
    f << model.params.Z.format(eigen_tsv);
    f.close();
#endif 

    timer.nextTime("starting optimization w.r.t. B");
    LOG_INFO("Optimizing w.r.t. prototype matrix (B)...");

#ifdef BTLS
    etaB = armijoB * btls<BMatType>
      ([&model, &data, &WX]
       (const BMatType& B, const Eigen::Index begin, const Eigen::Index end)
       ->FP_TYPE {return L(model.params.Z, data.Ytrain, gaussianKernel(B, WX, model.hyperParams.gamma, begin, end), begin, end); },
       [&model, &data, &WX]
       (const BMatType& B, const Eigen::Index begin, const Eigen::Index end)
       ->MatrixXuf
      {return gradL_B(B, data.Ytrain, model.params.Z, WX,
		      gaussianKernel(B, WX, model.hyperParams.gamma, begin, end),
		      model.hyperParams.gamma, begin, end); },
       std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaB),
       model.params.B, n, bs, (etaB/armijoB)*2);
#else    
    for (auto j = 0; j < eta.size(); ++j) {
      Eigen::Index idx1 = (j*(Eigen::Index)hessianbs) % n;
      Eigen::Index idx2 = ((j + 1)*(Eigen::Index)hessianbs) % n;
      if (idx2 <= idx1) idx2 = n;

      gtmpB = gradL_B(model.params.B, data.Ytrain, model.params.Z, WX,
        gaussianKernel(model.params.B, WX, model.hyperParams.gamma, idx1, idx2),
        model.hyperParams.gamma, idx1, idx2);

      MatrixXuf gtmpBThresh = gtmpB;
      hardThrsd(gtmpBThresh, model.hyperParams.lambdaB);

      Btmp = model.params.B - 0.001*safeDiv(model.params.B.cwiseAbs().maxCoeff(), gtmpB.cwiseAbs().maxCoeff())*gtmpBThresh;

      gtmpB -= gradL_B(Btmp, data.Ytrain, model.params.Z, WX,
        gaussianKernel(Btmp, WX, model.hyperParams.gamma, idx1, idx2),
        model.hyperParams.gamma, idx1, idx2);

      if (gtmpB.norm() <= 1e-20L) {
        LOG_WARNING("Difference between consecutive gradients of B has become really low.");
        eta(j) = 1.0;
      }
      else
        eta(j) = safeDiv((Btmp - model.params.B).norm(), gtmpB.norm());
    }

    std::sort(eta.data(), eta.data() + eta.size());
    etaB = armijoB * eta(4);
#endif
    //LOG_INFO("Step-length estimate for gradB = " + std::to_string(etaB));

    accProxSGD<BMatType>
      (//[&model.params.Z, &data.Ytrain, &WX, &model.hyperParams] 
        [&model, &data, &WX]
    (const BMatType& B, const Eigen::Index begin, const Eigen::Index end)
      ->FP_TYPE {return L(model.params.Z, data.Ytrain, gaussianKernel(B, WX, model.hyperParams.gamma, begin, end), begin, end); },
      //[&WX, &data.Ytrain, &model.params.Z, &model.hyperParams]
      [&model, &data, &WX]
    (const BMatType& B, const Eigen::Index begin, const Eigen::Index end)
      ->MatrixXuf
    {return gradL_B(B, data.Ytrain, model.params.Z, WX,
      gaussianKernel(B, WX, model.hyperParams.gamma, begin, end),
      model.hyperParams.gamma, begin, end); },
      std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaB),
      model.params.B, epochs, n, bs, etaB, etaUpdate);
    timer.nextTime("ending gradB");
    //LOG_INFO("Final step-length for gradB = " + std::to_string(etaB));

    fOld = fNew;
#ifdef XML
    fNew = batchEvaluate(model.params.Z, Y_sub, Yvalidation_sub, model.params.B, WX_sub, WXvalidation_sub, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 9);
#else 
    fNew = batchEvaluate(model.params.Z, data.Ytrain, data.Yvalidation, model.params.B, WX, WXvalidation, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 9);
#endif

    if (fNew >= fOld * (1 + safeDiv(sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
      armijoB *= (FP_TYPE)0.7;
    else if (fNew <= fOld * (1 - safeDiv(3 * sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
      armijoB *= (FP_TYPE)1.1;
    else;

#ifdef VERIFY
    fileName = outDir + "/verify/B" + std::to_string(i);
    f.open(fileName);
    f << "B_check = [" << model.params.B << "];" << std::endl;
    f.close();
#endif 
#ifdef DUMP
    fileName = outDir + "/dump/B" + std::to_string(i);
    f.open(fileName);
    f << model.params.B.format(eigen_tsv);
    f.close();
#endif 
  }
}