export function setUpUI()

in lstm-text-generation/ui.js [136:423]


export function setUpUI() {
  /**
   * Refresh the status of locally saved model (in IndexedDB).
   */
  async function refreshLocalModelStatus() {
    const modelInfo = await textGenerator.checkStoredModelStatus();
    if (modelInfo == null) {
      modelAvailableInfo.innerText =
          `No locally saved model for "${textGenerator.modelIdentifier()}".`;
      createOrLoadModelButton.textContent = 'Create model';
      deleteModelButton.disabled = true;
      enableModelParameterControls();
    } else {
      modelAvailableInfo.innerText =
          `Saved @ ${modelInfo.dateSaved.toISOString()}`;
      createOrLoadModelButton.textContent = 'Load model';
      deleteModelButton.disabled = false;
      disableModelParameterControls();
    }
    createOrLoadModelButton.disabled = false;
  }

  function disableModelButtons() {
    createOrLoadModelButton.disabled = true;
    deleteModelButton.disabled = true;
    trainModelButton.disabled = true;
    generateTextButton.disabled = true;
  }

  function enableModelButtons() {
    createOrLoadModelButton.disabled = false;
    deleteModelButton.disabled = false;
    trainModelButton.disabled = false;
    generateTextButton.disabled = false;
  }

  /**
   * Use `textGenerator` to generate random text, show the characters on the
   * screen as they are generated one by one.
   */
  async function generateText() {
    try {
      disableModelButtons();

      if (textGenerator == null) {
        logStatus('ERROR: Please load text data set first.');
        return;
      }
      const generateLength = parseInt(generateLengthInput.value);
      const temperature = parseFloat(temperatureInput.value);
      if (!(generateLength > 0)) {
        logStatus(
            `ERROR: Invalid generation length: ${generateLength}. ` +
            `Generation length must be a positive number.`);
        enableModelButtons();
        return;
      }
      if (!(temperature > 0 && temperature <= 1)) {
        logStatus(
            `ERROR: Invalid temperature: ${temperature}. ` +
            `Temperature must be a positive number.`);
        enableModelButtons();
        return;
      }

      let seedSentence;
      let seedSentenceIndices;
      if (seedTextInput.value.length === 0) {
        // Seed sentence is not specified yet. Get it from the data.
        [seedSentence, seedSentenceIndices] = textData.getRandomSlice();
        seedTextInput.value = seedSentence;
      } else {
        seedSentence = seedTextInput.value;
        if (seedSentence.length < textData.sampleLen()) {
          logStatus(
              `ERROR: Seed text must have a length of at least ` +
              `${textData.sampleLen()}, but has a length of ` +
              `${seedSentence.length}.`);
          enableModelButtons();
          return;
        }
        seedSentence = seedSentence.slice(
            seedSentence.length - textData.sampleLen(), seedSentence.length);
        seedSentenceIndices = textData.textToIndices(seedSentence);
      }

      const sentence = await textGenerator.generateText(
          seedSentenceIndices, generateLength, temperature);
      generatedTextInput.value = sentence;
      const status = 'Done generating text.';
      logStatus(status);
      textGenerationStatus.value = status;

      enableModelButtons();

      return sentence;
    } catch (err) {
      logStatus(`ERROR: Failed to generate text: ${err.message}, ${err.stack}`);
    }
  }

  function disableModelParameterControls() {
    lstmLayersSizesInput.disabled = true;
  }

  function enableModelParameterControls() {
    lstmLayersSizesInput.disabled = false;
  }

  function updateModelParameterControls(lstmLayerSizes) {
    lstmLayersSizesInput.value = lstmLayerSizes;
  }

  function updateTextInputParameters() {
    Object.keys(TEXT_DATA_URLS).forEach(key => {
      var opt = document.createElement('option');
      opt.value = key;
      opt.innerHTML = TEXT_DATA_URLS[key].needle;
      textDataSelect.appendChild(opt);
    });
  }

  function hashCode(str) {
    let hash = 5381, i = str.length;
    while (i) {
      hash = (hash * 33) ^ str.charCodeAt(--i);
    }
    return hash >>> 0;
  }

  /**
   * Initialize UI state.
   */

  disableModelParameterControls();

  /**
   * Update Text Inputs
   */
  updateTextInputParameters();

  /**
   * Wire up UI callbacks.
   */

  loadTextDataButton.addEventListener('click', async () => {
    textDataSelect.disabled = true;
    loadTextDataButton.disabled = true;
    let dataIdentifier = textDataSelect.value;
    const url = TEXT_DATA_URLS[dataIdentifier].url;
    if (testText.value.length === 0) {
      try {
        logStatus(`Loading text data from URL: ${url} ...`);
        const response = await fetch(url);
        const textString = await response.text();
        testText.value = textString;
        logStatus(
            `Done loading text data ` +
            `(length=${(textString.length / 1024).toFixed(1)}k). ` +
            `Next, please load or create model.`);
      } catch (err) {
        logStatus('Failed to load text data: ' + err.message);
      }
      if (testText.value.length === 0) {
        logStatus('ERROR: Empty text data.');
        return;
      }
    } else {
      dataIdentifier = hashCode(testText.value);
    }
    textData =
        new TextData(dataIdentifier, testText.value, sampleLen, sampleStep);
    textGenerator = new SaveableLSTMTextGenerator(textData);
    await refreshLocalModelStatus();
  });

  createOrLoadModelButton.addEventListener('click', async () => {
    createOrLoadModelButton.disabled = true;
    if (textGenerator == null) {
      createOrLoadModelButton.disabled = false;
      logStatus('ERROR: Please load text data set first.');
      return;
    }

    if (await textGenerator.checkStoredModelStatus()) {
      // Load locally-saved model.
      logStatus('Loading model from IndexedDB... Please wait.');
      await textGenerator.loadModel();
      updateModelParameterControls(textGenerator.lstmLayerSizes());
      logStatus(
          'Done loading model from IndexedDB. ' +
          'Now you can train the model further or use it to generate text.');
    } else {
      // Create model from scratch.
      logStatus('Creating model... Please wait.');
      const lstmLayerSizes = lstmLayersSizesInput.value.trim().split(',').map(
          s => parseInt(s));

      // Sanity check on the LSTM layer sizes.
      if (lstmLayerSizes.length === 0) {
        logStatus('ERROR: Invalid LSTM layer sizes.');
        return;
      }
      for (let i = 0; i < lstmLayerSizes.length; ++i) {
        const lstmLayerSize = lstmLayerSizes[i];
        if (!(lstmLayerSize > 0)) {
          logStatus(
              `ERROR: lstmLayerSizes must be a positive integer, ` +
              `but got ${lstmLayerSize} for layer ${i + 1} ` +
              `of ${lstmLayerSizes.length}.`);
          return;
        }
      }

      await textGenerator.createModel(lstmLayerSizes);
      logStatus(
          'Done creating model. ' +
          'Now you can train the model or use it to generate text.');
    }

    trainModelButton.disabled = false;
    generateTextButton.disabled = false;
  });

  deleteModelButton.addEventListener('click', async () => {
    if (textGenerator == null) {
      logStatus('ERROR: Please load text data set first.');
      return;
    }
    if (confirm(
            `Are you sure you want to delete the model ` +
            `'${textGenerator.modelIdentifier()}'?`)) {
      console.log(await textGenerator.removeModel());
      await refreshLocalModelStatus();
    }
  });

  trainModelButton.addEventListener('click', async () => {
    if (textGenerator == null) {
      logStatus('ERROR: Please load text data set first.');
      return;
    }

    const numEpochs = parseInt(epochsInput.value);
    if (!(numEpochs > 0)) {
      logStatus(`ERROR: Invalid number of epochs: ${numEpochs}`);
      return;
    }
    const examplesPerEpoch = parseInt(examplesPerEpochInput.value);
    if (!(examplesPerEpoch > 0)) {
      logStatus(`ERROR: Invalid examples per epoch: ${examplesPerEpoch}`);
      return;
    }
    const batchSize = parseInt(batchSizeInput.value);
    if (!(batchSize > 0)) {
      logStatus(`ERROR: Invalid batch size: ${batchSize}`);
      return;
    }
    const validationSplit = parseFloat(validationSplitInput.value);
    if (!(validationSplit >= 0 && validationSplit < 1)) {
      logStatus(`ERROR: Invalid validation split: ${validationSplit}`);
      return;
    }
    const learningRate = parseFloat(learningRateInput.value);
    if (!(learningRate > 0)) {
      logStatus(`ERROR: Invalid learning rate: ${learningRate}`);
      return;
    }

    textGenerator.compileModel(learningRate);
    disableModelButtons();
    await textGenerator.fitModel(
        numEpochs, examplesPerEpoch, batchSize, validationSplit);
    console.log(await textGenerator.saveModel());
    await refreshLocalModelStatus();
    enableModelButtons();

    await generateText();
  });

  generateTextButton.addEventListener('click', async () => {
    if (textGenerator == null) {
      logStatus('ERROR: Load text data set first.');
      return;
    }
    await generateText();
  });
}