in sentiment/train.js [160:244]
async function main() {
const args = parseArguments();
let tfn;
if (args.gpu) {
console.log('Using GPU for training');
tfn = require('@tensorflow/tfjs-node-gpu');
} else {
console.log('Using CPU for training');
tfn = require('@tensorflow/tfjs-node');
}
console.log('Loading data...');
const multihot = args.modelType === 'multihot';
const {xTrain, yTrain, xTest, yTest} =
await loadData(args.numWords, args.maxLen, multihot);
console.log('Building model...');
const model = buildModel(
args.modelType, args.maxLen, args.numWords, args.embeddingSize);
model.compile({
loss: 'binaryCrossentropy',
optimizer: args.optimizer,
metrics: ['acc']
});
model.summary();
console.log('Training model...');
await model.fit(xTrain, yTrain, {
epochs: args.epochs,
batchSize: args.batchSize,
validationSplit: args.validationSplit,
callbacks: args.logDir == null ? null : tfn.node.tensorBoard(args.logDir, {
updateFreq: args.logUpdateFreq
})
});
console.log('Evaluating model...');
const [testLoss, testAcc] =
model.evaluate(xTest, yTest, {batchSize: args.batchSize});
console.log(`Evaluation loss: ${(await testLoss.data())[0].toFixed(4)}`);
console.log(`Evaluation accuracy: ${(await testAcc.data())[0].toFixed(4)}`);
// Save model.
let metadata;
if (args.modelSaveDir != null && args.modelSaveDir.length > 0) {
if (multihot) {
console.warn(
'Skipping saving of multihot model, which is not supported.');
} else {
// Create base directory first.
shelljs.mkdir('-p', args.modelSaveDir);
// Load metadata template.
console.log('Loading metadata template...');
metadata = await loadMetadataTemplate();
// Save metadata.
metadata.epochs = args.epochs;
metadata.embedding_size = args.embeddingSize;
metadata.max_len = args.maxLen;
metadata.model_type = args.modelType;
metadata.batch_size = args.batchSize;
metadata.vocabulary_size = args.numWords;
const metadataPath = path.join(args.modelSaveDir, 'metadata.json');
fs.writeFileSync(metadataPath, JSON.stringify(metadata));
console.log(`Saved metadata to ${metadataPath}`);
// Save model artifacts.
await model.save(`file://${args.modelSaveDir}`);
console.log(`Saved model to ${args.modelSaveDir}`);
}
}
if (args.embeddingFilesPrefix != null &&
args.embeddingFilesPrefix.length > 0) {
if (metadata == null) {
metadata = await loadMetadataTemplate();
}
await writeEmbeddingMatrixAndLabels(
model, args.embeddingFilesPrefix, metadata.word_index,
metadata.index_from);
}
}