Underrated Training Optimizations That Actually Move The Needle
Modern deep learning training has become an expensive balancing act between compute costs, memory constraints, and training time. While most practitioners focus on the obvious optimizations like mixed precision or gradient accumulation, there’s a collection of lesser-known techniques that can dramatically improve training efficiency without requiring expensive hardware upgrades or architectural overhauls.
These aren’t the general optimizations that are generally popular or considered common knowledge, but are still practical, implementable techniques that experienced engineers use to squeeze substantially more performance from their existing setups. The beauty of these methods lies in their accessibility as most of them can be implemented with just a few lines of code changes, yet they often deliver 20–50% improvements in training throughput or memory efficiency.
What makes these optimizations particularly valuable is that they address different bottlenecks in the training pipeline. Some tackle memory bandwidth issues, others eliminate computational waste, and several focus on removing hidden inefficiencies that silently drain performance. When combined thoughtfully, these techniques create a multiplicative effect that can transform a slow, memory-hungry training run into something super efficient.
The following techniques represent many trenches of large-scale model training. Each addresses a specific pain point that most people encounter but few know how to solve effectively. By understanding not just the “what” but the “why” behind each optimization, you’ll develop the intuition to apply these techniques appropriately and avoid the common pitfalls and altogether avoid a debugging nightmare. Let’s start with model internal compute optimizations.
Implementing Official And Custom Attention Implementations For Speedups.
For many modern SOTA Transformer models like Qwen, Llama or GPT-OSS, the vanilla scaled dot-product attention forms the full score matrix like this:
Which is applied to (masked) Softmax that computes PV where
That naïve path writes the entire L×L matrix back to the GPU, which causes the real bottleneck especially on newer larger GPUs. The issue here is memory traffic, not FLOPs. Implementing IO-aware kernels (FlashAttention, Liger Kernels etc.) tile Q/K/V so that blocks fit in on-chip SRAM component of the GPU. stream through K/V, maintain running row-wise max/sum for numerically stable “online softmax,” and only ever materialize small tiles instead of the full L×L matrix.
Optimizing internal code might sound complex but is rather simpler to work with, given there are official implementations like FlashAttention which can implemented with a simple 1-liner when using 🤗 Transformers.
What To Check For Before You Brag About Speedups 🙈
- DType & GPU: FA2/FA3 expect fp16/bf16 on CUDA. FA3 is an advanced implementation for newer Nvidia GPUs from Hopper architectures like H100.
- Mix-n-Match! Forcing experimental models to merge & compile with incompatible implementations may sometimes not cause an error but incorrectly calculate outputs which is again not good.
Extending Kernel/Fusing Implementation Beyond Attention Blocks.
Community provided solutions from Triton, HuggingFace Kernel Hub and Liger Kernel ships triton-based kernel implementations for RMSNorm, RoPE, SwiGLU, Cross-Entropy offering high throughput and insane memory reductions!
Stretching Context Windows Without Melting Your GPUs
One of the most obvious ways to improve a model’s utility is to let it handle longer sequences. Whether you’re doing document QA, long-form generation, or maybe you want to vibe-code your way through a project by pushing your entire codebase as context.
Enter YaRN (Yet another RoPE extensioN), a tiny but effective trick to extend context length without retraining from scratch or hurting short-context performance. It only takes a couple of lines in config and a short finetune run to make it stable.
Most LLaMA-style models use RoPE (rotary position embeddings) to encode token position. But if you naively push context (e.g., 8K to 64K), those position signals drift and the model basically loses track of “where it is” in the sequence.
YaRN rescales the frequency components of RoPE using temperature scaling and extrapolation factors to accommodate longer sequences. It modifies how the rotational frequencies are computed, effectively “stretching” the position encoding space without breaking the learned patterns. Kinda like telling the model: “you’re still reading, just… further down the page.” Most newer models already support YaRN implementation mechanism and encourage users to use it, if they want to fine-tune for longer context lengths.
To use YaRN for fine-tuning on longer sequences and also make sure the model still answers in shorter sequences, mixing short and long sequences in each batch helps the model retain its short-context behavior too. A full-pretraining cycle is never an option here. A few thousand steps and preferably a PEFT strategy should maximize the efficiency for a regular sized does the job.
Squeezing More Out of PyTorch with torch.compile (No Graph Breaks, Please)
torch.compile is one of those switches that can make your training loop noticeably faster… or do absolutely nothing if you leave graph breaks all over the place. The idea is simple, PyTorch traces your model, fuses ops, and uses CUDA Graphs to cut down Python overhead. But if the compiler hits a “graph break” (something it can’t trace), it bails back to eager mode and you lose most of the benefit.
A “graph break” happens when torch.compile hits something it can’t turn into a single optimized graph. Like a dynamic function or unsupported operations. So it stops compiling and runs that part in slow, normal PyTorch mode also known as “eager” mode. Too many of these, and you lose the speed benefits entirely.
Generally, torch.compile is a 1-liner implementation.
Although, given the diversity of current models and their custom operations designed for specific optimizations, some implementations might not be fully compatible with torch.compile. Thus, adding arguments like fullgraph=True and reduce-overhead mode usually adds a nice extra boost.
fullgraph=Trueforces you to fix the breaks early instead of silently running half-compiled. Look at the next section on how to fix them.reduce-overheadhelps with most LLMs/VLMs with consistent-shapes. If every batch has wildly different shapes, compiling can be slower.
Kill the Graph Breaks, Retain the Speed
- Always set
TORCH_LOGS="graph_breaks"or calltorch._dynamo.explain(model, *example_inputs)to see exactly where it’s failing. - Bucket/pad to fixed sizes (e.g., set pad sequence length to 4096 or bucket {1k, 2k, 4k}). Basically, Varying batch shapes = recompiles and breaks.
- Use tensor ops instead of python-side data deps like
.item(),len(tensor), orif x.shape[0] > 1:insideforward.
- Keep random/bookkeeping out of
forwardsignature: norandom, notime, nologgingprints. If you need randomness, use tensor RNG (torch.rand) and pass seeds in. - Swap unsupported ops for compiled ones (e.g., prefer using
torch.nn.functional.scaled_dot_product_attention) over custom-ops if your goal is to retain full-graphtorch.compile.
Smarter Activation Checkpointing instead of Checkpointing Everything
As models keep getting bigger, activation memory starts to dominate the your cloud bills. PyTorch gives you activation checkpointing to cut that memory. Instead of saving all intermediates for backward, it recomputes some of them later. Resulting in less memory consumption and more compute availability.
What usually gets checkpointed? (and why that’s bad)
The common pattern is “checkpoint the whole block” (or worse, the whole model). That means you’re recomputing everything inside each Transformer layer like Norms, Small Linears, Cheap Ops along with the truly heavy stuff. This helps with the memory drops for sure, but throughput nosedives as well.
What you should be checkpointed instead? (Selective Activation Checkpointing)
Selective Activation Checkpointing (SAC) fixes that. Instead of recomputing everything in a checkpointed region, you selectively save the expensive operations and only recompute the cheap ones. Resulting in optimizing memory consumption and recomputation.
While checkpointing with SAC, save (and don’t recompute) the compute‑heavy ops like:
- Matmuls
mm,bmm,addmm, scaled MM, large linears (e.g., MLP Modules, SwiGLU Activations). - Attention cores like scaled‑dot‑product variants, FA/Efficient attention kernels.
- Convolutions or upsample layers for vision towers in Vision-based models.
- MoE experts like the expert MLPs, you can still recompute the light router/norms.
An easy way to implement this is by passing the COMPUTE_INTENSIVE ops through a SAC Policy.
If you prefer using torch.compile, you can also let PyTorch handle SAC automatically for you like this.
TF32, CUDNN Autotuner, DataLoader knobs: Flip It On
These are certain low-effort settings that should definitely help you fix hidden bottlenecks.
Tensor-Float 32
TensorFloat‑32 (TF32) is an NVIDIA Tensor Core math mode that keeps FP32’s range (8‑bit exponent) but uses a 10‑bit precision, meaning the multiply‑accumulates run on Tensor Cores at Tensor‑Core speed, while accumulation stays in FP32. In practice you get FP32‑like convergence for most DL workloads, with a solid speed bump. Most LLM/VLM finetunes run bf16/fp16 mixed‑precision. In that mode, the heavy matmuls already hit Tensor Cores. TF32 only affects ops that still run in FP32 (e.g., core matmuls, convolutions or stray layers that autocast keeps in float32).
This simply helps you gain efficiency while training with little-to-none accuracy loss. Although this is a hardware constrained mode that worls for Ampere+ GPUs (A100, H100, H200 etc). Think of it like a good-to-have service since you have access to expensive computing resource.
Implementation is even simpler than understanding the concept.
For larger models, the time taken per-training-step should be reduced as Tensor Cores go brrrr… instead of being empty and hitting bottleneck.
CUDNN Autotuner
The autotuner is another internal PyTorch benchmarks multiple convolution algorithms at the start of training and picks the fastest one for your tensor shapes.
What is does?
- On the first few convolution calls, cuDNN tries different kernels and records the fastest.
- After that, it reuses the winning algorithm for every forward/backward pass with the same input shape.
This is specifically helpful for Vision based training where fixed aspect-ratio and fixed resolution are always passed as inputs. You can enable this with a simple switch like this.
DataLoader Knobs That Actually Matter
Data loading is one of the silent killers of training speed. Meaning, your GPU’s sitting there, ready to go, but it’s waiting on the CPU to hand over the next batch. A few small PyTorch DataLoader tweaks can smooth that pipeline and give you an easy ~10% boost without touching the model espcially for larger datasets.
Here’s a setup I personally use almost everywhere.
If you run into issues with the Dataloader mid-training, try reducing num_workers to 2-4.
Wrapping It Up
None of these tricks are magic bullets on their own, but stacking up a few of these together and you start to feel the difference. They’re the kinds of changes that don’t require rewriting your whole stack or burning through compute credits, just a little awareness of where the hidden inefficiencies live. The real win is that once you know these levers exist, you can pull them exactly when a bottleneck shows up instead of throwing more GPUs at the problem. Efficient training isn’t about one big hack but instead about a lot of small, smart and crafty ones.
