Search

Book purpose

Table of contents

Note of thanks

Endorsements

References

License and referencing

Update tracking

How LLMs learn (3 of 4): Back propagation

How LLMs learn (3 of 4): Back propagation

With the forward pass complete, new tokens can be generated during inference which does not use back propagation, or learning begins with backpropagation. During learning, the final error calculated from the forward pass softmax against the target for the loss is traced back through the network to see how each layer contributed to the loss. Understanding backpropagation is matters because poor gradient management can waste enormous resources.

Gradient clipping, capping gradient values at a threshold to prevent them from growing exponentially, stops exploding gradients that cause training instability and NaN losses. Vanishing gradients can be ameliorated with architectural choices (residual connections, layer normalization) and proper weight initialization. Without this knowledge, undetected gradient issues during multi-day training runs on large GPU clusters can easily cost many thousands in wasted compute. Even in fine-tuning scenarios, catching these problems early prevents can save money.

Going backwards through he transformer, our goal at each (sub)layer is to calculate gradient with the same shape as the outputs at that stage in the forward pass, accumulating them with the chain rule so we can update parameters during optimization. The chain rule connects the “gradient out” from earlier layers (going backward) to the “gradient in” for next layer’s outputs and parameters. The results at each layer are gradient matrices and activation gradients that match the shape of the corresponding weight and activation matrices in the forward pass.

💡

Alert: Here I highlight the steps and computations of the backpropagation procedure with a numerical example to make the ideas concrete. The approach is consistent with our simplified GPT-2 model. These results are preliminary and will be validated with a formal mathematical verifier in due course.

Target specification

We must first define the target to calculate loss. In this case, the target is the input sequence shifted one position forward, so that each token is trained to predict the next token in the sequence. For our toy example the target is [rocks,<EOS>][rocks, <EOS>] for the first two positions, predicting the next tokens after AIAI and rocksrocks respectively, with no target for the final position. For our toy example, the target sequence in matrix form is specified as follows.

Input sequence: x=[x1,x2,x3]=[AI,rocks,<EOS>]=[012]Target sequence: y=[y1,y2,y3]=[rocks,<EOS>,no target]=[12]\begin{aligned} \text{Input sequence: } &\quad x = [x_1, x_2, x_3] = [\text{AI}, \text{rocks}, \text{<EOS>}] = \begin{bmatrix} 0 & 1 & 2 \end{bmatrix} \\[6pt] \text{Target sequence: } &\quad y = [y_1, y_2, y_3] = [\text{rocks}, \text{<EOS>}, \text{no target}] = \begin{bmatrix} 1 & 2 & - \end{bmatrix} \end{aligned}

For loss calculation, these targets are converted to one-hot vectors where only the position corresponding to the target token has value 1, and all other positions are 0.

One-hot targets: Yone-hot=[ 010001]\text{One-hot targets: } \quad Y_{\text{one-hot}} = \begin{bmatrix} 0 & 1 & 0 \\0 & 0 & 1 \\- & - & -\end{bmatrix}

The first row represents the one-hot encoding for target 'rocks' (token 1), the second row for target '<EOS>' (token 2), and the third row has no target.

Loss function

With the target specified, we define the loss function. It specifies how far the model’s predicted next-token probabilities are from the true targets. In autoregressive language modelling we use cross-entropy loss. This is the negative log probability of the correct token at each position, averaged over positions with a target. Positions without a target are skipped.

L=1Nt=1Nlogp(ytx<t)\mathcal{L} = - \frac{1}{N} \sum_{t=1}^{N} \log p(y_t \mid x_{<t})

where NN is the number of target tokens, ytyt, is the true token at position tt, and p(ytx<t)p(yt∣x<t) is the model’s predicted probability for that token. With the target specified, the loss function defined, and loss calculated, we are ready to go through the back propagation process for our toy example.

Worked example

We now calculate the loss for our toy example by calculating loss using the model’s predicted probabilities and the specified target sequence. The softmax probabilities from the forward pass were as follows.

P[0.41140.25320.33540.25240.43230.31530.32800.29860.3734]P \approx \begin{bmatrix} 0.4114 & 0.2532 & 0.3354 \\ 0.2524 & 0.4323 & 0.3153 \\ 0.3280 & 0.2986 & 0.3734 \end{bmatrix}

PP is the matrix of predicted probabilities, where each row corresponds to a position in the input sequence and each column corresponds to a token in the vocabulary. The target is [y1,y2,y3][y1, y2, y3] = [rocks,<EOS>,][rocks, <EOS>, -]. The first row gives the probabilities for predicting AIAI, rocksrocks and EOSEOS after the token AIAI. The second row gives the probabilities for predicting AIAI, rocksrocks and EOSEOS after the token rocksrocks. The third row gives the probabilities for predicting AIAI, rocksrocks and EOSEOS after the token EOSEOS.

Softmax backward
Logit projection backward
Residual connection backward (post MLP)
MLP backward
LayerNorm 2 backward
Residual connection backward (post attention)
Multihead attention backward
LayerNorm 1 backward
Positional encodings backward
Embedding lookup backward

References

LeCun, Y., Bottou, L., Orr, G. B., & Müller, K. R. (1998). Efficient backprop. In Neural networks: Tricks of the trade (pp. 9-50). Springer.

Rumelhart, D. E., Hinton, G. E., & Williams, R. J. (1986). Learning representations by back-propagating errors. Nature323(6088), 533-536.

Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing systems (Vol. 30).

Next page

How LLMs learn (4 of 4): Optimization

Last page

How LLMs learn (2 of 4): Forward pass

Home

Psychometrics.ai

image

This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).