# Word2Vec

**Learning Objectives**

1. Compile all steps into one function
2. Prepare training data for Word2Vec
3. Model and Training
4. Embedding lookup and analysis




## Introduction 
Word2Vec is not a singular algorithm, rather, it is a family of model architectures and optimizations that can be used to learn word embeddings from large datasets. Embeddings learned through Word2Vec have proven to be successful on a variety of downstream natural language processing tasks.

Note: This notebook is based on [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/pdf/1301.3781.pdf) and
[Distributed
Representations of Words and Phrases and their Compositionality](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). It is not an exact implementation of the papers. Rather, it is intended to illustrate the key ideas.

These papers proposed two methods for learning representations of words: 

*   **Continuous Bag-of-Words Model** which predicts the middle word based on surrounding context words. The context consists of a few words before and after the current (middle) word. This architecture is called a bag-of-words model as the order of words in the context is not important.
*   **Continuous Skip-gram Model** which predict words within a certain range before and after the current word in the same sentence. A worked example of this is given below.


You'll use the skip-gram approach in this notebook. First, you'll explore skip-grams and other concepts using a single sentence for illustration. Next, you'll train your own Word2Vec model on a small dataset. This notebook also contains code to export the trained embeddings and visualize them in the [TensorFlow Embedding Projector](http://projector.tensorflow.org/).


Each learning objective will correspond to a __#TODO__ in the [student lab notebook](../labs/word2vec.ipynb) -- try to complete that notebook first before reviewing this solution notebook.

## Skip-gram and Negative Sampling 

While a bag-of-words model predicts a word given the neighboring context, a skip-gram model predicts the context (or neighbors) of a word, given the word itself. The model is trained on skip-grams, which are n-grams that allow tokens to be skipped (see the diagram below for an example). The context of a word can be represented through a set of skip-gram pairs of `(target_word, context_word)` where `context_word` appears in the neighboring context of `target_word`. 

Consider the following sentence of 8 words.
> The wide road shimmered in the hot sun. 

The context words for each of the 8 words of this sentence are defined by a window size. The window size determines the span of words on either side of a `target_word` that can be considered `context word`. Take a look at this table of skip-grams for target words based on different window sizes.

Note: For this tutorial, a window size of *n* implies n words on each side with a total window span of 2*n+1 words across a word.

![word2vec_skipgrams](assets/word2vec_skipgram.png)

The training objective of the skip-gram model is to maximize the probability of predicting context words given the target word. For a sequence of words *w<sub>1</sub>, w<sub>2</sub>, ... w<sub>T</sub>*, the objective can be written as the average log probability

![word2vec_skipgram_objective](assets/word2vec_skipgram_objective.png)

where `c` is the size of the training context. The basic skip-gram formulation defines this probability using the softmax function.

![word2vec_full_softmax](assets/word2vec_full_softmax.png)

where *v* and *v<sup>'<sup>* are target and context vector representations of words and *W* is vocabulary size. 

Computing the denominator of this formulation involves performing a full softmax over the entire vocabulary words which is often large (10<sup>5</sup>-10<sup>7</sup>) terms. 

The [Noise Contrastive Estimation](https://www.tensorflow.org/api_docs/python/tf/nn/nce_loss) loss function is an efficient approximation for a full softmax. With an objective to learn word embeddings instead of modelling the word distribution, NCE loss can be [simplified](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf) to use negative sampling. 

The simplified negative sampling objective for a target word is to distinguish  the context word from *num_ns* negative samples drawn from noise distribution *P<sub>n</sub>(w)* of words. More precisely, an efficient approximation of full softmax over the vocabulary is, for a skip-gram pair, to pose the loss for a target word as a classification problem between the context word and *num_ns* negative samples. 

A negative sample is defined as a (target_word, context_word) pair such that the context_word does not appear in the `window_size` neighborhood of the target_word. For the example sentence, these are few potential negative samples (when `window_size` is 2).

```
(hot, shimmered)
(wide, hot)
(wide, sun)
```

In the next section, you'll generate skip-grams and negative samples for a single sentence. You'll also learn about subsampling techniques and train a classification model for positive and negative training examples later in the tutorial.

## Setup

In [None]:
# Use the chown command to change the ownership of repository to user.
!sudo chown -R jupyter:jupyter /home/jupyter/training-data-analyst

In [2]:
!pip install -q tqdm

In [3]:
# You can use any Python source file as a module by executing an import statement in some other Python source file.
# The import statement combines two operations; it searches for the named module, then it binds the
# results of that search to a name in the local scope.
import io
import itertools
import numpy as np
import os
import re
import string
import tensorflow as tf
import tqdm

from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Activation, Dense, Dot, Embedding, Flatten, GlobalAveragePooling1D, Reshape
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

Please check your tensorflow version using the cell below.

In [None]:
# Show the currently installed version of TensorFlow
print("TensorFlow version: ",tf.version.VERSION)

TensorFlow version:  2.6.0


In [4]:
SEED = 42 
AUTOTUNE = tf.data.experimental.AUTOTUNE

### Vectorize an example sentence

Consider the following sentence:    
`The wide road shimmered in the hot sun.`

Tokenize the sentence:

In [5]:
sentence = "The wide road shimmered in the hot sun"
tokens = list(sentence.lower().split())
print(len(tokens))

8


Create a vocabulary to save mappings from tokens to integer indices.

In [6]:
vocab, index = {}, 1 # start indexing from 1
vocab['<pad>'] = 0 # add a padding token 
for token in tokens:
  if token not in vocab: 
    vocab[token] = index
    index += 1
vocab_size = len(vocab)
print(vocab)

{'<pad>': 0, 'the': 1, 'wide': 2, 'road': 3, 'shimmered': 4, 'in': 5, 'hot': 6, 'sun': 7}


Create an inverse vocabulary to save mappings from integer indices to tokens.

In [7]:
inverse_vocab = {index: token for token, index in vocab.items()}
print(inverse_vocab)

{0: '<pad>', 1: 'the', 2: 'wide', 3: 'road', 4: 'shimmered', 5: 'in', 6: 'hot', 7: 'sun'}


Vectorize your sentence.


In [8]:
example_sequence = [vocab[word] for word in tokens]
print(example_sequence)

[1, 2, 3, 4, 5, 1, 6, 7]


### Generate skip-grams from one sentence

The `tf.keras.preprocessing.sequence` module provides useful functions that simplify data preparation for Word2Vec. You can use the `tf.keras.preprocessing.sequence.skipgrams` to generate skip-gram pairs from the `example_sequence` with a given `window_size` from tokens in the range `[0, vocab_size)`.

Note: `negative_samples` is set to `0` here as batching negative samples generated by this function requires a bit of code. You will use another function to perform negative sampling in the next section.


In [9]:
window_size = 2
positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
      example_sequence, 
      vocabulary_size=vocab_size,
      window_size=window_size,
      negative_samples=0)
