Codecademy Logo

Generating Text with PyTorch

Reading in Text Data

To read a text file into Python, we use the with and open statements to manage the opening and closing of the files using the following parameters:

  • 'r': specifies reading mode
  • encoding='utf-8': specifies the encoding to decode the text file

Next, the as statement assigns the open file to the variable f where the .read() method reads and saves the text file to a string named text.

with open('datasets/text.txt', 'r', encoding='utf-8') as f:
text = f.read()

Initializing Clean LSTM States to Generate Text

To generate text from an LSTM, we’ll need to initialize clean hidden and cell states (vectors containing zero values).

The clean states allow the starting prompt to set the context for the LSTM to generate relevant text without bias from any previous inputs.

As the LSTM generates new text, both the hidden and cell states are continually updated with information from the newly generated token.

LSTM Layer in PyTorch

In PyTorch, an LSTM layer can be implemented using nn.LSTM from the torch.nn module (aliased as nn).

For text generation, the LSTM layer takes in the following input parameters:

  • input_size: equal to the embedding dimension in the previous embedding layer
  • hidden_size: specifies the number of neurons in the hidden layer
  • batch_first=True: specifies that the training data is loaded in batches using the DataLoader utility class
from torch import nn
# LSTM layer in PyTorch
nn.LSTM(input_size=36,
hidden_size=72,
batch_first=True)

Long-Short Term Memory (LSTM) Networks

LSTMs maintain a hidden state and a memory cell state that are updated and passed from one unit to the next.

To learn longer sequences, the cell state utilizes three gated units:

  • Forget gate: decides what is forgotten
  • Input gate: decides what is used update the cell state
  • Output gate: decides which how to update the hidden state

The output is influenced by the updated hidden and cell states.

RNN Issues: The Vanishing and Exploding Gradients

A major challenge facing RNNs is their inability to remember longer sequences due to the following issues:

  • Vanishing gradient problem: as the gradients across the units become smaller and smaller, the gradient vanishes toward zero leading to extremely small or no model weight updates
  • Exploding gradient problem: as the gradients across the units become larger and larger, the gradient explodes to a very large number leading to unstable model weight updates

PyTorch Utility Classes: Dataset and DataLoader

Since neural networks require heavy computations, the data needs to be efficiently accessed and loaded using the PyTorch utility classes: Dataset and DataLoader where

  • the Dataset class accesses the features and their corresponding labels

  • the DataLoader loads the features and their labels into smaller batches to train the network one batch at a time

from torch.utils.data import Dataset, DataLoader

Recurrent Neural Networks

Recurrent Neural Networks (RNNs) are a type of deep learning model used to model sequential data using a recurrence connection that connects individual units that can process information from the full sequence.

The key component of an RNN is a hidden state that attempts to retain information from multiple, previous units which is used to predict the next output of the sequence.

Using Cross-Entropy Loss for Text Generation Models

Text generation models that look to predict the next token can be treated as a multiclass classification task where each token in the vocabulary can be represented as its own category.

Therefore, the multiclass version of the cross-entropy loss function is used and can be implemented in PyTorch using nn.CrossEntropyLoss() from the torch.nn module.

from torch import nn
loss = nn.CrossEntropyLoss()

Token Scoring and Prediction

The final layer of a text generation model is typically a linear layer with an output size equal to the vocabulary size.

The output is a vector of scores for each token in the vocabulary and the token with the highest score is predicted to be the next generated token.

Note: LLMs like GPT will calculate token scores as well as add sampling techniques to introduce randomness and generate more diverse and interesting text.

Embedding Layer in PyTorch

Embeddings are created in PyTorch using the nn.Embedding layer from the torch.nn module (aliased as nn) and are parameterized by two inputs:

  • num_embeddings : the number of embeddings to create (equal to the vocabulary size)
  • embedding_dim : the dimension size of each embedding

The embedding layer is indexed by the token’s token ID (integer value) to obtain the token’s embedding vector.

