How LLMs learn (2 of 3): back propagation

How LLMs learn (2 of 3): back propagation

With the forward pass complete, new tokens can be generated during inference (no back propagation), or new learning begins (with back propagation). During back propagation, the final error from the forward pass output is traced back through the network to see how each layer contributed to the loss.

At each layer, the chain rule is applied to express the loss gradient (the partial derivative of the loss with respect to activations or parameters) in terms of that layer’s inputs, outputs, and parameters. This produces the gradients needed for the final stage, optimisation, where the weights are updated.

Target specification

We must first define the target to calculate loss. In this case, the target is the input sequence shifted one position forward, so that each token is trained to predict the next token in the sequence. For our toy example the target is [rocks][rocks] for the first position (predicting the next token after “AI”) and  nothing further for the second position, since our vocabulary did not define an explicit end-of-sequence (EOS) token. For our toy example, the target sequence in matrix form is specified as follows.

Input sequence: x=[x1,x2]=[AI,rocks]=[01]Target sequence: y=[y1,y2]=[rocks,no target]=[1]\begin{aligned} \text{Input sequence: } &\quad x = [x_1, x_2] = [\text{AI}, \text{rocks}] = \begin{bmatrix} 0 & 1 \end{bmatrix} \\[6pt] \text{Target sequence: } &\quad y = [y_1, y_2] = [\text{rocks}, \,\text{no target}] = \begin{bmatrix} 1 & - \end{bmatrix} \end{aligned}

Loss function

With the target specified, we define the loss function: it quantifies how far the model’s predicted next-token probabilities are from the true targets. In autoregressive language modelling we use cross-entropy (negative log-likelihood). This is the negative log probability of the correct token at each position, averaged over positions with a target (positions without a target are skipped).

L=1Nt=1Nlogp(ytx<t)\mathcal{L} = - \frac{1}{N} \sum_{t=1}^{N} \log p(y_t \mid x_{<t})

where NN is the number of target tokens, ytyt, is the true token at position tt, and p(ytx<t)p(yt∣x<t) is the model’s predicted probability for that token. With the target specified, the loss function defined, and loss calculated we are ready to go through the back propagation process for our numerical worked example.

Worked example

We will now calculate the loss for our toy example by applying this definition to the model’s predicted probabilities and the specified target sequence. The softmax probabilities from the forward pass were as follows.

P[0.590.410.410.59]P \approx \begin{bmatrix} 0.59 & 0.41 \\ 0.41 & 0.59 \end{bmatrix}

PP is the matrix of predicted probabilities, where each row corresponds to a position in the input sequence and each column corresponds to a token in the vocabulary. In our toy example, the first row gives the probabilities for predicting “AI” or “rocks” after the token “AI,” and the second row gives the probabilities for predicting “AI” or “rocks” after the token “rocks.” The target is [y1,y2][y1,y2] = [rocks,no target][rocks, no target].

Softmax and Logits
Logit projection
Residual connection (post MLP)
MLP
LayerNorm 2
Residual connection (post attention)
Output projection
Multihead attention
LayerNorm 1
Positional encodings
Embedding lookup

Remainder of this section is coming soon.

Next page

Seedling: LLM training walkthrough

Last page

How LLMs learn (1 of 3): forward pass

Home

Psychometrics.ai