in src/fasttext.cc [323:368]
void FastText::quantize(const Args& qargs, const TrainCallback& callback) {
if (args_->model != model_name::sup) {
throw std::invalid_argument(
"For now we only support quantization of supervised models");
}
args_->input = qargs.input;
args_->qout = qargs.qout;
args_->output = qargs.output;
std::shared_ptr<DenseMatrix> input =
std::dynamic_pointer_cast<DenseMatrix>(input_);
std::shared_ptr<DenseMatrix> output =
std::dynamic_pointer_cast<DenseMatrix>(output_);
bool normalizeGradient = (args_->model == model_name::sup);
if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) {
auto idx = selectEmbeddings(qargs.cutoff);
dict_->prune(idx);
std::shared_ptr<DenseMatrix> ninput =
std::make_shared<DenseMatrix>(idx.size(), args_->dim);
for (auto i = 0; i < idx.size(); i++) {
for (auto j = 0; j < args_->dim; j++) {
ninput->at(i, j) = input->at(idx[i], j);
}
}
input = ninput;
if (qargs.retrain) {
args_->epoch = qargs.epoch;
args_->lr = qargs.lr;
args_->thread = qargs.thread;
args_->verbose = qargs.verbose;
auto loss = createLoss(output_);
model_ = std::make_shared<Model>(input, output, loss, normalizeGradient);
startThreads(callback);
}
}
input_ = std::make_shared<QuantMatrix>(
std::move(*(input.get())), qargs.dsub, qargs.qnorm);
if (args_->qout) {
output_ = std::make_shared<QuantMatrix>(
std::move(*(output.get())), 2, qargs.qnorm);
}
quant_ = true;
auto loss = createLoss(output_);
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
}