from torch import nn
# create 100 embeddings with 16 dimensions
nn.Embedding(num_embeddings=100,
embedding_dim=16)

Bigram model

A Bigram model is a type of language model that uses the previous token to generate the next token.

Specifically, two consecutive tokens in a sequence are paired together to form a bigram where the first token serves as the context token that the model uses to predict the second token which serves as the target token.

Note: since models cannot understand text, the tokens in the bigrams must be referenced by their token IDs.

# bigram pair
bigram = [context_token, target_token]

Token Embeddings

Embeddings are dense vector representations of tokens containing continuous values that look to best capture the semantic meaning behind the token.

The key idea is that semantically similar tokens should have similar token embeddings. And vice-versa, dissimilar tokens should have dissimilar token embeddings.

Embeddings allow machine learning models to understand and process text data effectively in a numerical format.

# tokens represented as 2-dimensional embeddings
'vanity' =====> [ 0.7148, 0.3928]
'and' =====> [-0.4809, -0.9904]
'pride' =====> [ 0.0126, -0.1968]

Creating Vocabularies

The vocabulary is created by first obtaining a sorted list of unique tokens from the tokenized text using the set(), list(), and sorted() functions.

Then, the dictionary is created using the enumerate() function within a for loop that assigns the positional index ix for each token as its token ID.

The inverse vocabulary is created by reversing the items in the vocabulary dictionary which maps the token ID back to each token.

# sorted list of unique tokens
unique_tokens = sorted(list(set(tokenized_text)))
# vocabulary (token-to-index)
t2ix = {token:ix for ix, token in enumerate(unique_tokens)}
# inverse vocabulary (index-to-token)
tx2w = {ix:token for token,ix in t2ix.items()}

Vocabulary and Inverse Vocabulary

A vocabulary is built that contains a collection of unique tokens within the tokenized text mapped to a unique token ID integer.

An inverse vocabulary maps the token ID back to each token by reversing the vocabulary items.

For example, the token 'vanity' is assigned to token ID 6921 as such: {'vanity':6921}. The inverse vocabulary maps the token ID back to the token as such: {6921:'vanity'}.

Tokenization

Tokenization is the process of breaking down a text into individual units called tokens.

Two tokenization strategies include:

  • word-based tokenization that breaks down a text into individual word-based tokens using the .split() method
  • character-based tokenization that breaks down a text into individual character-based tokens using the list() function
text = '''Vanity and pride are different things'''
# word-based tokenization
word_tokens = text.split()
# output:
# >> ['Vanity', 'and', 'pride', 'are', 'different', 'things']
# character-based tokenization
character_tokens = list(text)
# output:
# >> ['V', 'a', 'n', 'i', 't', 'y', ' ', 'a', 'n', 'd', ' ', 'p', 'r', 'i', 'd', 'e', ' ', 'a', 'r', 'e', ' ', 'd', 'i', 'f', 'f', 'e', 'r', 'e', 'n', 't', ' ', 't', 'h', 'i', 'n', 'g', 's']

How Models Understand Text

In order for machine learning and deep learning models to understand text (which include the current most advanced large language models), it is necessary to transform the text into numerical representations that can be processed and analyzed by the models.

Text Data

Text data is an example of sequential data where the syntactic order of words is important to the grammatical structure and contextual meaning behind the text.

text = '''Vanity and pride are different things, though the words are often used synonymously.'''

Text Generation Models

Text generation models are language models that are trained to generate human-like text in response to an input.

Their generated text can assist us with tasks like:

  • answering questions
  • analyzing texts
  • writing documents
  • generating code
  • brainstorming ideas

Examples of text generation models include Bigram models, Recurrent Neural Networks (RNNs), Long-Short Term Memory Networks (LSTMs), and Large Language Models (LLMs).

Granular Strategies for Text Generation

Two granular strategies to train a text generation model to generate text include:

  • word-based strategy that generates one word at a time
  • character-based strategy that generates one character at a time

Learn More on Codecademy