in snake-dqn/train.js [67:138]
export async function train(
agent, batchSize, gamma, learningRate, cumulativeRewardThreshold,
maxNumFrames, syncEveryFrames, savePath, logDir) {
let summaryWriter;
if (logDir != null) {
summaryWriter = tf.node.summaryFileWriter(logDir);
}
for (let i = 0; i < agent.replayBufferSize; ++i) {
agent.playStep();
}
// Moving averager: cumulative reward across 100 most recent 100 episodes.
const rewardAverager100 = new MovingAverager(100);
// Moving averager: fruits eaten across 100 most recent 100 episodes.
const eatenAverager100 = new MovingAverager(100);
const optimizer = tf.train.adam(learningRate);
let tPrev = new Date().getTime();
let frameCountPrev = agent.frameCount;
let averageReward100Best = -Infinity;
while (true) {
agent.trainOnReplayBatch(batchSize, gamma, optimizer);
const {cumulativeReward, done, fruitsEaten} = agent.playStep();
if (done) {
const t = new Date().getTime();
const framesPerSecond =
(agent.frameCount - frameCountPrev) / (t - tPrev) * 1e3;
tPrev = t;
frameCountPrev = agent.frameCount;
rewardAverager100.append(cumulativeReward);
eatenAverager100.append(fruitsEaten);
const averageReward100 = rewardAverager100.average();
const averageEaten100 = eatenAverager100.average();
console.log(
`Frame #${agent.frameCount}: ` +
`cumulativeReward100=${averageReward100.toFixed(1)}; ` +
`eaten100=${averageEaten100.toFixed(2)} ` +
`(epsilon=${agent.epsilon.toFixed(3)}) ` +
`(${framesPerSecond.toFixed(1)} frames/s)`);
if (summaryWriter != null) {
summaryWriter.scalar(
'cumulativeReward100', averageReward100, agent.frameCount);
summaryWriter.scalar('eaten100', averageEaten100, agent.frameCount);
summaryWriter.scalar('epsilon', agent.epsilon, agent.frameCount);
summaryWriter.scalar(
'framesPerSecond', framesPerSecond, agent.frameCount);
}
if (averageReward100 >= cumulativeRewardThreshold ||
agent.frameCount >= maxNumFrames) {
// TODO(cais): Save online network.
break;
}
if (averageReward100 > averageReward100Best) {
averageReward100Best = averageReward100;
if (savePath != null) {
if (!fs.existsSync(savePath)) {
mkdir('-p', savePath);
}
await agent.onlineNetwork.save(`file://${savePath}`);
console.log(`Saved DQN to ${savePath}`);
}
}
}
if (agent.frameCount % syncEveryFrames === 0) {
copyWeights(agent.targetNetwork, agent.onlineNetwork);
console.log('Sync\'ed weights from online network to target network');
}
}
}