Building a Generative Transformer Decoder in TensorFlow: A Step-by-Step Guide


This post documents my process implementing a generative transformer decoder model in TensorFlow in an effort to learn more about this architecture. We will be using a dataset of domain name as an example application to train the model.

Data pre-processing

You can find a step-by-step explanation of how to use BPE to tokenize the dataset in a previous post:

Training a Custom Byte-Pair Encoding (BPE) Tokenizer using Hugging Face

Learn how to train a custom Byte-Pair Encoding (BPE) tokenizer on a dataset of domain names using the Hugging Face library. Improve your NLP models' performance with this efficient tokenization technique.

https://mlread.me/favicon.svg
https://mlread.me/blog/byte-pair-encoding-bpe

Let’s start by loading the already tokenized dataset using the Hugging Face datasets library.

from datasets import load_dataset

ds = load_dataset("jeremygf/domains-app-alpha")
ds = ds['train']

Target shifting

The model needs to be fed slightly shifted target sequences. For example, if the correct domain name is “example” during training, the input might be [STA]exampl and the target would be example. This is called “teacher forcing.”

We can use the map method to produce input and target sequences. Here, x[‘ids’] contains the tokenized sequences as arrays of numerical IDs representing individual tokens.

ds = ds.map(lambda x: {
    'input_ids': x['ids'][:-1],
    'target_ids': x['ids'][1:]
})

ds_dict = ds.train_test_split(test_size=0.1)

To prepare a TensorFlow dataset, we use the DataCollator class from the Hugging Face Transformers library. A data collator is responsible for assembling your data samples into batches as needed for the model’s training or evaluation.

BATCH = 256

from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(
    tokenizer=tokenizer, 
    return_tensors="tf"
)

tf_train = ds_dict['train'].to_tf_dataset(
    columns="input_ids",
    label_cols="target_ids",
    collate_fn=data_collator,
    batch_size=BATCH,
    shuffle=True
)

tf_test = ds_dict['test'].to_tf_dataset(
    columns="input_ids",
    label_cols="target_ids",
    collate_fn=data_collator,
    batch_size=BATCH,
    shuffle=True
)

Building a Generative Transformer Decoder Model

Let’s break down how generative transformer decoder models work. The primary goal of a generative transformer decoder model is to predict the next element (e.g., word, character, or any other token) in a sequence given the previous elements. They achieve this through a combination of powerful mechanisms from the Transformer architecture.

Positional encoding

Unlike recurrent models (RNNs, LSTMs), Transformers don’t have an inherent sense of sequence order. Positional encodings inject information about the relative positions of tokens into the input representations, which helps the model understand the sequence’s structure.

The original paper uses the following formula for calculating the positional encoding:

This implementation uses sinusoids with varying wavelengths to represent positional information. The different depths i create different frequencies of sine and cosine waves, giving the model a richer way to “understand” the relative positions of tokens.

import numpy as np
def positional_encoding(length, depth):
    depth = depth/2

    positions = np.arange(length)[:, np.newaxis]
    depths = np.arange(depth)[np.newaxis, :]/depth

    angle_rates = 1 / (10000**depths)
    angle_rads = positions * angle_rates

    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1) 

    return tf.cast(pos_encoding, dtype=tf.float32)

The token sequences have to be converted to vectors using a standard embedding layer. The positional encoding array is added to the input token embeddings before being fed into the Transformer model.

class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, mask_zero=True) 
        self.pos_encoding = positional_encoding(length=MAX_SEQ_LEN, depth=embedding_dim)

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.embedding_dim, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x

Causal self-attention

Self-attention allows the transformer decoder to focus on different parts of the partially generated sequence. An important aspect of the decoder architecture is causal making in the attention mechanism. This means the model can only attend to past tokens, not future ones. This prevents “cheating” by seeing what the correct next token should be. Because we are not making use of an encoder component, the cross-attention portion of the decoder is omitted.

Multi-headed attention

The model contains multiple instances of the attention mechanism. Each “head” within the attention mechanism can focus on different aspects of the input sequence, enabling it to capture complex patterns and relationships within the partially generated sequence.

Let’s break down the code for the CausalSelfAttention layer class. The constructor takes two parameters:

  • num_heads: The number of attention heads in the multi-head attention mechanism.
  • embedding_dim: The dimensionality of the input embeddings.

With these parameters, we instantiate a MultiHeadAttention layer, the core of the self-attention mechanism. A LayerNormalization object is also created for stabilizing the attention mechanism’s output.

The call method defines what happens when the layer is called on an input x.

  • use_causal_mask=True ensures the model only attends to previous tokens and not to future ones in the sequence.
  • x = x + attn_output performs a residual connection, adding the attention output to the original input. This helps with gradient flow during training.
class CausalSelfAttention(tf.keras.layers.Layer):
    def __init__(self, num_heads, embedding_dim):
        super().__init__()

        self.mha = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, 
            key_dim=embedding_dim
        )

        self.layer_norm = LayerNormalization()
    
    def call(self, x):
        attn_output = self.mha(
            query=x, 
            value=x, 
            use_causal_mask=True
        )
        x = x + attn_output
        x = self.layer_norm(x)
        return x

Feed-Forward Networks

After the attention layers, feed-forward networks independently process each token’s representation. This adds further non-linear transformations to enhance the model’s expressive power.