print(len(positive_skip_grams))

26


Take a look at few positive skip-grams.

In [10]:
for target, context in positive_skip_grams[:5]:
  print(f"({target}, {context}): ({inverse_vocab[target]}, {inverse_vocab[context]})")

(1, 3): (the, road)
(4, 1): (shimmered, the)
(5, 6): (in, hot)
(4, 2): (shimmered, wide)
(3, 2): (road, wide)


### Negative sampling for one skip-gram 

The `skipgrams` function returns all positive skip-gram pairs by sliding over a given window span. To produce additional skip-gram pairs that would serve as negative samples for training, you need to sample random words from the vocabulary. Use the `tf.random.log_uniform_candidate_sampler` function to sample `num_ns` number of negative samples for a given target word in a window. You can call the funtion on one skip-grams's target word and pass the context word as true class to exclude it from being sampled.


Key point: *num_ns* (number of negative samples per positive context word) between [5, 20] is [shown to work](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf) best for smaller datasets, while *num_ns* between [2,5] suffices for larger datasets. 

In [11]:
# Get target and context words for one positive skip-gram.
target_word, context_word = positive_skip_grams[0]

# Set the number of negative samples per positive context. 
num_ns = 4

context_class = tf.reshape(tf.constant(context_word, dtype="int64"), (1, 1))
negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
    true_classes=context_class, # class that should be sampled as 'positive'
    num_true=1, # each positive skip-gram has 1 positive context class
    num_sampled=num_ns, # number of negative context words to sample
    unique=True, # all the negative samples should be unique
    range_max=vocab_size, # pick index of the samples from [0, vocab_size]
    seed=SEED, # seed for reproducibility
    name="negative_sampling" # name of this operation
)
print(negative_sampling_candidates)
print([inverse_vocab[index.numpy()] for index in negative_sampling_candidates])

tf.Tensor([2 1 4 3], shape=(4,), dtype=int64)
['wide', 'the', 'shimmered', 'road']


### Construct one training example

For a given positive `(target_word, context_word)` skip-gram, you now also have `num_ns` negative sampled context words that do not appear in the window size neighborhood of `target_word`. Batch the `1` positive `context_word` and `num_ns` negative context words into one tensor. This produces a set of positive skip-grams (labelled as `1`) and negative samples (labelled as `0`) for each target word.

In [12]:
# Add a dimension so you can use concatenation (on the next step).
negative_sampling_candidates = tf.expand_dims(negative_sampling_candidates, 1)

# Concat positive context word with negative sampled words.
context = tf.concat([context_class, negative_sampling_candidates], 0)

# Label first context word as 1 (positive) followed by num_ns 0s (negative).
label = tf.constant([1] + [0]*num_ns, dtype="int64") 

# Reshape target to shape (1,) and context and label to (num_ns+1,).
target = tf.squeeze(target_word)
context = tf.squeeze(context)
label =  tf.squeeze(label)

Take a look at the context and the corresponding labels for the target word from the skip-gram example above. 

In [13]:
print(f"target_index    : {target}")
print(f"target_word     : {inverse_vocab[target_word]}")
print(f"context_indices : {context}")
print(f"context_words   : {[inverse_vocab[c.numpy()] for c in context]}")
print(f"label           : {label}")

target_index    : 1
target_word     : the
context_indices : [3 2 1 4 3]
context_words   : ['road', 'wide', 'the', 'shimmered', 'road']
label           : [1 0 0 0 0]


A tuple of `(target, context, label)` tensors constitutes one training example for training your skip-gram negative sampling Word2Vec model. Notice that the target is of shape `(1,)` while the context and label are of shape `(1+num_ns,)`

In [14]:
print(f"target  :", target)
print(f"context :", context )
print(f"label   :", label )

target  : tf.Tensor(1, shape=(), dtype=int32)
context : tf.Tensor([3 2 1 4 3], shape=(5,), dtype=int64)
label   : tf.Tensor([1 0 0 0 0], shape=(5,), dtype=int64)


### Summary

This picture summarizes the procedure of generating training example from a sentence. 


![word2vec_negative_sampling](assets/word2vec_negative_sampling.png)

## Lab Task 1: Compile all steps into one function


### Skip-gram Sampling table 

A large dataset means larger vocabulary with higher number of more frequent words such as stopwords. Training examples obtained from sampling commonly occuring words (such as `the`, `is`, `on`) don't add much useful information  for the model to learn from. [Mikolov et al.](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf) suggest subsampling of frequent words as a helpful practice to improve embedding quality. 

The `tf.keras.preprocessing.sequence.skipgrams` function accepts a sampling table argument to encode probabilities of sampling any token. You can use the `tf.keras.preprocessing.sequence.make_sampling_table` to  generate a word-frequency rank based probabilistic sampling table and pass it to `skipgrams` function. Take a look at the sampling probabilities for a `vocab_size` of 10.

In [15]:
sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(size=10)
print(sampling_table)

[0.00315225 0.00315225 0.00547597 0.00741556 0.00912817 0.01068435
 0.01212381 0.01347162 0.01474487 0.0159558 ]


