Search

Book purpose

Table of contents

Note of thanks

Endorsements

References

License and referencing

Update tracking

How LLMs learn (4 of 4): Optimization

How LLMs learn (4 of 4): Optimization

We have computed gradients for all learnable parameters in the network. During optimization, these gradients are used to update the model's parameters to reduce the loss on the training data. Moving forward through the transformer, our goal at each layer is to adjust the parameters in a direction that minimizes the loss.

There are numerous optimizers available for updating weights during training. Basic approaches include stochastic gradient descent (SGD) (Robbins & Monro, 1951), SGD with momentum (gradient moving averages: Polyak, 1964), and adaptive methods like Adam (Kingma & Ba, 2015) and AdamW (Loshchilov & Hutter, 2019) which adjust learning rates per parameter based on gradient history.

In summary, optimization applies an update rule based on the gradient, learning rate, and potentially other factors like momentum or adaptive learning rates. The results at each layer are updated parameter matrices that match the shape of the original weight matrices, now adjusted to better fit the training data.

Over repeated training iterations, convergence occurs along two dimensions: the quantitative reduction in training loss and the qualitative improvement in output coherence. These two criteria don't always align. Models may achieve low loss before becoming coherent, or produce fluent text while loss still decreases.

Worked numerical example

Our gradients for this example are small, which is expected given the tiny batch size, small model dimensions, and single training step. Adaptive optimizers like Adam will obscure the relationship between gradients and weight updates . We choose vanilla SGD to transparently shows how gradients directly determine updates via θnew=θoldηL/θθ_new = θ_old - η * ∂L/∂θ. The learning rate typically ranges from 0.001 to 0.1 for real training, but the value we choose here is η = .01.

💡

Alert: Here I highlight the steps and computations of the optimization (weight updating procedure) with a numerical example to make the ideas concrete. The approach is consistent with our simplified GPT-2 model. These results are preliminary and will be validated with a formal mathematical verifier in due course.

Embedding
Positional embeddings
LayerNorm 1
Multi-head self-attention
Residual connection
LayerNorm 2
Multilayer perceptron
Residual connection (post MLP)
Logit projection
Softmax

References

Robbins, H., & Monro, S. (1951). A stochastic approximation method. The Annals of Mathematical Statistics22(3), 400-407.

Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.

Loshchilov, I., & Hutter, F. (2019). Decoupled weight decay regularization. In Proceedings of the 7th International Conference on Learning Representations.

Polyak, B. T. (1964). Some methods of speeding up the convergence of iteration methods. USSR Computational Mathematics and Mathematical Physics4(5), 1-17.

Next page

Seedling: LLM training walkthrough

Last page

How LLMs learn (3 of 4): Back propagation

Home

Psychometrics.ai