class FeedForward(tf.keras.layers.Layer):
    def __init__(self, embedding_dim, feedforward_dim, dropout_rate=0.1):
        super().__init__()
        
        self.d1 = tf.keras.layers.Dense(feedforward_dim, activation='gelu')
        self.d2 = tf.keras.layers.Dense(embedding_dim)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

        self.layer_norm = LayerNormalization()

    def call(self, x):
        x = self.d2(self.d1(x))
        x = x + self.dropout(x)
        x = self.layer_norm(x) 
        return x

The DecoderLayer class combines the multi-headed attention layer and the feed-forward layer in sequence as a block.

class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, embedding_dim, num_heads, feedforward_dim, dropout_rate=0.1):
        super(DecoderLayer, self).__init__()

        self.causal_self_attention = CausalSelfAttention(num_heads=num_heads, embedding_dim=embedding_dim)

        self.ffn = FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, dropout_rate=dropout_rate)

    def call(self, x):
        x = self.causal_self_attention(x=x)
        x = self.ffn(x)
        return x    

Putting it all together

To build a full Transformer decoder, we need to stack multiple instances of the DecoderLayer. The final layer of the decoder is a dense layer with the same size as your vocabulary. The output is a probability distribution over the entire vocabulary. The token with the highest probability is typically chosen as the next prediction.

class TransformerDecoder(tf.keras.Model):
    def __init__(self, vocab_size: int, num_layers: int, embedding_dim: int, 
        num_heads: int, intermediate_dim: int) -> None:

        super().__init__()

        self.embeddings = PositionalEmbedding(vocab_size, embedding_dim)

        self.decoders = [
            DecoderLayer(embedding_dim, num_heads, intermediate_dim)
            for _ in range(num_layers)
        ]

        self.dense = tf.keras.layers.Dense(vocab_size, activation='relu')
        
    def call(self, inputs):
        x = self.embeddings(inputs)  

        for layer in self.decoders:
            x = layer(x)
        return self.dense(x)  

Training

We are now ready to start training an instance of our Generative Transformer Decoder model.

Masked loss and metrics

Since the target sequences have various lengths and are padded, it is important to apply a padding mask when calculating the loss and metrics. Masking ensures padding tokens don’t corrupt our calculations.

def masked_loss(label, pred):
    mask = label != 0
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    loss = loss_object(label, pred)

    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask

    loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
    return loss

def masked_accuracy(label, pred):
    pred = tf.argmax(pred, axis=2)
    label = tf.cast(label, pred.dtype)
    match = label == pred

    mask = label != 0

    match = match & mask

    match = tf.cast(match, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(match)/tf.reduce_sum(mask)

Hyperparameters and compiled model

The following constants define the hyperparameters of our Transformer decoder architecture:

  • EMBED_DIM: Dimensionality of the token embeddings.
  • NUM_HEADS: Number of attention heads in the multi-head attention layers.
  • FEED_FORWARD_DIM: Dimensionality of the hidden layer in the feed-forward sublayers of the decoder.
  • NUM_LAYERS: The number of Transformer decoder layers stacked in our model.

These parameters are not ‘one-size-fits-all’. It’s common to experiment with different values to find the best configuration for your specific task.

EPOCHS = 10

EMBED_DIM = 256
NUM_HEADS = 4
FEED_FORWARD_DIM = 512
NUM_LAYERS = 2

model = TransformerDecoder(
    vocab_size=tokenizer.vocab_size,
    num_layers=NUM_LAYERS,
    embedding_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    intermediate_dim=FEED_FORWARD_DIM
)

model.compile(loss=masked_loss, metrics=[masked_accuracy], optimizer='adam')

We are now ready to run the training procedure with the fit method of our model.

model.fit(
    tf_train, 
    validation_data=tf_test, 
    epochs=EPOCHS
)

Inference Once our model is trained, we can use it to generate domain names. The keras_nlp library provides convenient utilities for text generation

prompt = ''
prompt_length = len(tokenizer.encode(prompt))-1
prompt = tokenizer.encode(prompt, padding='max_length', max_length=MAX_SEQ_LEN, return_tensors='tf')

def step(prompt, cache, index):
    logits = model(prompt)[:, index-1, :]
    hidden_states = None
    return logits, hidden_states, cache

import keras_nlp

sampler = keras_nlp.samplers.TopPSampler(0.5)
output_tokens = sampler(
    next=step,
    prompt=prompt,
    index=prompt_length
)[0]

print(tokenizer.decode(output_tokens).split('[END]')[0])

Here are some example domain names generated by the model:

  • my doc tor y
  • r is k b ay
  • car e er w iz ard

Resources

Training a Custom Byte-Pair Encoding (BPE) Tokenizer using Hugging Face

Learn how to train a custom Byte-Pair Encoding (BPE) tokenizer on a dataset of domain names using the Hugging Face library. Improve your NLP models' performance with this efficient tokenization technique.

https://mlread.me/favicon.svg
https://mlread.me/blog/byte-pair-encoding-bpe
مدل ترانسفورماتور برای درک زبان  |  Text  |  TensorFlow

https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/apple-touch-icon-180x180.png
https://www.tensorflow.org/text/tutorials/transformer?ref=mlread.me#the_embedding_and_positional_encoding_layer
مدل ترانسفورماتور برای درک زبان  |  Text  |  TensorFlow

https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/apple-touch-icon-180x180.png
https://www.tensorflow.org/text/tutorials/transformer?ref=mlread.me#set_up_the_loss_and_metrics
Keras documentation: GPT text generation from scratch with KerasHub

Keras documentation

https://keras.io/favicon.ico
https://keras.io/examples/generative/text_generation_gpt/