With the forward pass complete, new tokens can be generated during inference which does not use back propagation, or learning begins with backpropagation. During learning, the final error calculated from the forward pass softmax against the target for the loss is traced back through the network to see how each layer contributed to the loss. Understanding backpropagation is matters because poor gradient management can waste enormous resources.
Gradient clipping, capping gradient values at a threshold to prevent them from growing exponentially, stops exploding gradients that cause training instability and NaN losses. Vanishing gradients can be ameliorated with architectural choices (residual connections, layer normalization) and proper weight initialization. Without this knowledge, undetected gradient issues during multi-day training runs on large GPU clusters can easily cost many thousands in wasted compute. Even in fine-tuning scenarios, catching these problems early prevents can save money.
Going backwards through he transformer, our goal at each (sub)layer is to calculate gradient with the same shape as the outputs at that stage in the forward pass, accumulating them with the chain rule so we can update parameters during optimization. The chain rule connects the “gradient out” from earlier layers (going backward) to the “gradient in” for next layer’s outputs and parameters. The results at each layer are gradient matrices and activation gradients that match the shape of the corresponding weight and activation matrices in the forward pass.
Alert: Here I highlight the steps and computations of the backpropagation procedure with a numerical example to make the ideas concrete. The approach is consistent with our simplified GPT-2 model. These results are preliminary and will be validated with a formal mathematical verifier in due course.
Target specification
We must first define the target to calculate loss. In this case, the target is the input sequence shifted one position forward, so that each token is trained to predict the next token in the sequence. For our toy example the target is for the first two positions, predicting the next tokens after and respectively, with no target for the final position. For our toy example, the target sequence in matrix form is specified as follows.
For loss calculation, these targets are converted to one-hot vectors where only the position corresponding to the target token has value 1, and all other positions are 0.
The first row represents the one-hot encoding for target 'rocks' (token 1), the second row for target '<EOS>' (token 2), and the third row has no target.
Loss function
With the target specified, we define the loss function. It specifies how far the model’s predicted next-token probabilities are from the true targets. In autoregressive language modelling we use cross-entropy loss. This is the negative log probability of the correct token at each position, averaged over positions with a target. Positions without a target are skipped.
where is the number of target tokens, , is the true token at position , and is the model’s predicted probability for that token. With the target specified, the loss function defined, and loss calculated, we are ready to go through the back propagation process for our toy example.
Worked example
We now calculate the loss for our toy example by calculating loss using the model’s predicted probabilities and the specified target sequence. The softmax probabilities from the forward pass were as follows.
is the matrix of predicted probabilities, where each row corresponds to a position in the input sequence and each column corresponds to a token in the vocabulary. The target is = . The first row gives the probabilities for predicting , and after the token . The second row gives the probabilities for predicting , and after the token . The third row gives the probabilities for predicting , and after the token .
References
LeCun, Y., Bottou, L., Orr, G. B., & Müller, K. R. (1998). Efficient backprop. In Neural networks: Tricks of the trade (pp. 9-50). Springer.
Rumelhart, D. E., Hinton, G. E., & Williams, R. J. (1986). Learning representations by back-propagating errors. Nature, 323(6088), 533-536.
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing systems (Vol. 30).
Next page
How LLMs learn (4 of 4): Optimization
Last page
How LLMs learn (2 of 4): Forward pass
Home
Psychometrics.ai
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).