`sampling_table[i]` denotes the probability of sampling the i-th most common word in a dataset. The function assumes a [Zipf's distribution](https://en.wikipedia.org/wiki/Zipf%27s_law) of the word frequencies for sampling.

Key point: The `tf.random.log_uniform_candidate_sampler` already assumes that the vocabulary frequency follows a log-uniform (Zipf's) distribution. Using these distribution weighted sampling also helps approximate the Noise Contrastive Estimation (NCE) loss with simpler loss functions for training a negative sampling objective.

### Generate training data

Compile all the steps described above into a function that can be called on a list of vectorized sentences obtained from any text dataset. Notice that the sampling table is built before sampling skip-gram word pairs. You will use this function in the later sections.

In [16]:
# Generates skip-gram pairs with negative sampling for a list of sequences
# (int-encoded sentences) based on window size, number of negative samples
# and vocabulary size.
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  # Elements of each training example are appended to these lists.
  targets, contexts, labels = [], [], []

  # Build the sampling table for vocab_size tokens.
  # TODO 1a
  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Iterate over all sequences (sentences) in dataset.
  for sequence in tqdm.tqdm(sequences):

    # Generate positive skip-gram pairs for a sequence (sentence).
    positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence, 
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0)
    
    # Iterate over each positive skip-gram pair to produce training examples 
    # with positive context word and negative samples.
    # TODO 1b
    for target_word, context_word in positive_skip_grams:
      context_class = tf.expand_dims(
          tf.constant([context_word], dtype="int64"), 1)
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1, 
          num_sampled=num_ns, 
          unique=True, 
          range_max=vocab_size, 
          seed=SEED, 
          name="negative_sampling")
      
      # Build context and label vectors (for one target word)
      negative_sampling_candidates = tf.expand_dims(
          negative_sampling_candidates, 1)

      context = tf.concat([context_class, negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      # Append each element from the training example to global lists.
      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

## Lab Task 2: Prepare training data for Word2Vec

With an understanding of how to work with one sentence for a skip-gram negative sampling based Word2Vec model, you can proceed to generate training examples from a larger list of sentences!

### Download text corpus


You will use a text file of Shakespeare's writing for this tutorial. Change the following line to run this code on your own data.

In [17]:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt


   8192/1115394 [..............................] - ETA: 0s



Read text from the file and take a look at the first few lines. 

In [18]:
with open(path_to_file) as f: 
  lines = f.read().splitlines()
for line in lines[:20]:
  print(line)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.


Use the non empty lines to construct a `tf.data.TextLineDataset` object for next steps.

In [19]:
# TODO 2a
text_ds = tf.data.TextLineDataset(path_to_file).filter(lambda x: tf.cast(tf.strings.length(x), bool))

### Vectorize sentences from the corpus

You can use the `TextVectorization` layer to vectorize sentences from the corpus. Learn more about using this layer in this [Text Classification](https://www.tensorflow.org/tutorials/keras/text_classification) tutorial. Notice from the first few sentences above that the text needs to be in one case and punctuation needs to be removed. To do this, define a `custom_standardization function` that can be used in the TextVectorization layer.

In [20]:
# We create a custom standardization function to lowercase the text and 
# remove punctuation.
def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  return tf.strings.regex_replace(lowercase,
                                  '[%s]' % re.escape(string.punctuation), '')

# Define the vocabulary size and number of words in a sequence.
vocab_size = 4096
sequence_length = 10

# Use the text vectorization layer to normalize, split, and map strings to
# integers. Set output_sequence_length length to pad all samples to same length.
vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size,
    output_mode='int',
    output_sequence_length=sequence_length)

Call `adapt` on the text dataset to create vocabulary.


In [21]:
vectorize_layer.adapt(text_ds.batch(1024))

Once the state of the layer has been adapted to represent the text corpus, the vocabulary can be accessed with `get_vocabulary()`. This function returns a list of all vocabulary tokens sorted (descending) by their frequency. 

In [22]:
# Save the created vocabulary for reference.
inverse_vocab = vectorize_layer.get_vocabulary()
print(inverse_vocab[:20])

['', '[UNK]', 'the', 'and', 'to', 'i', 'of', 'you', 'my', 'a', 'that', 'in', 'is', 'not', 'for', 'with', 'me', 'it', 'be', 'your']


The vectorize_layer can now be used to generate vectors for each element in the `text_ds`.

In [23]:
def vectorize_text(text):
  text = tf.expand_dims(text, -1)
  return tf.squeeze(vectorize_layer(text))

# Vectorize the data in text_ds.
text_vector_ds = text_ds.batch(1024).prefetch(AUTOTUNE).map(vectorize_layer).unbatch()

### Obtain sequences from the dataset

You now have a `tf.data.Dataset` of integer encoded sentences. To prepare the dataset for training a Word2Vec model, flatten the dataset into a list of sentence vector sequences. This step is required as you would iterate over each sentence in the dataset to produce positive and negative examples. 

Note: Since the `generate_training_data()` defined earlier uses non-TF python/numpy functions, you could also use a `tf.py_function` or `tf.numpy_function` with `tf.data.Dataset.map()`.

In [24]:
sequences = list(text_vector_ds.as_numpy_iterator())
print(len(sequences))

32777


Take a look at few examples from `sequences`.


In [25]:
for seq in sequences[:5]:
  print(f"{seq} => {[inverse_vocab[i] for i in seq]}")

[ 89 270   0   0   0   0   0   0   0   0] => ['first', 'citizen', '', '', '', '', '', '', '', '']
[138  36 982 144 673 125  16 106   0   0] => ['before', 'we', 'proceed', 'any', 'further', 'hear', 'me', 'speak', '', '']
[34  0  0  0  0  0  0  0  0  0] => ['all', '', '', '', '', '', '', '', '', '']
[106 106   0   0   0   0   0   0   0   0] => ['speak', 'speak', '', '', '', '', '', '', '', '']
[ 89 270   0   0   0   0   0   0   0   0] => ['first', 'citizen', '', '', '', '', '', '', '', '']


### Generate training examples from sequences

`sequences` is now a list of int encoded sentences. Just call the `generate_training_data()` function defined earlier to generate training examples for the Word2Vec model. To recap, the function iterates over each word from each sequence to collect positive and negative context words. Length of target, contexts and labels should be same, representing the total number of training examples.

In [26]:
targets, contexts, labels = generate_training_data(
    sequences=sequences, 
    window_size=2, 
    num_ns=4, 
    vocab_size=vocab_size, 
    seed=SEED)
print(len(targets), len(contexts), len(labels))

  0%|          | 0/32777 [00:00<?, ?it/s]

  0%|          | 130/32777 [00:00<00:25, 1290.67it/s]

  1%|          | 260/32777 [00:00<00:26, 1204.74it/s]

  1%|          | 403/32777 [00:00<00:24, 1302.79it/s]

  2%|▏         | 534/32777 [00:00<00:26, 1220.29it/s]

  2%|▏         | 704/32777 [00:00<00:23, 1374.92it/s]

  3%|▎         | 843/32777 [00:00<00:24, 1314.59it/s]

  3%|▎         | 985/32777 [00:00<00:23, 1343.36it/s]

  3%|▎         | 1141/32777 [00:00<00:22, 1400.11it/s]

  4%|▍         | 1282/32777 [00:00<00:22, 1395.62it/s]

  4%|▍         | 1425/32777 [00:01<00:22, 1398.25it/s]

  5%|▍         | 1566/32777 [00:01<00:25, 1227.86it/s]

  5%|▌         | 1693/32777 [00:01<00:26, 1161.30it/s]

  6%|▌         | 1822/32777 [00:01<00:25, 1193.64it/s]

  6%|▌         | 1944/32777 [00:01<00:25, 1189.19it/s]

  7%|▋         | 2139/32777 [00:01<00:21, 1396.93it/s]

  7%|▋         | 2313/32777 [00:01<00:20, 1489.45it/s]

  8%|▊         | 2537/32777 [00:01<00:17, 1705.36it/s]

  8%|▊         | 2710/32777 [00:01<00:17, 1687.36it/s]

  9%|▉         | 2881/32777 [00:02<00:18, 1657.44it/s]

  9%|▉         | 3048/32777 [00:02<00:18, 1642.90it/s]

 10%|▉         | 3214/32777 [00:02<00:18, 1618.60it/s]

 10%|█         | 3377/32777 [00:02<00:19, 1539.00it/s]

 11%|█         | 3532/32777 [00:02<00:19, 1501.18it/s]

 11%|█         | 3683/32777 [00:02<00:20, 1419.09it/s]

 12%|█▏        | 3880/32777 [00:02<00:18, 1563.15it/s]

 12%|█▏        | 4039/32777 [00:02<00:18, 1554.95it/s]

 13%|█▎        | 4196/32777 [00:02<00:18, 1529.13it/s]

 13%|█▎        | 4350/32777 [00:03<00:20, 1359.67it/s]

 14%|█▎        | 4490/32777 [00:03<00:22, 1245.60it/s]

 14%|█▍        | 4631/32777 [00:03<00:21, 1286.33it/s]

 15%|█▍        | 4794/32777 [00:03<00:20, 1373.31it/s]

 15%|█▌        | 4935/32777 [00:03<00:21, 1293.66it/s]

 15%|█▌        | 5068/32777 [00:03<00:21, 1261.17it/s]

 16%|█▌        | 5209/32777 [00:03<00:21, 1300.66it/s]

 16%|█▋        | 5372/32777 [00:03<00:19, 1391.87it/s]

 17%|█▋        | 5514/32777 [00:03<00:19, 1395.77it/s]

 17%|█▋        | 5655/32777 [00:04<00:21, 1289.41it/s]

 18%|█▊        | 5787/32777 [00:04<00:22, 1187.64it/s]

 18%|█▊        | 5913/32777 [00:04<00:22, 1206.02it/s]

 18%|█▊        | 6042/32777 [00:04<00:21, 1224.97it/s]

 19%|█▉        | 6173/32777 [00:04<00:21, 1247.39it/s]

 19%|█▉        | 6303/32777 [00:04<00:21, 1258.49it/s]

 20%|█▉        | 6444/32777 [00:04<00:20, 1296.15it/s]

 20%|██        | 6575/32777 [00:04<00:20, 1252.69it/s]

 20%|██        | 6703/32777 [00:04<00:20, 1259.99it/s]

 21%|██        | 6830/32777 [00:05<00:21, 1231.77it/s]

 21%|██        | 6954/32777 [00:05<00:21, 1175.20it/s]

 22%|██▏       | 7073/32777 [00:05<00:22, 1131.32it/s]

 22%|██▏       | 7206/32777 [00:05<00:21, 1183.55it/s]

 22%|██▏       | 7339/32777 [00:05<00:20, 1223.83it/s]

 23%|██▎       | 7463/32777 [00:05<00:21, 1177.79it/s]

 23%|██▎       | 7582/32777 [00:05<00:21, 1176.36it/s]

 24%|██▎       | 7742/32777 [00:05<00:19, 1290.61it/s]

 24%|██▍       | 7872/32777 [00:05<00:19, 1274.76it/s]

 24%|██▍       | 8000/32777 [00:06<00:20, 1196.45it/s]

 25%|██▍       | 8121/32777 [00:06<00:26, 919.60it/s] 

 25%|██▌       | 8223/32777 [00:06<00:26, 936.05it/s]

 25%|██▌       | 8327/32777 [00:06<00:25, 960.95it/s]

 26%|██▌       | 8429/32777 [00:06<00:26, 914.45it/s]

 26%|██▌       | 8573/32777 [00:06<00:23, 1050.69it/s]

 26%|██▋       | 8684/32777 [00:06<00:22, 1048.02it/s]

 27%|██▋       | 8818/32777 [00:06<00:21, 1125.35it/s]

 27%|██▋       | 8988/32777 [00:06<00:18, 1285.96it/s]

 28%|██▊       | 9120/32777 [00:07<00:18, 1250.11it/s]

 28%|██▊       | 9248/32777 [00:07<00:20, 1160.04it/s]

 29%|██▊       | 9376/32777 [00:07<00:19, 1183.71it/s]

 29%|██▉       | 9498/32777 [00:07<00:19, 1190.33it/s]

 29%|██▉       | 9621/32777 [00:07<00:19, 1193.34it/s]

 30%|██▉       | 9742/32777 [00:07<00:19, 1177.71it/s]

 30%|███       | 9861/32777 [00:07<00:19, 1166.33it/s]

 30%|███       | 9979/32777 [00:07<00:21, 1075.96it/s]

 31%|███       | 10089/32777 [00:07<00:21, 1037.96it/s]

 31%|███       | 10198/32777 [00:08<00:21, 1048.57it/s]

 31%|███▏      | 10304/32777 [00:08<00:23, 943.35it/s] 

 32%|███▏      | 10401/32777 [00:08<00:23, 947.13it/s]

 32%|███▏      | 10509/32777 [00:08<00:22, 981.23it/s]

 32%|███▏      | 10611/32777 [00:08<00:22, 987.06it/s]

 33%|███▎      | 10739/32777 [00:08<00:20, 1065.95it/s]

 33%|███▎      | 10860/32777 [00:08<00:19, 1103.49it/s]

 34%|███▎      | 11000/32777 [00:08<00:18, 1183.49it/s]

 34%|███▍      | 11123/32777 [00:08<00:18, 1189.01it/s]

 34%|███▍      | 11243/32777 [00:09<00:19, 1112.40it/s]

 35%|███▍      | 11356/32777 [00:09<00:20, 1065.67it/s]

 35%|███▍      | 11464/32777 [00:09<00:20, 1023.70it/s]

 35%|███▌      | 11576/32777 [00:09<00:20, 1050.07it/s]

 36%|███▌      | 11685/32777 [00:09<00:19, 1060.09it/s]

 36%|███▌      | 11792/32777 [00:09<00:20, 1040.53it/s]

 36%|███▋      | 11897/32777 [00:09<00:22, 927.83it/s] 

 37%|███▋      | 11993/32777 [00:09<00:23, 899.47it/s]

 37%|███▋      | 12087/32777 [00:09<00:22, 908.80it/s]

 37%|███▋      | 12180/32777 [00:10<00:22, 907.81it/s]

 37%|███▋      | 12285/32777 [00:10<00:21, 942.30it/s]

 38%|███▊      | 12415/32777 [00:10<00:19, 1040.59it/s]

 38%|███▊      | 12534/32777 [00:10<00:18, 1082.86it/s]

 39%|███▊      | 12658/32777 [00:10<00:17, 1126.25it/s]

 39%|███▉      | 12772/32777 [00:10<00:17, 1114.55it/s]

 39%|███▉      | 12884/32777 [00:10<00:18, 1102.98it/s]

 40%|███▉      | 13030/32777 [00:10<00:16, 1207.19it/s]

 40%|████      | 13172/32777 [00:10<00:15, 1260.32it/s]

 41%|████      | 13299/32777 [00:11<00:17, 1126.86it/s]

 41%|████      | 13415/32777 [00:11<00:17, 1129.64it/s]

 41%|████▏     | 13532/32777 [00:11<00:16, 1137.15it/s]

 42%|████▏     | 13660/32777 [00:11<00:16, 1168.83it/s]

 42%|████▏     | 13806/32777 [00:11<00:15, 1251.14it/s]

 43%|████▎     | 13933/32777 [00:11<00:15, 1205.02it/s]

 43%|████▎     | 14055/32777 [00:11<00:16, 1151.19it/s]

 43%|████▎     | 14189/32777 [00:11<00:15, 1198.47it/s]

 44%|████▎     | 14310/32777 [00:11<00:16, 1148.91it/s]

 44%|████▍     | 14463/32777 [00:11<00:14, 1250.78it/s]

 45%|████▍     | 14590/32777 [00:12<00:15, 1170.78it/s]

 45%|████▍     | 14709/32777 [00:12<00:15, 1157.86it/s]

 45%|████▌     | 14848/32777 [00:12<00:14, 1214.00it/s]

 46%|████▌     | 14971/32777 [00:12<00:15, 1172.55it/s]

 46%|████▌     | 15090/32777 [00:12<00:16, 1104.12it/s]

 46%|████▋     | 15211/32777 [00:12<00:15, 1130.11it/s]

 47%|████▋     | 15329/32777 [00:12<00:15, 1140.80it/s]

 47%|████▋     | 15460/32777 [00:12<00:14, 1184.63it/s]

 48%|████▊     | 15600/32777 [00:12<00:13, 1241.88it/s]

 48%|████▊     | 15762/32777 [00:13<00:12, 1345.56it/s]

 49%|████▊     | 15898/32777 [00:13<00:13, 1253.70it/s]

 49%|████▉     | 16035/32777 [00:13<00:13, 1284.49it/s]

 49%|████▉     | 16165/32777 [00:13<00:13, 1227.88it/s]

 50%|████▉     | 16290/32777 [00:13<00:13, 1225.49it/s]

 50%|█████     | 16414/32777 [00:13<00:14, 1144.95it/s]

 50%|█████     | 16547/32777 [00:13<00:13, 1192.57it/s]

 51%|█████     | 16673/32777 [00:13<00:13, 1206.39it/s]

 51%|█████     | 16795/32777 [00:13<00:13, 1147.37it/s]

 52%|█████▏    | 16940/32777 [00:14<00:12, 1231.27it/s]

 52%|█████▏    | 17097/32777 [00:14<00:11, 1326.58it/s]

 53%|█████▎    | 17232/32777 [00:14<00:11, 1319.29it/s]

 53%|█████▎    | 17365/32777 [00:14<00:12, 1247.83it/s]

 53%|█████▎    | 17492/32777 [00:14<00:12, 1183.33it/s]

 54%|█████▎    | 17612/32777 [00:14<00:12, 1167.53it/s]

 54%|█████▍    | 17730/32777 [00:14<00:14, 1068.22it/s]

 55%|█████▍    | 17864/32777 [00:14<00:13, 1133.02it/s]

 55%|█████▍    | 17980/32777 [00:14<00:13, 1113.10it/s]

 55%|█████▌    | 18093/32777 [00:15<00:14, 1016.62it/s]

 56%|█████▌    | 18197/32777 [00:15<00:15, 927.30it/s] 

 56%|█████▌    | 18293/32777 [00:15<00:15, 916.76it/s]

 56%|█████▌    | 18396/32777 [00:15<00:15, 944.96it/s]

 56%|█████▋    | 18502/32777 [00:15<00:14, 976.30it/s]

 57%|█████▋    | 18643/32777 [00:15<00:12, 1097.76it/s]

 57%|█████▋    | 18755/32777 [00:15<00:13, 1066.93it/s]

 58%|█████▊    | 18864/32777 [00:15<00:13, 1047.34it/s]

 58%|█████▊    | 18998/32777 [00:15<00:12, 1123.54it/s]

 58%|█████▊    | 19120/32777 [00:16<00:11, 1146.14it/s]

 59%|█████▊    | 19239/32777 [00:16<00:11, 1155.74it/s]

 59%|█████▉    | 19388/32777 [00:16<00:10, 1239.93it/s]

 60%|█████▉    | 19513/32777 [00:16<00:12, 1074.10it/s]

 60%|█████▉    | 19625/32777 [00:16<00:13, 967.20it/s] 

 60%|██████    | 19735/32777 [00:16<00:13, 999.28it/s]

 61%|██████    | 19839/32777 [00:16<00:13, 960.15it/s]

 61%|██████    | 19977/32777 [00:16<00:12, 1063.78it/s]

 61%|██████▏   | 20087/32777 [00:16<00:12, 998.98it/s] 

 62%|██████▏   | 20190/32777 [00:17<00:14, 879.51it/s]

 62%|██████▏   | 20310/32777 [00:17<00:13, 957.27it/s]

 62%|██████▏   | 20410/32777 [00:17<00:12, 959.97it/s]

 63%|██████▎   | 20509/32777 [00:17<00:12, 952.17it/s]

 63%|██████▎   | 20607/32777 [00:17<00:12, 951.56it/s]

 63%|██████▎   | 20712/32777 [00:17<00:12, 976.29it/s]

 64%|██████▎   | 20834/32777 [00:17<00:11, 1043.17it/s]

 64%|██████▍   | 20961/32777 [00:17<00:10, 1102.90it/s]

 64%|██████▍   | 21103/32777 [00:17<00:09, 1187.95it/s]

 65%|██████▍   | 21223/32777 [00:18<00:09, 1158.79it/s]

 65%|██████▌   | 21359/32777 [00:18<00:09, 1212.72it/s]

 66%|██████▌   | 21492/32777 [00:18<00:09, 1246.48it/s]

 66%|██████▌   | 21645/32777 [00:18<00:08, 1328.81it/s]

 66%|██████▋   | 21779/32777 [00:18<00:08, 1244.05it/s]

 67%|██████▋   | 21905/32777 [00:18<00:09, 1172.14it/s]

 67%|██████▋   | 22032/32777 [00:18<00:08, 1194.44it/s]

 68%|██████▊   | 22164/32777 [00:18<00:08, 1220.60it/s]

 68%|██████▊   | 22288/32777 [00:18<00:09, 1090.22it/s]

 68%|██████▊   | 22401/32777 [00:19<00:11, 907.78it/s] 

 69%|██████▊   | 22499/32777 [00:19<00:11, 890.76it/s]

 69%|██████▉   | 22629/32777 [00:19<00:10, 990.92it/s]

 69%|██████▉   | 22743/32777 [00:19<00:09, 1029.57it/s]

 70%|██████▉   | 22851/32777 [00:19<00:09, 1022.50it/s]

 70%|███████   | 22957/32777 [00:19<00:10, 954.09it/s] 

 70%|███████   | 23056/32777 [00:19<00:10, 915.06it/s]

 71%|███████   | 23162/32777 [00:19<00:10, 951.84it/s]

 71%|███████   | 23260/32777 [00:20<00:09, 952.61it/s]

 71%|███████▏  | 23412/32777 [00:20<00:08, 1109.97it/s]

 72%|███████▏  | 23539/32777 [00:20<00:07, 1155.53it/s]

 72%|███████▏  | 23677/32777 [00:20<00:07, 1213.40it/s]

 73%|███████▎  | 23818/32777 [00:20<00:07, 1268.89it/s]

 73%|███████▎  | 23971/32777 [00:20<00:06, 1342.22it/s]

 74%|███████▎  | 24114/32777 [00:20<00:06, 1363.78it/s]

 74%|███████▍  | 24251/32777 [00:20<00:07, 1211.59it/s]

 74%|███████▍  | 24376/32777 [00:20<00:07, 1179.99it/s]

 75%|███████▍  | 24513/32777 [00:20<00:06, 1230.80it/s]

 75%|███████▌  | 24653/32777 [00:21<00:06, 1277.97it/s]

 76%|███████▌  | 24815/32777 [00:21<00:05, 1370.16it/s]

 76%|███████▌  | 24954/32777 [00:21<00:05, 1329.11it/s]

 77%|███████▋  | 25094/32777 [00:21<00:05, 1335.46it/s]

 77%|███████▋  | 25304/32777 [00:21<00:04, 1550.06it/s]

 78%|███████▊  | 25461/32777 [00:21<00:04, 1484.39it/s]

 78%|███████▊  | 25636/32777 [00:21<00:04, 1554.38it/s]

 79%|███████▊  | 25793/32777 [00:21<00:04, 1555.36it/s]

 79%|███████▉  | 25960/32777 [00:21<00:04, 1587.37it/s]

 80%|███████▉  | 26120/32777 [00:22<00:04, 1473.36it/s]

 80%|████████  | 26270/32777 [00:22<00:04, 1431.14it/s]

 81%|████████  | 26415/32777 [00:22<00:05, 1134.15it/s]

 81%|████████  | 26539/32777 [00:22<00:05, 1139.16it/s]

 81%|████████▏ | 26691/32777 [00:22<00:04, 1229.78it/s]

 82%|████████▏ | 26821/32777 [00:22<00:04, 1230.82it/s]

 82%|████████▏ | 26962/32777 [00:22<00:04, 1275.90it/s]

 83%|████████▎ | 27094/32777 [00:22<00:04, 1214.54it/s]

 83%|████████▎ | 27232/32777 [00:23<00:04, 1257.86it/s]

 84%|████████▎ | 27372/32777 [00:23<00:04, 1291.94it/s]

 84%|████████▍ | 27512/32777 [00:23<00:03, 1320.56it/s]

 84%|████████▍ | 27647/32777 [00:23<00:03, 1326.09it/s]

 85%|████████▍ | 27781/32777 [00:23<00:03, 1285.54it/s]

 85%|████████▌ | 27922/32777 [00:23<00:03, 1312.17it/s]

 86%|████████▌ | 28054/32777 [00:23<00:03, 1222.29it/s]

 86%|████████▌ | 28178/32777 [00:23<00:03, 1195.10it/s]

 86%|████████▋ | 28299/32777 [00:23<00:04, 1076.82it/s]

 87%|████████▋ | 28410/32777 [00:24<00:04, 1045.85it/s]

 87%|████████▋ | 28541/32777 [00:24<00:03, 1114.29it/s]

 87%|████████▋ | 28668/32777 [00:24<00:03, 1153.79it/s]

 88%|████████▊ | 28786/32777 [00:24<00:03, 1074.73it/s]

 88%|████████▊ | 28896/32777 [00:24<00:03, 1054.76it/s]

 89%|████████▊ | 29044/32777 [00:24<00:03, 1169.72it/s]

 89%|████████▉ | 29225/32777 [00:24<00:02, 1349.33it/s]

 90%|████████▉ | 29363/32777 [00:24<00:02, 1333.65it/s]

 90%|█████████ | 29522/32777 [00:24<00:02, 1401.68it/s]

 91%|█████████ | 29664/32777 [00:24<00:02, 1285.23it/s]

 91%|█████████ | 29796/32777 [00:25<00:02, 1261.55it/s]

 91%|█████████▏| 29925/32777 [00:25<00:02, 1185.38it/s]

 92%|█████████▏| 30071/32777 [00:25<00:02, 1254.46it/s]

 92%|█████████▏| 30199/32777 [00:25<00:02, 1184.73it/s]

 93%|█████████▎| 30353/32777 [00:25<00:01, 1280.64it/s]

 93%|█████████▎| 30492/32777 [00:25<00:01, 1307.37it/s]

 93%|█████████▎| 30635/32777 [00:25<00:01, 1340.02it/s]

 94%|█████████▍| 30771/32777 [00:25<00:01, 1312.51it/s]

 94%|█████████▍| 30904/32777 [00:26<00:01, 973.49it/s] 

 95%|█████████▍| 31043/32777 [00:26<00:01, 1068.95it/s]

 95%|█████████▌| 31208/32777 [00:26<00:01, 1207.38it/s]

 96%|█████████▌| 31374/32777 [00:26<00:01, 1325.21it/s]

 96%|█████████▌| 31547/32777 [00:26<00:00, 1430.27it/s]

 97%|█████████▋| 31698/32777 [00:26<00:00, 1358.48it/s]

 97%|█████████▋| 31840/32777 [00:26<00:00, 1343.18it/s]

 98%|█████████▊| 31979/32777 [00:26<00:00, 1283.90it/s]

 98%|█████████▊| 32111/32777 [00:26<00:00, 1272.53it/s]

 98%|█████████▊| 32241/32777 [00:27<00:00, 1242.12it/s]

 99%|█████████▉| 32382/32777 [00:27<00:00, 1286.56it/s]

 99%|█████████▉| 32571/32777 [00:27<00:00, 1457.17it/s]

100%|█████████▉| 32727/32777 [00:27<00:00, 1485.36it/s]

100%|██████████| 32777/32777 [00:27<00:00, 1196.59it/s]

64362 64362 64362





### Configure the dataset for performance

To perform efficient batching for the potentially large number of training examples, use the `tf.data.Dataset` API. After this step, you would have a `tf.data.Dataset` object of `(target_word, context_word), (label)` elements to train your Word2Vec model!

In [27]:
BATCH_SIZE = 1024
BUFFER_SIZE = 10000
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
print(dataset)

<BatchDataset shapes: (((1024,), (1024, 5, 1)), (1024, 5)), types: ((tf.int32, tf.int64), tf.int64)>


Add `cache()` and `prefetch()` to improve performance.

In [28]:
dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)
print(dataset)

<PrefetchDataset shapes: (((1024,), (1024, 5, 1)), (1024, 5)), types: ((tf.int32, tf.int64), tf.int64)>


## Lab Task 3: Model and Training

The Word2Vec model can be implemented as a classifier to distinguish between true context words from skip-grams and false context words obtained through negative sampling. You can perform a dot product between the embeddings of target and context words to obtain predictions for labels and compute loss against true labels in the dataset.

### Subclassed Word2Vec Model

Use the [Keras Subclassing API](https://www.tensorflow.org/guide/keras/custom_layers_and_models) to define your Word2Vec model with the following layers:


* `target_embedding`: A `tf.keras.layers.Embedding` layer which looks up the embedding of a word when it appears as a target word. The number of parameters in this layer are `(vocab_size * embedding_dim)`.
* `context_embedding`: Another `tf.keras.layers.Embedding` layer which looks up the embedding of a word when it appears as a context word. The number of parameters in this layer are the same as those in `target_embedding`, i.e. `(vocab_size * embedding_dim)`.
* `dots`: A `tf.keras.layers.Dot` layer that computes the dot product of target and context embeddings from a training pair.
* `flatten`: A `tf.keras.layers.Flatten` layer to flatten the results of `dots` layer into logits.

With the sublassed model, you can define the `call()` function that accepts `(target, context)` pairs which can then be passed into their corresponding embedding layer. Reshape the `context_embedding` to perform a dot product with `target_embedding` and return the flattened result.

Key point: The `target_embedding` and `context_embedding` layers can be shared as well. You could also use a concatenation of both embeddings as the final Word2Vec embedding.

In [29]:
class Word2Vec(Model):
  def __init__(self, vocab_size, embedding_dim):
    super(Word2Vec, self).__init__()
    self.target_embedding = Embedding(vocab_size, 
                                      embedding_dim,
                                      input_length=1,
                                      name="w2v_embedding", )
    self.context_embedding = Embedding(vocab_size, 
                                       embedding_dim, 
                                       input_length=num_ns+1)
    self.dots = Dot(axes=(3,2))
    self.flatten = Flatten()

  def call(self, pair):
    target, context = pair
    we = self.target_embedding(target)
    ce = self.context_embedding(context)
    dots = self.dots([ce, we])
    return self.flatten(dots)

### Define loss function and compile model


For simplicity, you can use `tf.keras.losses.CategoricalCrossEntropy` as an alternative to the negative sampling loss. If you would like to write your own custom loss function, you can also do so as follows:

``` python
def custom_loss(x_logit, y_true):
      return tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=y_true)
```

It's time to build your model! Instantiate your Word2Vec class with an embedding dimension of 128 (you could experiment with different values). Compile the model with the `tf.keras.optimizers.Adam` optimizer. 

In [30]:
# TODO 3a
embedding_dim = 128
word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(optimizer='adam',
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Also define a callback to log training statistics for tensorboard.

In [31]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")

Train the model with `dataset` prepared above for some number of epochs.

In [32]:
word2vec.fit(dataset, epochs=20, callbacks=[tensorboard_callback])

Epoch 1/20


 1/62 [..............................] - ETA: 57s - loss: 1.6086 - accuracy: 0.2236

 3/62 [>.............................] - ETA: 2s - loss: 1.6090 - accuracy: 0.2168 

11/62 [====>.........................] - ETA: 0s - loss: 1.6093 - accuracy: 0.2127













Epoch 2/20
 1/62 [..............................] - ETA: 0s - loss: 1.5899 - accuracy: 0.7832

10/62 [===>..........................] - ETA: 0s - loss: 1.5930 - accuracy: 0.6796













Epoch 3/20


 1/62 [..............................] - ETA: 0s - loss: 1.5601 - accuracy: 0.7656

 9/62 [===>..........................] - ETA: 0s - loss: 1.5630 - accuracy: 0.7238













Epoch 4/20
 1/62 [..............................] - ETA: 0s - loss: 1.4925 - accuracy: 0.6670

11/62 [====>.........................] - ETA: 0s - loss: 1.4947 - accuracy: 0.6363













Epoch 5/20
 1/62 [..............................] - ETA: 0s - loss: 1.3973 - accuracy: 0.6113

11/62 [====>.........................] - ETA: 0s - loss: 1.3988 - accuracy: 0.6039













Epoch 6/20


 1/62 [..............................] - ETA: 0s - loss: 1.2968 - accuracy: 0.6182

10/62 [===>..........................] - ETA: 0s - loss: 1.2982 - accuracy: 0.6144













Epoch 7/20


 1/62 [..............................] - ETA: 0s - loss: 1.2016 - accuracy: 0.6445

10/62 [===>..........................] - ETA: 0s - loss: 1.2031 - accuracy: 0.6404















Epoch 8/20
 1/62 [..............................] - ETA: 0s - loss: 1.1132 - accuracy: 0.6758

10/62 [===>..........................] - ETA: 0s - loss: 1.1151 - accuracy: 0.6757















Epoch 9/20
 1/62 [..............................] - ETA: 0s - loss: 1.0313 - accuracy: 0.7100

10/62 [===>..........................] - ETA: 0s - loss: 1.0339 - accuracy: 0.7087













Epoch 10/20
 1/62 [..............................] - ETA: 0s - loss: 0.9556 - accuracy: 0.7344

10/62 [===>..........................] - ETA: 0s - loss: 0.9586 - accuracy: 0.7377













Epoch 11/20
 1/62 [..............................] - ETA: 0s - loss: 0.8857 - accuracy: 0.7627

11/62 [====>.........................] - ETA: 0s - loss: 0.8891 - accuracy: 0.7651













Epoch 12/20
 1/62 [..............................] - ETA: 0s - loss: 0.8213 - accuracy: 0.7881

10/62 [===>..........................] - ETA: 0s - loss: 0.8248 - accuracy: 0.7873













Epoch 13/20
 1/62 [..............................] - ETA: 0s - loss: 0.7622 - accuracy: 0.8057

11/62 [====>.........................] - ETA: 0s - loss: 0.7659 - accuracy: 0.8057













Epoch 14/20
 1/62 [..............................] - ETA: 0s - loss: 0.7082 - accuracy: 0.8145

11/62 [====>.........................] - ETA: 0s - loss: 0.7119 - accuracy: 0.8214













Epoch 15/20


 1/62 [..............................] - ETA: 0s - loss: 0.6588 - accuracy: 0.8320

11/62 [====>.........................] - ETA: 0s - loss: 0.6624 - accuracy: 0.8367













Epoch 16/20
 1/62 [..............................] - ETA: 0s - loss: 0.6139 - accuracy: 0.8525

 9/62 [===>..........................] - ETA: 0s - loss: 0.6170 - accuracy: 0.8525













Epoch 17/20
 1/62 [..............................] - ETA: 0s - loss: 0.5730 - accuracy: 0.8672

11/62 [====>.........................] - ETA: 0s - loss: 0.5764 - accuracy: 0.8659













Epoch 18/20


 1/62 [..............................] - ETA: 0s - loss: 0.5359 - accuracy: 0.8760

11/62 [====>.........................] - ETA: 0s - loss: 0.5391 - accuracy: 0.8761













Epoch 19/20
 1/62 [..............................] - ETA: 0s - loss: 0.5022 - accuracy: 0.8828

 9/62 [===>..........................] - ETA: 0s - loss: 0.5049 - accuracy: 0.8834













Epoch 20/20
 1/62 [..............................] - ETA: 0s - loss: 0.4716 - accuracy: 0.8965

11/62 [====>.........................] - ETA: 0s - loss: 0.4745 - accuracy: 0.8938













<tensorflow.python.keras.callbacks.History at 0x7f65706f69b0>

Tensorboard now shows the Word2Vec model's accuracy and loss.

In [None]:
!tensorboard --bind_all --port=8081 --load_fast=false --logdir logs

Run the following command in **Cloud Shell:**

<code>gcloud beta compute ssh --zone &lt;instance-zone&gt; &lt;notebook-instance-name&gt; --project &lt;project-id&gt; -- -L 8081:localhost:8081</code> 

Make sure to replace &lt;instance-zone&gt;, &lt;notebook-instance-name&gt; and &lt;project-id&gt;.

In Cloud Shell, click *Web Preview* > *Change Port* and insert port number *8081*. Click *Change and Preview* to open the TensorBoard.

![embeddings_classifier_accuracy.png](assets/embeddings_classifier_accuracy.png)

**To quit the TensorBoard, click Kernel > Interrupt kernel**.

## Lab Task 4: Embedding lookup and analysis

Obtain the weights from the model using `get_layer()` and `get_weights()`. The `get_vocabulary()` function provides the vocabulary to build a metadata file with one token per line. 

In [33]:
# TODO 4a
weights = word2vec.get_layer('w2v_embedding').get_weights()[0]
vocab = vectorize_layer.get_vocabulary()

Create and save the vectors and metadata file. 

In [34]:
out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')

for index, word in enumerate(vocab):
  if  index == 0: continue # skip 0, it's padding.
  vec = weights[index] 
  out_v.write('\t'.join([str(x) for x in vec]) + "\n")
  out_m.write(word + "\n")
out_v.close()
out_m.close()

Download the `vectors.tsv` and `metadata.tsv` to analyze the obtained embeddings in the [Embedding Projector](https://projector.tensorflow.org/).

In [35]:
try:
  from google.colab import files
  files.download('vectors.tsv')
  files.download('metadata.tsv')
except Exception as e:
  pass

## Next steps


This tutorial has shown you how to implement a skip-gram Word2Vec model with negative sampling from scratch and visualize the obtained word embeddings.

* To learn more about word vectors and their mathematical representations, refer to these [notes](https://web.stanford.edu/class/cs224n/readings/cs224n-2019-notes01-wordvecs1.pdf).

* To learn more about advanced text processing, read the [Transformer model for language understanding](https://www.tensorflow.org/tutorials/text/transformer) tutorial.

* If you’re interested in pre-trained embedding models, you may also be interested in [Exploring the TF-Hub CORD-19 Swivel Embeddings](https://www.tensorflow.org/hub/tutorials/cord_19_embeddings_keras), or the [Multilingual Universal Sentence Encoder](https://www.tensorflow.org/hub/tutorials/cross_lingual_similarity_with_tf_hub_multilingual_universal_encoder)

* You may also like to train the model on a new dataset (there are many available in [TensorFlow Datasets](https://www.tensorflow.org/datasets)).
