Slaying OOMs with PyTorch FSDP and torchao

preview_player
Показать описание
Have you ever hit an OOM (and wished you had more VRAM)? If you've done much fine-tuning, then you have. And if you are just starting, then you will. Hop on the bus with Mark Saroufim and Jane Xu and feel the road become smoother as we talk about stacking together techniques like FSDP2 + QLoRa+ CPU Offloading + Fused ADAM (thanks Intel) + more in PyTorch native.

This is a talk from Mastering LLMs: A survey course on applied topics for Large Language Models.

More resources are available here:

*00:00 Introduction*
Mark introduces the session on addressing Out of Memory (OOM) errors in PyTorch, discussing tools and techniques to handle these issues more effectively.

*00:30 Traditional Solutions to OOMs*
Mark describes conventional methods of dealing with OOMs, such as reducing batch size or model size, and the limitations of these approaches.

*00:48 VRAM Constraints*
Mark explains VRAM constraints on different GPUs and how it impacts model training, emphasizing the perpetual challenge of being VRAM-starved.

*01:24 Estimating VRAM requirements for your model*
Mark outlines the components involved in estimating a model's memory usage, including parameters, gradients, and optimizer states.

*06:06 Quantization Techniques*
Mark introduces quantization techniques, such as 4-bit quantization, to reduce model size and memory requirements. Mark also demonstrates using Torch compile to generate efficient quantization kernels, avoiding the complexity of writing custom CUDA kernels.

*09:27 LoRA*
Mark introduces the LoRa technique for updating a subset of parameters to save memory.

*09:56 QLORA Algorithm*
Mark details the QLORA algorithm, combining quantized parameters with selective parameter updates to enable efficient fine-tuning.

*10:51 Implementing QLORA with PyTorch*
Discussion on implementing QLORA with PyTorch, highlighting the complexity of writing efficient kernels and the benefits of using Torch compile.

*14:38 Introducing Jane's Section on Model Parallelism*
Mark hands over to Jane to discuss parallelism techniques and how to manage memory across multiple devices.

*15:20 Understanding Memory Allocation During Training*
Jane illustrates memory allocation during training, showing the impact of activations, gradients, and optimizer states. Jane also explains data parallelism and model sharding as techniques to distribute memory load across multiple GPUs.

*17:45 Fully Sharded Data Parallel (FSDP)*
Jane introduces Fully Sharded Data Parallel (FSDP) and its mechanism to manage memory efficiently by sharding model parameters.

*21:49 CPU Offloading*
Jane discusses CPU offloading as a method to handle memory constraints by temporarily storing parameters on the CPU during training.

*23:05 Challenges and Improvements in FSDP*
Jane outlines the challenges with FSDP1 and introduces FSDP2, which offers more flexibility and efficiency in managing memory and data types.

*29:50 Identifying and Addressing Performance Gaps*
Jane discusses the process of identifying performance gaps in FSDP2 and the steps taken to optimize and match the performance of FSDP1. Jane discusses benchmarking and profiling techniques that are helpful in debugging performance.

*37:06 Overcoming Debugging Challenges*
Jane shares insights from debugging and optimizing the performance of FSDP2, highlighting the importance of detailed trace analysis. She also explains the impact of wrapping policy on memory usage.

*47:38 How you can get started*
Jane encourages students to try this process themselves in torchtune.
Рекомендации по теме
Комментарии
Автор

This is a very good presentation, for ppl who want to learn the details on debugging pytorch models.

windmaple