in cpp/src/ProtoNN/ProtoNNFunctions.cpp [575:992]
void EdgeML::altMinSGD(
const EdgeML::Data& data,
EdgeML::ProtoNN::ProtoNNModel& model,
FP_TYPE *const stats,
const std::string& outDir)
{
// This allows us to make mkl-blas calls on Eigen matrices
assert(sizeof(MKL_INT) == sizeof(Eigen::Index));
Timer timer("altMinSGD");
assert(sizeof(Eigen::Index) == sizeof(dataCount_t));
/*
[~, n] = size(X); [l, m] = size(Z);
if isempty(iters)
iters = 50;
end
tol = 1e-5;
% alternating minimization
counter = 1;epochs = 10; batchSize = 512;
batchSize = 128; epochs = 3; sgdTol = 0.02;
learning_rate = 0;
learning_rate_Z = 0.2; learning_rate_B = 0.2; learning_rate_W = 0.2;
*/
dataCount_t n = data.Xtrain.cols();
int epochs = model.hyperParams.epochs;
FP_TYPE sgdTol = (FP_TYPE) 0.02;
dataCount_t bs = std::min((dataCount_t)model.hyperParams.batchSize, (dataCount_t)n);
#ifdef XML
dataCount_t hessianbs = std::min((dataCount_t)(1 << 10), bs);
#else
dataCount_t hessianbs = bs;
#endif
//const FP_TYPE hessianAdjustment = 1.0; // (FP_TYPE)hessianbs / (FP_TYPE)bs;
int etaUpdate = 0;
FP_TYPE armijoZ((FP_TYPE)0.2), armijoB((FP_TYPE)0.2), armijoW((FP_TYPE)0.2);
FP_TYPE fOld, fNew, etaZ(1), etaB(1), etaW(1);
LOG_INFO("\nComputing model size assuming 4 bytes per entry for matrices with sparsity > 0.5 and 8 bytes per entry for matrices with sparsity <= 0.5 (to store sparse matrices, we require about 4 bytes for the index information)...");
LOG_INFO("Model size in kB = " + std::to_string(computeModelSizeInkB(model.hyperParams.lambdaW, model.hyperParams.lambdaZ, model.hyperParams.lambdaB, model.params.W, model.params.Z, model.params.B)));
MatrixXuf WX(model.params.W.rows(), data.Xtrain.cols());
mm(WX, model.params.W, CblasNoTrans, data.Xtrain, CblasNoTrans, 1.0, 0.0L);
MatrixXuf WXvalidation(model.params.W.rows(), data.Xvalidation.cols());
if (data.Xvalidation.cols() > 0) {
mm(WXvalidation, model.params.W, CblasNoTrans, data.Xvalidation, CblasNoTrans, 1.0, 0.0L);
}
#ifdef XML
dataCount_t numEvalTrain = std::min((dataCount_t)20000, (dataCount_t)data.Xtrain.cols());
MatrixXuf WX_sub(WX.rows(), numEvalTrain);
SparseMatrixuf Y_sub(data.Ytrain.rows(), numEvalTrain);
SparseMatrixuf X_sub(data.Xtrain.rows(), numEvalTrain);
randPick(data.Xtrain, X_sub);
randPick(data.Ytrain, Y_sub);
mm(WX_sub, model.params.W, CblasNoTrans, X_sub, CblasNoTrans, 1.0, 0.0L);
dataCount_t numEvalValidation= std::min((dataCount_t)10000, (dataCount_t)data.Xvalidation.cols());
MatrixXuf WXvalidation_sub(WX.rows(), numEvalValidation);
SparseMatrixuf Yvalidation_sub(data.Yvalidation.rows(), numEvalValidation);
SparseMatrixuf Xvalidation_sub(data.Xvalidation.rows(), numEvalValidation);
if (data.Xvalidation.cols() > 0) {
randPick(data.Xvalidation, Xvalidation_sub);
randPick(data.Yvalidation, Yvalidation_sub);
mm(WXvalidation_sub, model.params.W, CblasNoTrans, Xvalidation_sub, CblasNoTrans, 1.0, 0.0L);
}
#endif
timer.nextTime("starting evaluation");
LOG_INFO("\nInitial stats...");
#ifdef XML
fNew = batchEvaluate(model.params.Z, Y_sub, Yvalidation_sub, model.params.B, WX_sub, WXvalidation_sub, model.hyperParams.gamma, model.hyperParams.problemType, stats);
#else
fNew = batchEvaluate(model.params.Z, data.Ytrain, data.Yvalidation, model.params.B, WX, WXvalidation, model.hyperParams.gamma, model.hyperParams.problemType, stats);
#endif
timer.nextTime("evaluating");
VectorXf eta = VectorXf::Zero(10, 1);
MatrixXuf gtmpW(model.params.W.rows(), model.params.W.cols());
WMatType Wtmp(model.params.W.rows(), model.params.W.cols());
MatrixXuf gtmpB(model.params.B.rows(), model.params.B.cols());
BMatType Btmp(model.params.B.rows(), model.params.B.cols());
MatrixXuf gtmpZ(model.params.Z.rows(), model.params.Z.cols());
ZMatType Ztmp(model.params.Z.rows(), model.params.Z.cols());
#if defined(DUMP) || defined(VERIFY)
std::ofstream f;
std::string fileName;
#endif
LOG_INFO("\nStarting optimization. Number of outer iterations (altMinSGD) = " + std::to_string(model.hyperParams.iters));
// for i = 1 : iters
for (int i = 0; i < model.hyperParams.iters; ++i) {
LOG_INFO(
"\n=========================== " + std::to_string(i) + "\n"
+ "On iter " + std::to_string(i) + "\n" +
+"=========================== " + std::to_string(i));
timer.nextTime("starting optimization w.r.t. W");
LOG_INFO("Optimizing w.r.t. projection matrix (W)...");
#ifdef BTLS
etaW = armijoW * btls<WMatType>
([&model, &data] (const WMatType& W, const Eigen::Index begin, const Eigen::Index end) ->FP_TYPE {
MatrixXuf WX = MatrixXuf::Zero(W.rows(), end - begin);
SparseMatrixuf XMiddle = data.Xtrain.middleCols(begin, end - begin);
mm(WX, W, CblasNoTrans, XMiddle,
CblasNoTrans, 1.0, 0.0L);
return L(model.params.Z, data.Ytrain,
gaussianKernel(model.params.B, WX, model.hyperParams.gamma),
begin, end);
},
[&model, &data]
(const WMatType& W, const Eigen::Index begin, const Eigen::Index end)
->MatrixXuf {
MatrixXuf WX = MatrixXuf::Zero(W.rows(), end - begin);
SparseMatrixuf XMiddle = data.Xtrain.middleCols(begin, end - begin);
mm(WX, W, CblasNoTrans,
XMiddle,
CblasNoTrans, 1.0, 0.0L);
return gradL_W(model.params.B, data.Ytrain, model.params.Z, W, data.Xtrain,
gaussianKernel(model.params.B, WX, model.hyperParams.gamma),
model.hyperParams.gamma, begin, end);
},
std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaW),
model.params.W, n, bs, (etaW/armijoW)*2);
#else
for (auto j = 0; j < eta.size(); ++j) {
Eigen::Index idx1 = (j*(Eigen::Index)hessianbs) % n;
Eigen::Index idx2 = ((j + 1)*(Eigen::Index)hessianbs) % n;
//assert (((j+1)*(Eigen::Index)hessianbs) < n);
if (idx2 <= idx1) idx2 = n;
gtmpW = gradL_W(model.params.B, data.Ytrain, model.params.Z, model.params.W, data.Xtrain,
gaussianKernel(model.params.B, WX, model.hyperParams.gamma, idx1, idx2),
model.hyperParams.gamma, idx1, idx2);
MatrixXuf gtmpWThresh = gtmpW;
hardThrsd(gtmpWThresh, model.hyperParams.lambdaW);
Wtmp = model.params.W
- 0.001*safeDiv(model.params.W.cwiseAbs().maxCoeff(), gtmpW.cwiseAbs().maxCoeff()) * gtmpWThresh;
gtmpW -= gradL_W(model.params.B, data.Ytrain, model.params.Z, Wtmp, data.Xtrain,
gaussianKernel(model.params.B, Wtmp*data.Xtrain.middleCols(idx1, idx2 - idx1), model.hyperParams.gamma),
model.hyperParams.gamma, idx1, idx2);
if (gtmpW.norm() <= 1e-20L) {
LOG_WARNING("Difference between consecutive gradients of W has become really low.");
eta(j) = 1.0;
}
else
eta(j) = safeDiv((Wtmp - model.params.W).norm(), gtmpW.norm());
}
std::sort(eta.data(), eta.data() + eta.size());
etaW = armijoW * eta(4);
#endif
//LOG_INFO("Step-length estimate for gradW = " + std::to_string(etaW));
accProxSGD<WMatType>
(//[&model.params.Z, &data.Ytrain, &model.params.B, &data.Xtrain, &model.hyperParams] TODO: Figure out the elegant way of getting this to work
[&model, &data]
(const WMatType& W, const Eigen::Index begin, const Eigen::Index end)
->FP_TYPE {
MatrixXuf WX = MatrixXuf::Zero(W.rows(), end - begin);
SparseMatrixuf XMiddle = data.Xtrain.middleCols(begin, end - begin);
mm(WX, W, CblasNoTrans,
XMiddle,
CblasNoTrans, 1.0, 0.0L);
return L(model.params.Z, data.Ytrain, gaussianKernel(model.params.B, WX, model.hyperParams.gamma), begin, end);
},
// [&(model.params.B), &(data.Ytrain), &(model.params.Z), &(data.Xtrain), &(model.hyperParams)]
[&model, &data]
(const WMatType& W, const Eigen::Index begin, const Eigen::Index end)
->MatrixXuf {
MatrixXuf WX = MatrixXuf::Zero(W.rows(), end - begin);
SparseMatrixuf XMiddle = data.Xtrain.middleCols(begin, end - begin);
mm(WX, W, CblasNoTrans,
XMiddle,
CblasNoTrans, 1.0, 0.0L);
return gradL_W(model.params.B, data.Ytrain, model.params.Z, W, data.Xtrain,
gaussianKernel(model.params.B, WX, model.hyperParams.gamma),
model.hyperParams.gamma, begin, end);
},
std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaW),
model.params.W, epochs, n, bs, etaW, etaUpdate);
timer.nextTime("ending gradW");
//LOG_INFO("Final step-length for gradW = " + std::to_string(etaW));
mm(WX, model.params.W, CblasNoTrans, data.Xtrain, CblasNoTrans, 1.0, 0.0L);
if (data.Xvalidation.cols() > 0) {
mm(WXvalidation, model.params.W, CblasNoTrans, data.Xvalidation, CblasNoTrans, 1.0, 0.0L);
}
fOld = fNew;
#ifdef XML
mm(WX_sub, model.params.W, CblasNoTrans, X_sub, CblasNoTrans, 1.0, 0.0L);
if (data.Xvalidation.cols() > 0) {
mm(WXvalidation_sub, model.params.W, CblasNoTrans, Xvalidation_sub, CblasNoTrans, 1.0, 0.0L);
}
fNew = batchEvaluate(model.params.Z, Y_sub, Yvalidation_sub, model.params.B, WX_sub, WXvalidation_sub, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 3);
#else
fNew = batchEvaluate(model.params.Z, data.Ytrain, data.Yvalidation, model.params.B, WX, WXvalidation, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 3);
#endif
if (fNew >= fOld * (1 + safeDiv(sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
armijoW *= (FP_TYPE)0.7;
else if (fNew <= fOld * (1 - safeDiv(3 * sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
armijoW *= (FP_TYPE)1.1;
else;
#ifdef VERIFY
fileName = outDir + "/verify/W" + std::to_string(i);
f.open(fileName);
f << "W_check = [" << model.params.W << "];" << std::endl;
f.close();
#endif
#ifdef DUMP
fileName = outDir + "/dump/W" + std::to_string(i);
f.open(fileName);
f << model.params.W.format(eigen_tsv);
f.close();
#endif
timer.nextTime("starting optimization w.r.t. Z");
LOG_INFO("Optimizing w.r.t. prototype-label matrix (Z)...");
#ifdef BTLS
etaZ = armijoZ * btls<ZMatType>
([&model, &data, &WX]
(const ZMatType& Z, const Eigen::Index begin, const Eigen::Index end)
->FP_TYPE {return L(Z, data.Ytrain, gaussianKernel(model.params.B, WX, model.hyperParams.gamma, begin, end), begin, end); },
[&model, &data, &WX]
(const ZMatType& Z, const Eigen::Index begin, const Eigen::Index end)
->MatrixXuf
{return gradL_Z(Z, data.Ytrain,
gaussianKernel(model.params.B, WX, model.hyperParams.gamma, begin, end),
begin, end); },
std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaZ),
model.params.Z, n, bs, (etaZ/armijoZ)*2);
#else
for (auto j = 0; j < eta.size(); ++j) { //eta.size(); ++j) {
Eigen::Index idx1 = (j*(Eigen::Index)hessianbs) % n;
Eigen::Index idx2 = ((j + 1)*(Eigen::Index)hessianbs) % n;
if (idx2 <= idx1) idx2 = n;
gtmpZ = gradL_Z(model.params.Z, data.Ytrain,
gaussianKernel(model.params.B, WX, model.hyperParams.gamma, idx1, idx2),
idx1, idx2);
MatrixXuf gtmpZThresh = gtmpZ;
hardThrsd(gtmpZThresh, model.hyperParams.lambdaZ);
// Below: Ztmp = Z - 0.001*safeDiv(maxAbsVal(Z), gtmpZ.cwiseAbs().maxCoeff()) * gtmpZThresh;
gtmpZThresh *= (FP_TYPE)-0.001*safeDiv(maxAbsVal(model.params.Z), gtmpZ.cwiseAbs().maxCoeff());
typeMismatchAssign(Ztmp, gtmpZThresh);
Ztmp += model.params.Z;
gtmpZ -= gradL_Z(Ztmp, data.Ytrain,
gaussianKernel(model.params.B, WX, model.hyperParams.gamma, idx1, idx2),
idx1, idx2);
if (gtmpZ.norm() <= 1e-20L) {
LOG_WARNING("Difference between consecutive gradients of Z has become really low.");
eta(j) = 1.0;
}
else
eta(j) = safeDiv((Ztmp - model.params.Z).norm(), gtmpZ.norm());
}
std::sort(eta.data(), eta.data() + eta.size());
etaZ = armijoZ * eta(4);
#endif
//LOG_INFO("Step-length estimate for gradZ = " + std::to_string(etaZ));
accProxSGD<ZMatType>
(//[&model.params.B, &data.Ytrain, &WX, &model.hyperParams]
[&model, &data, &WX]
(const ZMatType& Z, const Eigen::Index begin, const Eigen::Index end)
->FP_TYPE {return L(Z, data.Ytrain, gaussianKernel(model.params.B, WX, model.hyperParams.gamma, begin, end), begin, end); },
//[&WX, &data.Ytrain, &model.params.B, &model.hyperParams]
[&model, &data, &WX]
(const ZMatType& Z, const Eigen::Index begin, const Eigen::Index end)
->MatrixXuf
{return gradL_Z(Z, data.Ytrain,
gaussianKernel(model.params.B, WX, model.hyperParams.gamma, begin, end),
begin, end); },
std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaZ),
model.params.Z, epochs, n, bs, etaZ, etaUpdate);
timer.nextTime("ending gradZ");
//LOG_INFO("Final step-length for gradZ = " + std::to_string(etaZ));
fOld = fNew;
#ifdef XML
fNew = batchEvaluate(model.params.Z, Y_sub, Yvalidation_sub, model.params.B, WX_sub, WXvalidation_sub, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 6);
#else
fNew = batchEvaluate(model.params.Z, data.Ytrain, data.Yvalidation, model.params.B, WX, WXvalidation, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 6);
#endif
if (fNew >= fOld * (1 + safeDiv(sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
armijoZ *= (FP_TYPE)0.7;
else if (fNew <= fOld * (1 - safeDiv(3 * (FP_TYPE)sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
armijoZ *= (FP_TYPE)1.1;
else;
#ifdef VERIFY
fileName = outDir + "/verify/Z" + std::to_string(i);
f.open(fileName);
f << "Z_check = [" << model.params.Z << "];" << std::endl;
f.close();
#endif
#ifdef DUMP
fileName = outDir + "/dump/Z" + std::to_string(i);
f.open(fileName);
f << model.params.Z.format(eigen_tsv);
f.close();
#endif
timer.nextTime("starting optimization w.r.t. B");
LOG_INFO("Optimizing w.r.t. prototype matrix (B)...");
#ifdef BTLS
etaB = armijoB * btls<BMatType>
([&model, &data, &WX]
(const BMatType& B, const Eigen::Index begin, const Eigen::Index end)
->FP_TYPE {return L(model.params.Z, data.Ytrain, gaussianKernel(B, WX, model.hyperParams.gamma, begin, end), begin, end); },
[&model, &data, &WX]
(const BMatType& B, const Eigen::Index begin, const Eigen::Index end)
->MatrixXuf
{return gradL_B(B, data.Ytrain, model.params.Z, WX,
gaussianKernel(B, WX, model.hyperParams.gamma, begin, end),
model.hyperParams.gamma, begin, end); },
std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaB),
model.params.B, n, bs, (etaB/armijoB)*2);
#else
for (auto j = 0; j < eta.size(); ++j) {
Eigen::Index idx1 = (j*(Eigen::Index)hessianbs) % n;
Eigen::Index idx2 = ((j + 1)*(Eigen::Index)hessianbs) % n;
if (idx2 <= idx1) idx2 = n;
gtmpB = gradL_B(model.params.B, data.Ytrain, model.params.Z, WX,
gaussianKernel(model.params.B, WX, model.hyperParams.gamma, idx1, idx2),
model.hyperParams.gamma, idx1, idx2);
MatrixXuf gtmpBThresh = gtmpB;
hardThrsd(gtmpBThresh, model.hyperParams.lambdaB);
Btmp = model.params.B - 0.001*safeDiv(model.params.B.cwiseAbs().maxCoeff(), gtmpB.cwiseAbs().maxCoeff())*gtmpBThresh;
gtmpB -= gradL_B(Btmp, data.Ytrain, model.params.Z, WX,
gaussianKernel(Btmp, WX, model.hyperParams.gamma, idx1, idx2),
model.hyperParams.gamma, idx1, idx2);
if (gtmpB.norm() <= 1e-20L) {
LOG_WARNING("Difference between consecutive gradients of B has become really low.");
eta(j) = 1.0;
}
else
eta(j) = safeDiv((Btmp - model.params.B).norm(), gtmpB.norm());
}
std::sort(eta.data(), eta.data() + eta.size());
etaB = armijoB * eta(4);
#endif
//LOG_INFO("Step-length estimate for gradB = " + std::to_string(etaB));
accProxSGD<BMatType>
(//[&model.params.Z, &data.Ytrain, &WX, &model.hyperParams]
[&model, &data, &WX]
(const BMatType& B, const Eigen::Index begin, const Eigen::Index end)
->FP_TYPE {return L(model.params.Z, data.Ytrain, gaussianKernel(B, WX, model.hyperParams.gamma, begin, end), begin, end); },
//[&WX, &data.Ytrain, &model.params.Z, &model.hyperParams]
[&model, &data, &WX]
(const BMatType& B, const Eigen::Index begin, const Eigen::Index end)
->MatrixXuf
{return gradL_B(B, data.Ytrain, model.params.Z, WX,
gaussianKernel(B, WX, model.hyperParams.gamma, begin, end),
model.hyperParams.gamma, begin, end); },
std::bind(hardThrsd, std::placeholders::_1, model.hyperParams.lambdaB),
model.params.B, epochs, n, bs, etaB, etaUpdate);
timer.nextTime("ending gradB");
//LOG_INFO("Final step-length for gradB = " + std::to_string(etaB));
fOld = fNew;
#ifdef XML
fNew = batchEvaluate(model.params.Z, Y_sub, Yvalidation_sub, model.params.B, WX_sub, WXvalidation_sub, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 9);
#else
fNew = batchEvaluate(model.params.Z, data.Ytrain, data.Yvalidation, model.params.B, WX, WXvalidation, model.hyperParams.gamma, model.hyperParams.problemType, stats + 9 * i + 9);
#endif
if (fNew >= fOld * (1 + safeDiv(sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
armijoB *= (FP_TYPE)0.7;
else if (fNew <= fOld * (1 - safeDiv(3 * sgdTol*(FP_TYPE)log(3), (FP_TYPE)log(2 + i))))
armijoB *= (FP_TYPE)1.1;
else;
#ifdef VERIFY
fileName = outDir + "/verify/B" + std::to_string(i);
f.open(fileName);
f << "B_check = [" << model.params.B << "];" << std::endl;
f.close();
#endif
#ifdef DUMP
fileName = outDir + "/dump/B" + std::to_string(i);
f.open(fileName);
f << model.params.B.format(eigen_tsv);
f.close();
#endif
}
}